package org.deeplearning4j.optimize;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.StepFunction;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.solvers.ConjugateGradient;
import org.deeplearning4j.optimize.solvers.LBFGS;
import org.deeplearning4j.optimize.solvers.LineGradientDescent;
import org.deeplearning4j.optimize.solvers.StochasticGradientDescent;
import org.deeplearning4j.optimize.stepfunctions.StepFunctions;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/optimize/Solver.class */
public class Solver {
    private NeuralNetConfiguration conf;
    private Collection<TrainingListener> listeners;
    private Model model;
    private ConvexOptimizer optimizer;
    private StepFunction stepFunction;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.deeplearning4j.optimize.Solver$1, reason: invalid class name */
    /* loaded from: input_file:org/deeplearning4j/optimize/Solver$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$deeplearning4j$nn$api$OptimizationAlgorithm = new int[OptimizationAlgorithm.values().length];

        static {
            try {
                $SwitchMap$org$deeplearning4j$nn$api$OptimizationAlgorithm[OptimizationAlgorithm.LBFGS.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$api$OptimizationAlgorithm[OptimizationAlgorithm.LINE_GRADIENT_DESCENT.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$api$OptimizationAlgorithm[OptimizationAlgorithm.CONJUGATE_GRADIENT.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$api$OptimizationAlgorithm[OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
        }
    }

    /* loaded from: input_file:org/deeplearning4j/optimize/Solver$Builder.class */
    public static class Builder {
        private NeuralNetConfiguration conf;
        private Model model;
        private List<TrainingListener> listeners = new ArrayList();

        public Builder configure(NeuralNetConfiguration neuralNetConfiguration) {
            this.conf = neuralNetConfiguration;
            return this;
        }

        public Builder listener(TrainingListener... trainingListenerArr) {
            if (trainingListenerArr != null) {
                this.listeners.addAll(Arrays.asList(trainingListenerArr));
            }
            return this;
        }

        public Builder listeners(Collection<TrainingListener> collection) {
            if (collection != null) {
                this.listeners.addAll(collection);
            }
            return this;
        }

        public Builder model(Model model) {
            this.model = model;
            return this;
        }

        public Solver build() {
            Solver solver = new Solver();
            solver.conf = this.conf;
            solver.stepFunction = StepFunctions.createStepFunction(this.conf.getStepFunction());
            solver.model = this.model;
            solver.listeners = this.listeners;
            return solver;
        }
    }

    public void optimize(LayerWorkspaceMgr layerWorkspaceMgr) {
        initOptimizer();
        this.optimizer.optimize(layerWorkspaceMgr);
    }

    public void initOptimizer() {
        if (this.optimizer == null) {
            MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
            try {
                this.optimizer = getOptimizer();
                if (scopeOutOfWorkspaces != null) {
                    scopeOutOfWorkspaces.close();
                }
            } catch (Throwable th) {
                if (scopeOutOfWorkspaces != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
                throw th;
            }
        }
    }

    public ConvexOptimizer getOptimizer() {
        if (this.optimizer != null) {
            return this.optimizer;
        }
        switch (AnonymousClass1.$SwitchMap$org$deeplearning4j$nn$api$OptimizationAlgorithm[this.conf.getOptimizationAlgo().ordinal()]) {
            case MergeVertex.DEFAULT_MERGE_DIM /* 1 */:
                this.optimizer = new LBFGS(this.conf, this.stepFunction, this.listeners, this.model);
                break;
            case 2:
                this.optimizer = new LineGradientDescent(this.conf, this.stepFunction, this.listeners, this.model);
                break;
            case 3:
                this.optimizer = new ConjugateGradient(this.conf, this.stepFunction, this.listeners, this.model);
                break;
            case 4:
                this.optimizer = new StochasticGradientDescent(this.conf, this.stepFunction, this.listeners, this.model);
                break;
            default:
                throw new IllegalStateException("No optimizer found");
        }
        return this.optimizer;
    }

    public void setListeners(Collection<TrainingListener> collection) {
        this.listeners = collection;
        if (this.optimizer != null) {
            this.optimizer.setListeners(collection);
        }
    }
}
