/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.optimize.solvers;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.deeplearning4j.exception.InvalidStepException;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.UpdaterCreator;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.StepFunction;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.solvers.BackTrackLineSearch;
import org.deeplearning4j.optimize.solvers.StochasticGradientDescent;
import org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator;
import org.deeplearning4j.optimize.stepfunctions.NegativeDefaultStepFunction;
import org.deeplearning4j.optimize.stepfunctions.NegativeGradientStepFunction;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class BaseOptimizer
implements ConvexOptimizer {
    protected NeuralNetConfiguration conf;
    protected static final Logger log = LoggerFactory.getLogger(BaseOptimizer.class);
    protected StepFunction stepFunction;
    protected Collection<TrainingListener> trainingListeners = new ArrayList<TrainingListener>();
    protected Model model;
    protected BackTrackLineSearch lineMaximizer;
    protected Updater updater;
    protected ComputationGraphUpdater computationGraphUpdater;
    protected double step;
    private int batchSize;
    protected double score;
    protected double oldScore;
    protected double stepMax = Double.MAX_VALUE;
    public static final String GRADIENT_KEY = "g";
    public static final String SCORE_KEY = "score";
    public static final String PARAMS_KEY = "params";
    public static final String SEARCH_DIR = "searchDirection";
    protected Map<String, Object> searchState = new ConcurrentHashMap<String, Object>();
    protected GradientsAccumulator accumulator;

    public BaseOptimizer(NeuralNetConfiguration conf, StepFunction stepFunction, Collection<TrainingListener> trainingListeners, Model model) {
        this.conf = conf;
        this.stepFunction = stepFunction != null ? stepFunction : BaseOptimizer.getDefaultStepFunctionForOptimizer(this.getClass());
        this.trainingListeners = trainingListeners != null ? trainingListeners : new ArrayList();
        this.model = model;
        this.lineMaximizer = new BackTrackLineSearch(model, this.stepFunction, this);
        this.lineMaximizer.setStepMax(this.stepMax);
        this.lineMaximizer.setMaxIterations(conf.getMaxNumLineSearchIterations());
    }

    @Override
    public void setGradientsAccumulator(GradientsAccumulator accumulator) {
        this.accumulator = accumulator;
    }

    @Override
    public GradientsAccumulator getGradientsAccumulator() {
        return this.accumulator;
    }

    @Override
    public double score() {
        throw new UnsupportedOperationException("Not yet reimplemented");
    }

    @Override
    public Updater getUpdater() {
        return this.getUpdater(true);
    }

    @Override
    public Updater getUpdater(boolean initializeIfReq) {
        if (this.updater == null && initializeIfReq) {
            this.updater = UpdaterCreator.getUpdater(this.model);
        }
        return this.updater;
    }

    @Override
    public void setUpdater(Updater updater) {
        this.updater = updater;
    }

    @Override
    public ComputationGraphUpdater getComputationGraphUpdater() {
        return this.getComputationGraphUpdater(true);
    }

    @Override
    public ComputationGraphUpdater getComputationGraphUpdater(boolean initializIfReq) {
        if (this.computationGraphUpdater == null && this.model instanceof ComputationGraph && initializIfReq) {
            this.computationGraphUpdater = new ComputationGraphUpdater((ComputationGraph)this.model);
        }
        return this.computationGraphUpdater;
    }

    @Override
    public void setUpdaterComputationGraph(ComputationGraphUpdater updater) {
        this.computationGraphUpdater = updater;
    }

    @Override
    public void setListeners(Collection<TrainingListener> listeners) {
        this.trainingListeners = listeners == null ? Collections.emptyList() : listeners;
    }

    @Override
    public NeuralNetConfiguration getConf() {
        return this.conf;
    }

    @Override
    public Pair<Gradient, Double> gradientAndScore(LayerWorkspaceMgr workspaceMgr) {
        this.oldScore = this.score;
        this.model.computeGradientAndScore(workspaceMgr);
        if (this.trainingListeners != null && !this.trainingListeners.isEmpty()) {
            try (MemoryWorkspace workspace = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                for (TrainingListener l : this.trainingListeners) {
                    l.onGradientCalculation(this.model);
                }
            }
        }
        Pair<Gradient, Double> pair = this.model.gradientAndScore();
        this.score = (Double)pair.getSecond();
        this.updateGradientAccordingToParams((Gradient)pair.getFirst(), this.model, this.model.batchSize(), workspaceMgr);
        return pair;
    }

    @Override
    public boolean optimize(LayerWorkspaceMgr workspaceMgr) {
        Throwable throwable;
        MemoryWorkspace ws;
        Pair<Gradient, Double> pair = this.gradientAndScore(workspaceMgr);
        if (this.searchState.isEmpty()) {
            this.searchState.put(GRADIENT_KEY, ((Gradient)pair.getFirst()).gradient());
            ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
            throwable = null;
            try {
                this.setupSearchState(pair);
            }
            catch (Throwable throwable2) {
                throwable = throwable2;
                throw throwable2;
            }
            finally {
                if (ws != null) {
                    if (throwable != null) {
                        try {
                            ws.close();
                        }
                        catch (Throwable throwable3) {
                            throwable.addSuppressed(throwable3);
                        }
                    } else {
                        ws.close();
                    }
                }
            }
        } else {
            this.searchState.put(GRADIENT_KEY, ((Gradient)pair.getFirst()).gradient());
        }
        ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
        throwable = null;
        try {
            this.preProcessLine();
        }
        catch (Throwable throwable4) {
            throwable = throwable4;
            throw throwable4;
        }
        finally {
            if (ws != null) {
                if (throwable != null) {
                    try {
                        ws.close();
                    }
                    catch (Throwable throwable5) {
                        throwable.addSuppressed(throwable5);
                    }
                } else {
                    ws.close();
                }
            }
        }
        INDArray gradient = (INDArray)this.searchState.get(GRADIENT_KEY);
        INDArray searchDirection = (INDArray)this.searchState.get(SEARCH_DIR);
        INDArray parameters = (INDArray)this.searchState.get(PARAMS_KEY);
        try {
            this.step = this.lineMaximizer.optimize(parameters, gradient, searchDirection, workspaceMgr);
        }
        catch (InvalidStepException e) {
            log.warn("Invalid step...continuing another iteration: {}", (Object)e.getMessage());
            this.step = 0.0;
        }
        if (this.step != 0.0) {
            this.stepFunction.step(parameters, searchDirection, this.step);
            this.model.setParams(parameters);
        } else {
            log.debug("Step size returned by line search is 0.0.");
        }
        pair = this.gradientAndScore(workspaceMgr);
        ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
        throwable = null;
        try {
            this.postStep(((Gradient)pair.getFirst()).gradient());
        }
        catch (Throwable throwable6) {
            throwable = throwable6;
            throw throwable6;
        }
        finally {
            if (ws != null) {
                if (throwable != null) {
                    try {
                        ws.close();
                    }
                    catch (Throwable throwable7) {
                        throwable.addSuppressed(throwable7);
                    }
                } else {
                    ws.close();
                }
            }
        }
        int iterationCount = BaseOptimizer.getIterationCount(this.model);
        int epochCount = BaseOptimizer.getEpochCount(this.model);
        try (MemoryWorkspace workspace = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            for (TrainingListener listener : this.trainingListeners) {
                listener.iterationDone(this.model, iterationCount, epochCount);
            }
        }
        BaseOptimizer.incrementIterationCount(this.model, 1);
        BaseOptimizer.applyConstraints(this.model);
        return true;
    }

    protected void postFirstStep(INDArray gradient) {
    }

    @Override
    public int batchSize() {
        return this.batchSize;
    }

    @Override
    public void setBatchSize(int batchSize) {
        this.batchSize = batchSize;
    }

    @Override
    public void preProcessLine() {
    }

    @Override
    public void postStep(INDArray gradient) {
    }

    @Override
    public void updateGradientAccordingToParams(Gradient gradient, Model model, int batchSize, LayerWorkspaceMgr workspaceMgr) {
        if (model instanceof ComputationGraph) {
            ComputationGraph graph = (ComputationGraph)model;
            if (this.computationGraphUpdater == null) {
                try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                    this.computationGraphUpdater = new ComputationGraphUpdater(graph);
                }
            }
            this.computationGraphUpdater.update(gradient, BaseOptimizer.getIterationCount(model), BaseOptimizer.getEpochCount(model), batchSize, workspaceMgr);
        } else {
            if (this.updater == null) {
                try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                    this.updater = UpdaterCreator.getUpdater(model);
                }
            }
            Layer layer = (Layer)model;
            this.updater.update(layer, gradient, BaseOptimizer.getIterationCount(model), BaseOptimizer.getEpochCount(model), batchSize, workspaceMgr);
        }
    }

    @Override
    public void setupSearchState(Pair<Gradient, Double> pair) {
        INDArray gradient = ((Gradient)pair.getFirst()).gradient(this.conf.variables());
        INDArray params = this.model.params().dup();
        this.searchState.put(GRADIENT_KEY, gradient);
        this.searchState.put(SCORE_KEY, pair.getSecond());
        this.searchState.put(PARAMS_KEY, params);
    }

    public static StepFunction getDefaultStepFunctionForOptimizer(Class<? extends ConvexOptimizer> optimizerClass) {
        if (optimizerClass == StochasticGradientDescent.class) {
            return new NegativeGradientStepFunction();
        }
        return new NegativeDefaultStepFunction();
    }

    public static int getIterationCount(Model model) {
        if (model instanceof MultiLayerNetwork) {
            return ((MultiLayerNetwork)model).getLayerWiseConfigurations().getIterationCount();
        }
        if (model instanceof ComputationGraph) {
            return ((ComputationGraph)model).getConfiguration().getIterationCount();
        }
        return model.conf().getIterationCount();
    }

    public static void incrementIterationCount(Model model, int incrementBy) {
        if (model instanceof MultiLayerNetwork) {
            MultiLayerConfiguration conf = ((MultiLayerNetwork)model).getLayerWiseConfigurations();
            conf.setIterationCount(conf.getIterationCount() + incrementBy);
        } else if (model instanceof ComputationGraph) {
            ComputationGraphConfiguration conf = ((ComputationGraph)model).getConfiguration();
            conf.setIterationCount(conf.getIterationCount() + incrementBy);
        } else {
            model.conf().setIterationCount(model.conf().getIterationCount() + incrementBy);
        }
    }

    public static int getEpochCount(Model model) {
        if (model instanceof MultiLayerNetwork) {
            return ((MultiLayerNetwork)model).getLayerWiseConfigurations().getEpochCount();
        }
        if (model instanceof ComputationGraph) {
            return ((ComputationGraph)model).getConfiguration().getEpochCount();
        }
        return model.conf().getEpochCount();
    }

    public static void applyConstraints(Model model) {
        int iter = BaseOptimizer.getIterationCount(model);
        int epoch = BaseOptimizer.getEpochCount(model);
        model.applyConstraints(iter, epoch);
    }

    @Override
    public StepFunction getStepFunction() {
        return this.stepFunction;
    }
}

