package org.deeplearning4j.optimize.solvers;

import java.util.Collection;
import java.util.Iterator;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.StepFunction;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/optimize/solvers/StochasticGradientDescent.class */
public class StochasticGradientDescent extends BaseOptimizer {
    private static final Logger log = LoggerFactory.getLogger(StochasticGradientDescent.class);

    public StochasticGradientDescent(NeuralNetConfiguration neuralNetConfiguration, StepFunction stepFunction, Collection<TrainingListener> collection, Model model) {
        super(neuralNetConfiguration, stepFunction, collection, model);
    }

    @Override // org.deeplearning4j.optimize.solvers.BaseOptimizer, org.deeplearning4j.optimize.api.ConvexOptimizer
    public boolean optimize(LayerWorkspaceMgr layerWorkspaceMgr) {
        if (this.accumulator != null && this.accumulator.hasAnything()) {
            log.info("Applying external updates before FF...");
            this.accumulator.applyUpdate(this.stepFunction, this.model.params(), Nd4j.createUninitialized(this.model.params().shape(), this.model.params().ordering()), false);
        }
        Gradient gradient = (Gradient) gradientAndScore(layerWorkspaceMgr).getFirst();
        INDArray params = this.model.params();
        INDArray gradient2 = gradient.gradient();
        INDArray reshape = gradient2.reshape(new long[]{gradient2.length()});
        if (this.accumulator != null) {
            int i = 0;
            int i2 = 0;
            if (this.model instanceof MultiLayerNetwork) {
                i2 = ((MultiLayerNetwork) this.model).getIterationCount();
                i = ((MultiLayerNetwork) this.model).getEpochCount();
            } else if (this.model instanceof ComputationGraph) {
                i2 = ((ComputationGraph) this.model).getIterationCount();
                i = ((ComputationGraph) this.model).getEpochCount();
            }
            this.accumulator.storeUpdate(reshape, i2, i);
            this.accumulator.applyUpdate(this.stepFunction, params, reshape, true);
        } else {
            this.stepFunction.step(params, reshape);
        }
        this.model.setParams(params);
        int iterationCount = BaseOptimizer.getIterationCount(this.model);
        int epochCount = BaseOptimizer.getEpochCount(this.model);
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
        try {
            Iterator<TrainingListener> it = this.trainingListeners.iterator();
            while (it.hasNext()) {
                it.next().iterationDone(this.model, iterationCount, epochCount);
            }
            if (scopeOutOfWorkspaces != null) {
                scopeOutOfWorkspaces.close();
            }
            BaseOptimizer.incrementIterationCount(this.model, 1);
            applyConstraints(this.model);
            return true;
        } catch (Throwable th) {
            if (scopeOutOfWorkspaces != null) {
                try {
                    scopeOutOfWorkspaces.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // org.deeplearning4j.optimize.solvers.BaseOptimizer, org.deeplearning4j.optimize.api.ConvexOptimizer
    public void preProcessLine() {
    }

    @Override // org.deeplearning4j.optimize.solvers.BaseOptimizer, org.deeplearning4j.optimize.api.ConvexOptimizer
    public void postStep(INDArray iNDArray) {
    }
}
