/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.util;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.Trainable;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.GraphVertex;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.MultiLayerUpdater;
import org.deeplearning4j.nn.updater.UpdaterBlock;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.schedule.ISchedule;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NetworkUtils {
    private static final Logger log = LoggerFactory.getLogger(NetworkUtils.class);

    private NetworkUtils() {
    }

    public static ComputationGraph toComputationGraph(MultiLayerNetwork net) {
        ComputationGraphConfiguration.GraphBuilder b = new NeuralNetConfiguration.Builder().dataType(net.getLayerWiseConfigurations().getDataType()).graphBuilder();
        MultiLayerConfiguration origConf = net.getLayerWiseConfigurations().clone();
        int layerIdx = 0;
        String lastLayer = "in";
        b.addInputs("in");
        for (NeuralNetConfiguration c : origConf.getConfs()) {
            String currLayer = String.valueOf(layerIdx);
            InputPreProcessor preproc = origConf.getInputPreProcess(layerIdx);
            b.addLayer(currLayer, c.getLayer(), preproc, lastLayer);
            lastLayer = currLayer;
            ++layerIdx;
        }
        b.setOutputs(lastLayer);
        ComputationGraphConfiguration conf = b.build();
        ComputationGraph cg = new ComputationGraph(conf);
        cg.init();
        cg.setParams(net.params());
        INDArray updaterState = net.getUpdater().getStateViewArray();
        if (updaterState != null) {
            cg.getUpdater().getUpdaterStateViewArray().assign(updaterState);
        }
        return cg;
    }

    public static void setLearningRate(MultiLayerNetwork net, double newLr) {
        NetworkUtils.setLearningRate(net, newLr, null);
    }

    private static void setLearningRate(MultiLayerNetwork net, double newLr, ISchedule lrSchedule) {
        int nLayers = net.getnLayers();
        for (int i = 0; i < nLayers; ++i) {
            NetworkUtils.setLearningRate(net, i, newLr, lrSchedule, false);
        }
        NetworkUtils.refreshUpdater(net);
    }

    private static void setLearningRate(MultiLayerNetwork net, int layerNumber, double newLr, ISchedule newLrSchedule, boolean refreshUpdater) {
        Layer l = net.getLayer(layerNumber).conf().getLayer();
        if (l instanceof BaseLayer) {
            BaseLayer bl = (BaseLayer)l;
            IUpdater u = bl.getIUpdater();
            if (u != null && u.hasLearningRate()) {
                if (newLrSchedule != null) {
                    u.setLrAndSchedule(Double.NaN, newLrSchedule);
                } else {
                    u.setLrAndSchedule(newLr, null);
                }
            }
            if (refreshUpdater) {
                NetworkUtils.refreshUpdater(net);
            }
        }
    }

    private static void refreshUpdater(MultiLayerNetwork net) {
        INDArray origUpdaterState = net.getUpdater().getStateViewArray();
        MultiLayerUpdater origUpdater = (MultiLayerUpdater)net.getUpdater();
        net.setUpdater(null);
        MultiLayerUpdater newUpdater = (MultiLayerUpdater)net.getUpdater();
        INDArray newUpdaterState = NetworkUtils.rebuildUpdaterStateArray(origUpdaterState, origUpdater.getUpdaterBlocks(), newUpdater.getUpdaterBlocks());
        newUpdater.setStateViewArray(newUpdaterState);
    }

    public static void setLearningRate(MultiLayerNetwork net, ISchedule newLrSchedule) {
        NetworkUtils.setLearningRate(net, Double.NaN, newLrSchedule);
    }

    public static void setLearningRate(MultiLayerNetwork net, int layerNumber, double newLr) {
        NetworkUtils.setLearningRate(net, layerNumber, newLr, null, true);
    }

    public static void setLearningRate(MultiLayerNetwork net, int layerNumber, ISchedule lrSchedule) {
        NetworkUtils.setLearningRate(net, layerNumber, Double.NaN, lrSchedule, true);
    }

    public static Double getLearningRate(MultiLayerNetwork net, int layerNumber) {
        Layer l = net.getLayer(layerNumber).conf().getLayer();
        int iter = net.getIterationCount();
        int epoch = net.getEpochCount();
        if (l instanceof BaseLayer) {
            BaseLayer bl = (BaseLayer)l;
            IUpdater u = bl.getIUpdater();
            if (u != null && u.hasLearningRate()) {
                double d = u.getLearningRate(iter, epoch);
                if (Double.isNaN(d)) {
                    return null;
                }
                return d;
            }
            return null;
        }
        return null;
    }

    public static void setLearningRate(ComputationGraph net, double newLr) {
        NetworkUtils.setLearningRate(net, newLr, null);
    }

    private static void setLearningRate(ComputationGraph net, double newLr, ISchedule lrSchedule) {
        org.deeplearning4j.nn.api.Layer[] layers = net.getLayers();
        for (int i = 0; i < layers.length; ++i) {
            NetworkUtils.setLearningRate(net, layers[i].conf().getLayer().getLayerName(), newLr, lrSchedule, false);
        }
        NetworkUtils.refreshUpdater(net);
    }

    private static void setLearningRate(ComputationGraph net, String layerName, double newLr, ISchedule newLrSchedule, boolean refreshUpdater) {
        Layer l = net.getLayer(layerName).conf().getLayer();
        if (l instanceof BaseLayer) {
            BaseLayer bl = (BaseLayer)l;
            IUpdater u = bl.getIUpdater();
            if (u != null && u.hasLearningRate()) {
                if (newLrSchedule != null) {
                    u.setLrAndSchedule(Double.NaN, newLrSchedule);
                } else {
                    u.setLrAndSchedule(newLr, null);
                }
            }
            if (refreshUpdater) {
                NetworkUtils.refreshUpdater(net);
            }
        }
    }

    private static void refreshUpdater(ComputationGraph net) {
        INDArray origUpdaterState = net.getUpdater().getStateViewArray();
        ComputationGraphUpdater uOrig = net.getUpdater();
        net.setUpdater(null);
        ComputationGraphUpdater uNew = net.getUpdater();
        INDArray newUpdaterState = NetworkUtils.rebuildUpdaterStateArray(origUpdaterState, uOrig.getUpdaterBlocks(), uNew.getUpdaterBlocks());
        uNew.setStateViewArray(newUpdaterState);
    }

    public static void setLearningRate(ComputationGraph net, ISchedule newLrSchedule) {
        NetworkUtils.setLearningRate(net, Double.NaN, newLrSchedule);
    }

    public static void setLearningRate(ComputationGraph net, String layerName, double newLr) {
        NetworkUtils.setLearningRate(net, layerName, newLr, null, true);
    }

    public static void setLearningRate(ComputationGraph net, String layerName, ISchedule lrSchedule) {
        NetworkUtils.setLearningRate(net, layerName, Double.NaN, lrSchedule, true);
    }

    public static Double getLearningRate(ComputationGraph net, String layerName) {
        Layer l = net.getLayer(layerName).conf().getLayer();
        int iter = net.getConfiguration().getIterationCount();
        int epoch = net.getConfiguration().getEpochCount();
        if (l instanceof BaseLayer) {
            BaseLayer bl = (BaseLayer)l;
            IUpdater u = bl.getIUpdater();
            if (u != null && u.hasLearningRate()) {
                double d = u.getLearningRate(iter, epoch);
                if (Double.isNaN(d)) {
                    return null;
                }
                return d;
            }
            return null;
        }
        return null;
    }

    public static INDArray output(Model model, INDArray input) {
        if (model instanceof MultiLayerNetwork) {
            MultiLayerNetwork multiLayerNetwork = (MultiLayerNetwork)model;
            INDArray output = multiLayerNetwork.output(input);
            return output;
        }
        if (model instanceof ComputationGraph) {
            ComputationGraph computationGraph = (ComputationGraph)model;
            INDArray output = computationGraph.outputSingle(input);
            return output;
        }
        String message = model.getClass().getName().startsWith("org.deeplearning4j") ? model.getClass().getName() + " models are not yet supported and pull requests are welcome: https://github.com/deeplearning4j/deeplearning4j" : model.getClass().getName() + " models are unsupported.";
        throw new UnsupportedOperationException(message);
    }

    public static void removeInstances(List<?> list, Class<?> remove) {
        NetworkUtils.removeInstancesWithWarning(list, remove, null);
    }

    public static void removeInstancesWithWarning(List<?> list, Class<?> remove, String warning) {
        if (list == null || list.isEmpty()) {
            return;
        }
        Iterator<?> iter = list.iterator();
        while (iter.hasNext()) {
            Object o = iter.next();
            if (!remove.isAssignableFrom(o.getClass())) continue;
            if (warning != null) {
                log.warn(warning);
            }
            iter.remove();
        }
    }

    protected static INDArray rebuildUpdaterStateArray(INDArray origUpdaterState, List<UpdaterBlock> orig, List<UpdaterBlock> newUpdater) {
        if (origUpdaterState == null) {
            return origUpdaterState;
        }
        if (orig.size() == newUpdater.size()) {
            boolean allEq = true;
            for (int i = 0; i < orig.size(); ++i) {
                UpdaterBlock updaterBlock = orig.get(i);
                UpdaterBlock ub2 = newUpdater.get(i);
                if (updaterBlock.getLayersAndVariablesInBlock().equals(ub2.getLayersAndVariablesInBlock())) continue;
                allEq = false;
                break;
            }
            if (allEq) {
                return origUpdaterState;
            }
        }
        HashMap stateViewsPerParam = new HashMap();
        for (UpdaterBlock updaterBlock : orig) {
            List<UpdaterBlock.ParamState> params = updaterBlock.getLayersAndVariablesInBlock();
            int blockPStart = updaterBlock.getParamOffsetStart();
            int blockPEnd = updaterBlock.getParamOffsetEnd();
            int blockUStart = updaterBlock.getUpdaterViewOffsetStart();
            int blockUEnd = updaterBlock.getUpdaterViewOffsetEnd();
            int paramsMultiplier = (blockUEnd - blockUStart) / (blockPEnd - blockPStart);
            INDArray updaterView = updaterBlock.getUpdaterView();
            long nParamsInBlock = blockPEnd - blockPStart;
            long soFar = 0L;
            for (int sub = 0; sub < paramsMultiplier; ++sub) {
                INDArray subsetUpdaterView = updaterView.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)0L, (boolean)true), NDArrayIndex.interval((long)soFar, (long)(soFar + nParamsInBlock))});
                long offsetWithinSub = 0L;
                for (UpdaterBlock.ParamState ps : params) {
                    int idx = NetworkUtils.getId(ps.getLayer());
                    String paramName = idx + "_" + ps.getParamName();
                    INDArray pv = ps.getParamView();
                    long nParamsThisParam = pv.length();
                    INDArray currSplit = subsetUpdaterView.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)0L, (boolean)true), NDArrayIndex.interval((long)offsetWithinSub, (long)(offsetWithinSub + nParamsThisParam))});
                    if (!stateViewsPerParam.containsKey(paramName)) {
                        stateViewsPerParam.put(paramName, new ArrayList());
                    }
                    ((List)stateViewsPerParam.get(paramName)).add(currSplit);
                    offsetWithinSub += nParamsThisParam;
                }
                soFar += nParamsInBlock;
            }
        }
        ArrayList<INDArray> toConcat = new ArrayList<INDArray>();
        for (UpdaterBlock ub : newUpdater) {
            List<UpdaterBlock.ParamState> ps = ub.getLayersAndVariablesInBlock();
            int idx = NetworkUtils.getId(ps.get(0).getLayer());
            String firstParam = idx + "_" + ps.get(0).getParamName();
            int size = ((List)stateViewsPerParam.get(firstParam)).size();
            for (int i = 0; i < size; ++i) {
                for (UpdaterBlock.ParamState p : ps) {
                    idx = NetworkUtils.getId(p.getLayer());
                    String paramName = idx + "_" + p.getParamName();
                    INDArray arr = (INDArray)((List)stateViewsPerParam.get(paramName)).get(i);
                    toConcat.add(arr);
                }
            }
        }
        INDArray iNDArray = Nd4j.hstack(toConcat);
        Preconditions.checkState((iNDArray.rank() == 2 ? 1 : 0) != 0, (String)"Expected rank 2");
        Preconditions.checkState((origUpdaterState.length() == iNDArray.length() ? 1 : 0) != 0, (String)"Updater state array lengths should be equal: got %s s. %s", (long)origUpdaterState.length(), (long)iNDArray.length());
        return iNDArray;
    }

    private static int getId(Trainable trainable) {
        if (trainable instanceof GraphVertex) {
            GraphVertex gv = (GraphVertex)trainable;
            return gv.getVertexIndex();
        }
        org.deeplearning4j.nn.api.Layer l = (org.deeplearning4j.nn.api.Layer)trainable;
        return l.getIndex();
    }
}

