/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.gradientcheck;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.graph.LayerVertex;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.layers.BaseOutputLayer;
import org.deeplearning4j.nn.conf.layers.LossLayer;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.UpdaterCreator;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.common.function.Consumer;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT;
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GradientCheckUtil {
    private static final Logger log = LoggerFactory.getLogger(GradientCheckUtil.class);
    private static final List<Class<? extends IActivation>> VALID_ACTIVATION_FUNCTIONS = Arrays.asList(Activation.CUBE.getActivationFunction().getClass(), Activation.ELU.getActivationFunction().getClass(), Activation.IDENTITY.getActivationFunction().getClass(), Activation.RATIONALTANH.getActivationFunction().getClass(), Activation.SIGMOID.getActivationFunction().getClass(), Activation.SOFTMAX.getActivationFunction().getClass(), Activation.SOFTPLUS.getActivationFunction().getClass(), Activation.SOFTSIGN.getActivationFunction().getClass(), Activation.TANH.getActivationFunction().getClass());

    private GradientCheckUtil() {
    }

    private static void configureLossFnClippingIfPresent(IOutputLayer outputLayer) {
        ILossFunction lfn = null;
        IActivation afn = null;
        if (outputLayer instanceof org.deeplearning4j.nn.layers.BaseOutputLayer) {
            org.deeplearning4j.nn.layers.BaseOutputLayer o = (org.deeplearning4j.nn.layers.BaseOutputLayer)outputLayer;
            lfn = ((BaseOutputLayer)o.layerConf()).getLossFn();
            afn = ((BaseLayer)o.layerConf()).getActivationFn();
        } else if (outputLayer instanceof org.deeplearning4j.nn.layers.LossLayer) {
            org.deeplearning4j.nn.layers.LossLayer o = (org.deeplearning4j.nn.layers.LossLayer)outputLayer;
            lfn = ((LossLayer)o.layerConf()).getLossFn();
            afn = ((LossLayer)o.layerConf()).getActivationFn();
        }
        if (lfn instanceof LossMCXENT && afn instanceof ActivationSoftmax && ((LossMCXENT)lfn).getSoftmaxClipEps() != 0.0) {
            log.info("Setting softmax clipping epsilon to 0.0 for " + lfn.getClass() + " loss function to avoid spurious gradient check failures");
            ((LossMCXENT)lfn).setSoftmaxClipEps(0.0);
        } else if (lfn instanceof LossBinaryXENT && ((LossBinaryXENT)lfn).getClipEps() != 0.0) {
            log.info("Setting clipping epsilon to 0.0 for " + lfn.getClass() + " loss function to avoid spurious gradient check failures");
            ((LossBinaryXENT)lfn).setClipEps(0.0);
        }
    }

    @Deprecated
    public static boolean checkGradients(MultiLayerNetwork mln, double epsilon, double maxRelError, double minAbsoluteError, boolean print, boolean exitOnFirstError, INDArray input, INDArray labels) {
        return GradientCheckUtil.checkGradients(new MLNConfig().net(mln).epsilon(epsilon).maxRelError(maxRelError).minAbsoluteError(minAbsoluteError).print(PrintMode.FAILURES_ONLY).exitOnFirstError(exitOnFirstError).input(input).labels(labels));
    }

    @Deprecated
    public static boolean checkGradients(MultiLayerNetwork mln, double epsilon, double maxRelError, double minAbsoluteError, boolean print, boolean exitOnFirstError, INDArray input, INDArray labels, INDArray inputMask, INDArray labelMask, boolean subset, int maxPerParam, Set<String> excludeParams, final Integer rngSeedResetEachIter) {
        Consumer<MultiLayerNetwork> c = null;
        if (rngSeedResetEachIter != null) {
            c = new Consumer<MultiLayerNetwork>(){

                public void accept(MultiLayerNetwork multiLayerNetwork) {
                    Nd4j.getRandom().setSeed(rngSeedResetEachIter.intValue());
                }
            };
        }
        return GradientCheckUtil.checkGradients(new MLNConfig().net(mln).epsilon(epsilon).maxRelError(maxRelError).minAbsoluteError(minAbsoluteError).print(PrintMode.FAILURES_ONLY).exitOnFirstError(exitOnFirstError).input(input).labels(labels).inputMask(inputMask).labelMask(labelMask).subset(subset).maxPerParam(maxPerParam).excludeParams(excludeParams).callEachIter(c));
    }

    public static boolean checkGradients(MLNConfig c) {
        int i;
        HashMap<String, Integer> stepSizeForParam;
        if (c.epsilon <= 0.0 || c.epsilon > 0.1) {
            throw new IllegalArgumentException("Invalid epsilon: expect epsilon in range (0,0.1], usually 1e-4 or so");
        }
        if (c.maxRelError <= 0.0 || c.maxRelError > 0.25) {
            throw new IllegalArgumentException("Invalid maxRelativeError: " + c.maxRelError);
        }
        if (!(c.net.getOutputLayer() instanceof IOutputLayer)) {
            throw new IllegalArgumentException("Cannot check backprop gradients without OutputLayer");
        }
        DataType dataType = DataTypeUtil.getDtypeFromContext();
        if (dataType != DataType.DOUBLE) {
            throw new IllegalStateException("Cannot perform gradient check: Datatype is not set to double precision (is: " + dataType + "). Double precision must be used for gradient checks. Set DataTypeUtil.setDTypeForContext(DataType.DOUBLE); before using GradientCheckUtil");
        }
        DataType netDataType = c.net.getLayerWiseConfigurations().getDataType();
        if (netDataType != DataType.DOUBLE) {
            throw new IllegalStateException("Cannot perform gradient check: Network datatype is not set to double precision (is: " + netDataType + "). Double precision must be used for gradient checks. Create network with .dataType(DataType.DOUBLE) before using GradientCheckUtil");
        }
        if (netDataType != c.net.params().dataType()) {
            throw new IllegalStateException("Parameters datatype does not match network configuration datatype (is: " + c.net.params().dataType() + "). If network datatype is set to DOUBLE, parameters must also be DOUBLE.");
        }
        int layerCount = 0;
        for (NeuralNetConfiguration n : c.net.getLayerWiseConfigurations().getConfs()) {
            if (n.getLayer() instanceof BaseLayer) {
                IActivation activation;
                BaseLayer bl = (BaseLayer)n.getLayer();
                IUpdater u = bl.getIUpdater();
                if (u instanceof Sgd) {
                    double lr = ((Sgd)u).getLearningRate();
                    if (lr != 1.0) {
                        throw new IllegalStateException("When using SGD updater, must also use lr=1.0 for layer " + layerCount + "; got " + u + " with lr=" + lr + " for layer \"" + n.getLayer().getLayerName() + "\"");
                    }
                } else if (!(u instanceof NoOp)) {
                    throw new IllegalStateException("Must have Updater.NONE (or SGD + lr=1.0) for layer " + layerCount + "; got " + u);
                }
                if ((activation = bl.getActivationFn()) != null && !VALID_ACTIVATION_FUNCTIONS.contains(activation.getClass())) {
                    log.warn("Layer " + layerCount + " is possibly using an unsuitable activation function: " + activation.getClass() + ". Activation functions for gradient checks must be smooth (like sigmoid, tanh, softmax) and not contain discontinuities like ReLU or LeakyReLU (these may cause spurious failures)");
                }
            }
            if (n.getLayer().getIDropout() == null || c.callEachIter != null) continue;
            throw new IllegalStateException("When gradient checking dropout, need to reset RNG seed each iter, or no dropout should be present during gradient checks - got dropout = " + n.getLayer().getIDropout() + " for layer " + layerCount);
        }
        for (Layer l : c.net.getLayers()) {
            if (!(l instanceof IOutputLayer)) continue;
            GradientCheckUtil.configureLossFnClippingIfPresent((IOutputLayer)l);
        }
        c.net.setInput(c.input);
        c.net.setLabels(c.labels);
        c.net.setLayerMaskArrays(c.inputMask, c.labelMask);
        if (c.callEachIter != null) {
            c.callEachIter.accept((Object)c.net);
        }
        c.net.computeGradientAndScore();
        Pair<Gradient, Double> gradAndScore = c.net.gradientAndScore();
        Updater updater = UpdaterCreator.getUpdater(c.net);
        updater.update(c.net, (Gradient)gradAndScore.getFirst(), 0, 0, c.net.batchSize(), LayerWorkspaceMgr.noWorkspaces());
        INDArray gradientToCheck = ((Gradient)gradAndScore.getFirst()).gradient().dup();
        INDArray originalParams = c.net.params().dup();
        long nParams = originalParams.length();
        Map<String, INDArray> paramTable = c.net.paramTable();
        ArrayList<String> paramNames = new ArrayList<String>(paramTable.keySet());
        long[] paramEnds = new long[paramNames.size()];
        paramEnds[0] = paramTable.get(paramNames.get(0)).length();
        if (c.subset) {
            stepSizeForParam = new HashMap<String, Integer>();
            stepSizeForParam.put((String)paramNames.get(0), (int)Math.max(1L, paramTable.get(paramNames.get(0)).length() / (long)c.maxPerParam));
        } else {
            stepSizeForParam = null;
        }
        for (i = 1; i < paramEnds.length; ++i) {
            long n = paramTable.get(paramNames.get(i)).length();
            paramEnds[i] = paramEnds[i - 1] + n;
            if (!c.subset) continue;
            long ss = n / (long)c.maxPerParam;
            if (ss == 0L) {
                ss = n;
            }
            if (ss > Integer.MAX_VALUE) {
                throw new ND4JArraySizeException();
            }
            stepSizeForParam.put((String)paramNames.get(i), (int)ss);
        }
        if (c.print == PrintMode.ALL) {
            i = 0;
            for (Layer l : c.net.getLayers()) {
                Set<String> s = l.paramTable().keySet();
                log.info("Layer " + i + ": " + l.getClass().getSimpleName() + " - params " + s);
                ++i;
            }
        }
        int totalNFailures = 0;
        double maxError = 0.0;
        DataSet ds = new DataSet(c.input, c.labels, c.inputMask, c.labelMask);
        int currParamNameIdx = 0;
        if (c.excludeParams != null && !c.excludeParams.isEmpty()) {
            log.info("NOTE: parameters will be skipped due to config: {}", c.excludeParams);
        }
        INDArray params = c.net.params();
        long i2 = 0L;
        while (i2 < nParams) {
            long step;
            if (i2 >= paramEnds[currParamNameIdx]) {
                ++currParamNameIdx;
            }
            String paramName = (String)paramNames.get(currParamNameIdx);
            if (c.excludeParams != null && c.excludeParams.contains(paramName)) {
                i2 = paramEnds[currParamNameIdx++];
                continue;
            }
            double origValue = params.getDouble(i2);
            params.putScalar(i2, origValue + c.epsilon);
            if (c.callEachIter != null) {
                c.callEachIter.accept((Object)c.net);
            }
            double scorePlus = c.net.score(ds, true);
            params.putScalar(i2, origValue - c.epsilon);
            if (c.callEachIter != null) {
                c.callEachIter.accept((Object)c.net);
            }
            double scoreMinus = c.net.score(ds, true);
            params.putScalar(i2, origValue);
            double scoreDelta = scorePlus - scoreMinus;
            double numericalGradient = scoreDelta / (2.0 * c.epsilon);
            if (Double.isNaN(numericalGradient)) {
                throw new IllegalStateException("Numerical gradient was NaN for parameter " + i2 + " of " + nParams);
            }
            double backpropGradient = gradientToCheck.getDouble(i2);
            double relError = Math.abs(backpropGradient - numericalGradient) / (Math.abs(numericalGradient) + Math.abs(backpropGradient));
            if (backpropGradient == 0.0 && numericalGradient == 0.0) {
                relError = 0.0;
            }
            if (relError > maxError) {
                maxError = relError;
            }
            if (relError > c.maxRelError || Double.isNaN(relError)) {
                double absError = Math.abs(backpropGradient - numericalGradient);
                if (absError < c.minAbsoluteError) {
                    if (c.print == PrintMode.ALL || c.print == PrintMode.ZEROS && absError == 0.0) {
                        log.info("Param " + i2 + " (" + paramName + ") passed: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError + "; absolute error = " + absError + " < minAbsoluteError = " + c.minAbsoluteError);
                    }
                } else {
                    log.info("Param " + i2 + " (" + paramName + ") FAILED: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError + ", scorePlus=" + scorePlus + ", scoreMinus= " + scoreMinus + ", paramValue = " + origValue);
                    if (c.exitOnFirstError) {
                        return false;
                    }
                    ++totalNFailures;
                }
            } else if (c.print == PrintMode.ALL) {
                log.info("Param " + i2 + " (" + paramName + ") passed: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError);
            }
            if (c.subset) {
                step = ((Integer)stepSizeForParam.get(paramName)).intValue();
                if (i2 + step > paramEnds[currParamNameIdx] + 1L) {
                    step = paramEnds[currParamNameIdx] + 1L - i2;
                }
            } else {
                step = 1L;
            }
            i2 += step;
        }
        long nPass = nParams - (long)totalNFailures;
        log.info("GradientCheckUtil.checkGradients(): " + nParams + " params checked, " + nPass + " passed, " + totalNFailures + " failed. Largest relative error = " + maxError);
        return totalNFailures == 0;
    }

    public static boolean checkGradients(GraphConfig c) {
        int i;
        if (c.epsilon <= 0.0 || c.epsilon > 0.1) {
            throw new IllegalArgumentException("Invalid epsilon: expect epsilon in range (0,0.1], usually 1e-4 or so");
        }
        if (c.maxRelError <= 0.0 || c.maxRelError > 0.25) {
            throw new IllegalArgumentException("Invalid maxRelativeError: " + c.maxRelError);
        }
        if (c.net.getNumInputArrays() != c.inputs.length) {
            throw new IllegalArgumentException("Invalid input arrays: expect " + c.net.getNumInputArrays() + " inputs");
        }
        if (c.net.getNumOutputArrays() != c.labels.length) {
            throw new IllegalArgumentException("Invalid labels arrays: expect " + c.net.getNumOutputArrays() + " outputs");
        }
        DataType dataType = DataTypeUtil.getDtypeFromContext();
        if (dataType != DataType.DOUBLE) {
            throw new IllegalStateException("Cannot perform gradient check: Datatype is not set to double precision (is: " + dataType + "). Double precision must be used for gradient checks. Set DataTypeUtil.setDTypeForContext(DataType.DOUBLE); before using GradientCheckUtil");
        }
        DataType netDataType = c.net.getConfiguration().getDataType();
        if (netDataType != DataType.DOUBLE) {
            throw new IllegalStateException("Cannot perform gradient check: Network datatype is not set to double precision (is: " + netDataType + "). Double precision must be used for gradient checks. Create network with .dataType(DataType.DOUBLE) before using GradientCheckUtil");
        }
        if (netDataType != c.net.params().dataType()) {
            throw new IllegalStateException("Parameters datatype does not match network configuration datatype (is: " + c.net.params().dataType() + "). If network datatype is set to DOUBLE, parameters must also be DOUBLE.");
        }
        int layerCount = 0;
        for (String vertexName : c.net.getConfiguration().getVertices().keySet()) {
            GraphVertex gv = c.net.getConfiguration().getVertices().get(vertexName);
            if (!(gv instanceof LayerVertex)) continue;
            LayerVertex lv = (LayerVertex)gv;
            if (lv.getLayerConf().getLayer() instanceof BaseLayer) {
                IActivation activation;
                BaseLayer bl = (BaseLayer)lv.getLayerConf().getLayer();
                IUpdater u = bl.getIUpdater();
                if (u instanceof Sgd) {
                    double lr = ((Sgd)u).getLearningRate();
                    if (lr != 1.0) {
                        throw new IllegalStateException("When using SGD updater, must also use lr=1.0 for layer " + layerCount + "; got " + u + " with lr=" + lr + " for layer \"" + lv.getLayerConf().getLayer().getLayerName() + "\"");
                    }
                } else if (!(u instanceof NoOp)) {
                    throw new IllegalStateException("Must have Updater.NONE (or SGD + lr=1.0) for layer " + layerCount + "; got " + u);
                }
                if ((activation = bl.getActivationFn()) != null && !VALID_ACTIVATION_FUNCTIONS.contains(activation.getClass())) {
                    log.warn("Layer \"" + vertexName + "\" is possibly using an unsuitable activation function: " + activation.getClass() + ". Activation functions for gradient checks must be smooth (like sigmoid, tanh, softmax) and not contain discontinuities like ReLU or LeakyReLU (these may cause spurious failures)");
                }
            }
            if (lv.getLayerConf().getLayer().getIDropout() == null || c.callEachIter != null) continue;
            throw new IllegalStateException("When gradient checking dropout, rng seed must be reset each iteration, or no dropout should be present during gradient checks - got dropout = " + lv.getLayerConf().getLayer().getIDropout() + " for layer " + layerCount);
        }
        for (Layer l : c.net.getLayers()) {
            if (!(l instanceof IOutputLayer)) continue;
            GradientCheckUtil.configureLossFnClippingIfPresent((IOutputLayer)l);
        }
        for (i = 0; i < c.inputs.length; ++i) {
            c.net.setInput(i, c.inputs[i]);
        }
        for (i = 0; i < c.labels.length; ++i) {
            c.net.setLabel(i, c.labels[i]);
        }
        c.net.setLayerMaskArrays(c.inputMask, c.labelMask);
        if (c.callEachIter != null) {
            c.callEachIter.accept((Object)c.net);
        }
        c.net.computeGradientAndScore();
        Pair<Gradient, Double> gradAndScore = c.net.gradientAndScore();
        ComputationGraphUpdater updater = new ComputationGraphUpdater(c.net);
        updater.update((Gradient)gradAndScore.getFirst(), 0, 0, c.net.batchSize(), LayerWorkspaceMgr.noWorkspaces());
        INDArray gradientToCheck = ((Gradient)gradAndScore.getFirst()).gradient().dup();
        INDArray originalParams = c.net.params().dup();
        long nParams = originalParams.length();
        Map<String, INDArray> paramTable = c.net.paramTable();
        ArrayList<String> paramNames = new ArrayList<String>(paramTable.keySet());
        long[] paramEnds = new long[paramNames.size()];
        paramEnds[0] = paramTable.get(paramNames.get(0)).length();
        for (int i2 = 1; i2 < paramEnds.length; ++i2) {
            paramEnds[i2] = paramEnds[i2 - 1] + paramTable.get(paramNames.get(i2)).length();
        }
        if (c.excludeParams != null && !c.excludeParams.isEmpty()) {
            log.info("NOTE: parameters will be skipped due to config: {}", c.excludeParams);
        }
        int currParamNameIdx = 0;
        int totalNFailures = 0;
        double maxError = 0.0;
        MultiDataSet mds = new MultiDataSet(c.inputs, c.labels, c.inputMask, c.labelMask);
        INDArray params = c.net.params();
        for (long i3 = 0L; i3 < nParams; ++i3) {
            if (i3 >= paramEnds[currParamNameIdx]) {
                ++currParamNameIdx;
            }
            String paramName = (String)paramNames.get(currParamNameIdx);
            if (c.excludeParams != null && c.excludeParams.contains(paramName)) {
                i3 = paramEnds[currParamNameIdx++];
                continue;
            }
            double origValue = params.getDouble(i3);
            params.putScalar(i3, origValue + c.epsilon);
            if (c.callEachIter != null) {
                c.callEachIter.accept((Object)c.net);
            }
            double scorePlus = c.net.score((org.nd4j.linalg.dataset.api.MultiDataSet)mds, true);
            params.putScalar(i3, origValue - c.epsilon);
            if (c.callEachIter != null) {
                c.callEachIter.accept((Object)c.net);
            }
            double scoreMinus = c.net.score((org.nd4j.linalg.dataset.api.MultiDataSet)mds, true);
            params.putScalar(i3, origValue);
            double scoreDelta = scorePlus - scoreMinus;
            double numericalGradient = scoreDelta / (2.0 * c.epsilon);
            if (Double.isNaN(numericalGradient)) {
                throw new IllegalStateException("Numerical gradient was NaN for parameter " + i3 + " of " + nParams);
            }
            double backpropGradient = gradientToCheck.getDouble(i3);
            double relError = Math.abs(backpropGradient - numericalGradient) / (Math.abs(numericalGradient) + Math.abs(backpropGradient));
            if (backpropGradient == 0.0 && numericalGradient == 0.0) {
                relError = 0.0;
            }
            if (relError > maxError) {
                maxError = relError;
            }
            if (relError > c.maxRelError || Double.isNaN(relError)) {
                double absError = Math.abs(backpropGradient - numericalGradient);
                if (absError < c.minAbsoluteError) {
                    if (c.print != PrintMode.ALL && (c.print != PrintMode.ZEROS || absError != 0.0)) continue;
                    log.info("Param " + i3 + " (" + paramName + ") passed: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError + "; absolute error = " + absError + " < minAbsoluteError = " + c.minAbsoluteError);
                    continue;
                }
                log.info("Param " + i3 + " (" + paramName + ") FAILED: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError + ", scorePlus=" + scorePlus + ", scoreMinus= " + scoreMinus + ", paramValue = " + origValue);
                if (c.exitOnFirstError) {
                    return false;
                }
                ++totalNFailures;
                continue;
            }
            if (c.print != PrintMode.ALL) continue;
            log.info("Param " + i3 + " (" + paramName + ") passed: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError);
        }
        long nPass = nParams - (long)totalNFailures;
        log.info("GradientCheckUtil.checkGradients(): " + nParams + " params checked, " + nPass + " passed, " + totalNFailures + " failed. Largest relative error = " + maxError);
        return totalNFailures == 0;
    }

    public static boolean checkGradientsPretrainLayer(Layer layer, double epsilon, double maxRelError, double minAbsoluteError, boolean print, boolean exitOnFirstError, INDArray input, int rngSeed) {
        LayerWorkspaceMgr mgr = LayerWorkspaceMgr.noWorkspaces();
        if (epsilon <= 0.0 || epsilon > 0.1) {
            throw new IllegalArgumentException("Invalid epsilon: expect epsilon in range (0,0.1], usually 1e-4 or so");
        }
        if (maxRelError <= 0.0 || maxRelError > 0.25) {
            throw new IllegalArgumentException("Invalid maxRelativeError: " + maxRelError);
        }
        DataType dataType = DataTypeUtil.getDtypeFromContext();
        if (dataType != DataType.DOUBLE) {
            throw new IllegalStateException("Cannot perform gradient check: Datatype is not set to double precision (is: " + dataType + "). Double precision must be used for gradient checks. Set DataTypeUtil.setDTypeForContext(DataType.DOUBLE); before using GradientCheckUtil");
        }
        layer.setInput(input, LayerWorkspaceMgr.noWorkspaces());
        Nd4j.getRandom().setSeed(rngSeed);
        layer.computeGradientAndScore(mgr);
        Pair<Gradient, Double> gradAndScore = layer.gradientAndScore();
        Updater updater = UpdaterCreator.getUpdater(layer);
        updater.update(layer, (Gradient)gradAndScore.getFirst(), 0, 0, layer.batchSize(), LayerWorkspaceMgr.noWorkspaces());
        INDArray gradientToCheck = ((Gradient)gradAndScore.getFirst()).gradient().dup();
        INDArray originalParams = layer.params().dup();
        long nParams = originalParams.length();
        Map<String, INDArray> paramTable = layer.paramTable();
        ArrayList<String> paramNames = new ArrayList<String>(paramTable.keySet());
        long[] paramEnds = new long[paramNames.size()];
        paramEnds[0] = paramTable.get(paramNames.get(0)).length();
        for (int i = 1; i < paramEnds.length; ++i) {
            paramEnds[i] = paramEnds[i - 1] + paramTable.get(paramNames.get(i)).length();
        }
        int totalNFailures = 0;
        double maxError = 0.0;
        int currParamNameIdx = 0;
        INDArray params = layer.params();
        int i = 0;
        while ((long)i < nParams) {
            if ((long)i >= paramEnds[currParamNameIdx]) {
                ++currParamNameIdx;
            }
            String paramName = (String)paramNames.get(currParamNameIdx);
            double origValue = params.getDouble((long)i);
            params.putScalar((long)i, origValue + epsilon);
            Nd4j.getRandom().setSeed(rngSeed);
            layer.computeGradientAndScore(mgr);
            double scorePlus = layer.score();
            params.putScalar((long)i, origValue - epsilon);
            Nd4j.getRandom().setSeed(rngSeed);
            layer.computeGradientAndScore(mgr);
            double scoreMinus = layer.score();
            params.putScalar((long)i, origValue);
            double scoreDelta = scorePlus - scoreMinus;
            double numericalGradient = scoreDelta / (2.0 * epsilon);
            if (Double.isNaN(numericalGradient)) {
                throw new IllegalStateException("Numerical gradient was NaN for parameter " + i + " of " + nParams);
            }
            double backpropGradient = gradientToCheck.getDouble((long)i);
            double relError = Math.abs(backpropGradient - numericalGradient) / (Math.abs(numericalGradient) + Math.abs(backpropGradient));
            if (backpropGradient == 0.0 && numericalGradient == 0.0) {
                relError = 0.0;
            }
            if (relError > maxError) {
                maxError = relError;
            }
            if (relError > maxRelError || Double.isNaN(relError)) {
                double absError = Math.abs(backpropGradient - numericalGradient);
                if (absError < minAbsoluteError) {
                    log.info("Param " + i + " (" + paramName + ") passed: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError + "; absolute error = " + absError + " < minAbsoluteError = " + minAbsoluteError);
                } else {
                    if (print) {
                        log.info("Param " + i + " (" + paramName + ") FAILED: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError + ", scorePlus=" + scorePlus + ", scoreMinus= " + scoreMinus + ", paramValue = " + origValue);
                    }
                    if (exitOnFirstError) {
                        return false;
                    }
                    ++totalNFailures;
                }
            } else if (print) {
                log.info("Param " + i + " (" + paramName + ") passed: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError);
            }
            ++i;
        }
        if (print) {
            long nPass = nParams - (long)totalNFailures;
            log.info("GradientCheckUtil.checkGradients(): " + nParams + " params checked, " + nPass + " passed, " + totalNFailures + " failed. Largest relative error = " + maxError);
        }
        return totalNFailures == 0;
    }

    public static class GraphConfig {
        private ComputationGraph net;
        private INDArray[] inputs;
        private INDArray[] labels;
        private INDArray[] inputMask;
        private INDArray[] labelMask;
        private double epsilon = 1.0E-6;
        private double maxRelError = 0.001;
        private double minAbsoluteError = 1.0E-8;
        private PrintMode print = PrintMode.ZEROS;
        private boolean exitOnFirstError = false;
        private boolean subset;
        private int maxPerParam;
        private Set<String> excludeParams;
        private Consumer<ComputationGraph> callEachIter;

        public ComputationGraph net() {
            return this.net;
        }

        public INDArray[] inputs() {
            return this.inputs;
        }

        public INDArray[] labels() {
            return this.labels;
        }

        public INDArray[] inputMask() {
            return this.inputMask;
        }

        public INDArray[] labelMask() {
            return this.labelMask;
        }

        public double epsilon() {
            return this.epsilon;
        }

        public double maxRelError() {
            return this.maxRelError;
        }

        public double minAbsoluteError() {
            return this.minAbsoluteError;
        }

        public PrintMode print() {
            return this.print;
        }

        public boolean exitOnFirstError() {
            return this.exitOnFirstError;
        }

        public boolean subset() {
            return this.subset;
        }

        public int maxPerParam() {
            return this.maxPerParam;
        }

        public Set<String> excludeParams() {
            return this.excludeParams;
        }

        public Consumer<ComputationGraph> callEachIter() {
            return this.callEachIter;
        }

        public GraphConfig net(ComputationGraph net) {
            this.net = net;
            return this;
        }

        public GraphConfig inputs(INDArray[] inputs) {
            this.inputs = inputs;
            return this;
        }

        public GraphConfig labels(INDArray[] labels) {
            this.labels = labels;
            return this;
        }

        public GraphConfig inputMask(INDArray[] inputMask) {
            this.inputMask = inputMask;
            return this;
        }

        public GraphConfig labelMask(INDArray[] labelMask) {
            this.labelMask = labelMask;
            return this;
        }

        public GraphConfig epsilon(double epsilon) {
            this.epsilon = epsilon;
            return this;
        }

        public GraphConfig maxRelError(double maxRelError) {
            this.maxRelError = maxRelError;
            return this;
        }

        public GraphConfig minAbsoluteError(double minAbsoluteError) {
            this.minAbsoluteError = minAbsoluteError;
            return this;
        }

        public GraphConfig print(PrintMode print) {
            this.print = print;
            return this;
        }

        public GraphConfig exitOnFirstError(boolean exitOnFirstError) {
            this.exitOnFirstError = exitOnFirstError;
            return this;
        }

        public GraphConfig subset(boolean subset) {
            this.subset = subset;
            return this;
        }

        public GraphConfig maxPerParam(int maxPerParam) {
            this.maxPerParam = maxPerParam;
            return this;
        }

        public GraphConfig excludeParams(Set<String> excludeParams) {
            this.excludeParams = excludeParams;
            return this;
        }

        public GraphConfig callEachIter(Consumer<ComputationGraph> callEachIter) {
            this.callEachIter = callEachIter;
            return this;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof GraphConfig)) {
                return false;
            }
            GraphConfig other = (GraphConfig)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (Double.compare(this.epsilon(), other.epsilon()) != 0) {
                return false;
            }
            if (Double.compare(this.maxRelError(), other.maxRelError()) != 0) {
                return false;
            }
            if (Double.compare(this.minAbsoluteError(), other.minAbsoluteError()) != 0) {
                return false;
            }
            if (this.exitOnFirstError() != other.exitOnFirstError()) {
                return false;
            }
            if (this.subset() != other.subset()) {
                return false;
            }
            if (this.maxPerParam() != other.maxPerParam()) {
                return false;
            }
            ComputationGraph this$net = this.net();
            ComputationGraph other$net = other.net();
            if (this$net == null ? other$net != null : !((Object)this$net).equals(other$net)) {
                return false;
            }
            if (!Arrays.deepEquals(this.inputs(), other.inputs())) {
                return false;
            }
            if (!Arrays.deepEquals(this.labels(), other.labels())) {
                return false;
            }
            if (!Arrays.deepEquals(this.inputMask(), other.inputMask())) {
                return false;
            }
            if (!Arrays.deepEquals(this.labelMask(), other.labelMask())) {
                return false;
            }
            PrintMode this$print = this.print();
            PrintMode other$print = other.print();
            if (this$print == null ? other$print != null : !((Object)((Object)this$print)).equals((Object)other$print)) {
                return false;
            }
            Set<String> this$excludeParams = this.excludeParams();
            Set<String> other$excludeParams = other.excludeParams();
            if (this$excludeParams == null ? other$excludeParams != null : !((Object)this$excludeParams).equals(other$excludeParams)) {
                return false;
            }
            Consumer<ComputationGraph> this$callEachIter = this.callEachIter();
            Consumer<ComputationGraph> other$callEachIter = other.callEachIter();
            return !(this$callEachIter == null ? other$callEachIter != null : !this$callEachIter.equals(other$callEachIter));
        }

        protected boolean canEqual(Object other) {
            return other instanceof GraphConfig;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            long $epsilon = Double.doubleToLongBits(this.epsilon());
            result = result * 59 + (int)($epsilon >>> 32 ^ $epsilon);
            long $maxRelError = Double.doubleToLongBits(this.maxRelError());
            result = result * 59 + (int)($maxRelError >>> 32 ^ $maxRelError);
            long $minAbsoluteError = Double.doubleToLongBits(this.minAbsoluteError());
            result = result * 59 + (int)($minAbsoluteError >>> 32 ^ $minAbsoluteError);
            result = result * 59 + (this.exitOnFirstError() ? 79 : 97);
            result = result * 59 + (this.subset() ? 79 : 97);
            result = result * 59 + this.maxPerParam();
            ComputationGraph $net = this.net();
            result = result * 59 + ($net == null ? 43 : $net.hashCode());
            result = result * 59 + Arrays.deepHashCode(this.inputs());
            result = result * 59 + Arrays.deepHashCode(this.labels());
            result = result * 59 + Arrays.deepHashCode(this.inputMask());
            result = result * 59 + Arrays.deepHashCode(this.labelMask());
            PrintMode $print = this.print();
            result = result * 59 + ($print == null ? 43 : ((Object)((Object)$print)).hashCode());
            Set<String> $excludeParams = this.excludeParams();
            result = result * 59 + ($excludeParams == null ? 43 : ((Object)$excludeParams).hashCode());
            Consumer<ComputationGraph> $callEachIter = this.callEachIter();
            result = result * 59 + ($callEachIter == null ? 43 : $callEachIter.hashCode());
            return result;
        }

        public String toString() {
            return "GradientCheckUtil.GraphConfig(net=" + this.net() + ", inputs=" + Arrays.deepToString(this.inputs()) + ", labels=" + Arrays.deepToString(this.labels()) + ", inputMask=" + Arrays.deepToString(this.inputMask()) + ", labelMask=" + Arrays.deepToString(this.labelMask()) + ", epsilon=" + this.epsilon() + ", maxRelError=" + this.maxRelError() + ", minAbsoluteError=" + this.minAbsoluteError() + ", print=" + this.print() + ", exitOnFirstError=" + this.exitOnFirstError() + ", subset=" + this.subset() + ", maxPerParam=" + this.maxPerParam() + ", excludeParams=" + this.excludeParams() + ", callEachIter=" + this.callEachIter() + ")";
        }
    }

    public static class MLNConfig {
        private MultiLayerNetwork net;
        private INDArray input;
        private INDArray labels;
        private INDArray inputMask;
        private INDArray labelMask;
        private double epsilon = 1.0E-6;
        private double maxRelError = 0.001;
        private double minAbsoluteError = 1.0E-8;
        private PrintMode print = PrintMode.ZEROS;
        private boolean exitOnFirstError = false;
        private boolean subset;
        private int maxPerParam;
        private Set<String> excludeParams;
        private Consumer<MultiLayerNetwork> callEachIter;

        public MultiLayerNetwork net() {
            return this.net;
        }

        public INDArray input() {
            return this.input;
        }

        public INDArray labels() {
            return this.labels;
        }

        public INDArray inputMask() {
            return this.inputMask;
        }

        public INDArray labelMask() {
            return this.labelMask;
        }

        public double epsilon() {
            return this.epsilon;
        }

        public double maxRelError() {
            return this.maxRelError;
        }

        public double minAbsoluteError() {
            return this.minAbsoluteError;
        }

        public PrintMode print() {
            return this.print;
        }

        public boolean exitOnFirstError() {
            return this.exitOnFirstError;
        }

        public boolean subset() {
            return this.subset;
        }

        public int maxPerParam() {
            return this.maxPerParam;
        }

        public Set<String> excludeParams() {
            return this.excludeParams;
        }

        public Consumer<MultiLayerNetwork> callEachIter() {
            return this.callEachIter;
        }

        public MLNConfig net(MultiLayerNetwork net) {
            this.net = net;
            return this;
        }

        public MLNConfig input(INDArray input) {
            this.input = input;
            return this;
        }

        public MLNConfig labels(INDArray labels) {
            this.labels = labels;
            return this;
        }

        public MLNConfig inputMask(INDArray inputMask) {
            this.inputMask = inputMask;
            return this;
        }

        public MLNConfig labelMask(INDArray labelMask) {
            this.labelMask = labelMask;
            return this;
        }

        public MLNConfig epsilon(double epsilon) {
            this.epsilon = epsilon;
            return this;
        }

        public MLNConfig maxRelError(double maxRelError) {
            this.maxRelError = maxRelError;
            return this;
        }

        public MLNConfig minAbsoluteError(double minAbsoluteError) {
            this.minAbsoluteError = minAbsoluteError;
            return this;
        }

        public MLNConfig print(PrintMode print) {
            this.print = print;
            return this;
        }

        public MLNConfig exitOnFirstError(boolean exitOnFirstError) {
            this.exitOnFirstError = exitOnFirstError;
            return this;
        }

        public MLNConfig subset(boolean subset) {
            this.subset = subset;
            return this;
        }

        public MLNConfig maxPerParam(int maxPerParam) {
            this.maxPerParam = maxPerParam;
            return this;
        }

        public MLNConfig excludeParams(Set<String> excludeParams) {
            this.excludeParams = excludeParams;
            return this;
        }

        public MLNConfig callEachIter(Consumer<MultiLayerNetwork> callEachIter) {
            this.callEachIter = callEachIter;
            return this;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof MLNConfig)) {
                return false;
            }
            MLNConfig other = (MLNConfig)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (Double.compare(this.epsilon(), other.epsilon()) != 0) {
                return false;
            }
            if (Double.compare(this.maxRelError(), other.maxRelError()) != 0) {
                return false;
            }
            if (Double.compare(this.minAbsoluteError(), other.minAbsoluteError()) != 0) {
                return false;
            }
            if (this.exitOnFirstError() != other.exitOnFirstError()) {
                return false;
            }
            if (this.subset() != other.subset()) {
                return false;
            }
            if (this.maxPerParam() != other.maxPerParam()) {
                return false;
            }
            MultiLayerNetwork this$net = this.net();
            MultiLayerNetwork other$net = other.net();
            if (this$net == null ? other$net != null : !((Object)this$net).equals(other$net)) {
                return false;
            }
            INDArray this$input = this.input();
            INDArray other$input = other.input();
            if (this$input == null ? other$input != null : !this$input.equals(other$input)) {
                return false;
            }
            INDArray this$labels = this.labels();
            INDArray other$labels = other.labels();
            if (this$labels == null ? other$labels != null : !this$labels.equals(other$labels)) {
                return false;
            }
            INDArray this$inputMask = this.inputMask();
            INDArray other$inputMask = other.inputMask();
            if (this$inputMask == null ? other$inputMask != null : !this$inputMask.equals(other$inputMask)) {
                return false;
            }
            INDArray this$labelMask = this.labelMask();
            INDArray other$labelMask = other.labelMask();
            if (this$labelMask == null ? other$labelMask != null : !this$labelMask.equals(other$labelMask)) {
                return false;
            }
            PrintMode this$print = this.print();
            PrintMode other$print = other.print();
            if (this$print == null ? other$print != null : !((Object)((Object)this$print)).equals((Object)other$print)) {
                return false;
            }
            Set<String> this$excludeParams = this.excludeParams();
            Set<String> other$excludeParams = other.excludeParams();
            if (this$excludeParams == null ? other$excludeParams != null : !((Object)this$excludeParams).equals(other$excludeParams)) {
                return false;
            }
            Consumer<MultiLayerNetwork> this$callEachIter = this.callEachIter();
            Consumer<MultiLayerNetwork> other$callEachIter = other.callEachIter();
            return !(this$callEachIter == null ? other$callEachIter != null : !this$callEachIter.equals(other$callEachIter));
        }

        protected boolean canEqual(Object other) {
            return other instanceof MLNConfig;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            long $epsilon = Double.doubleToLongBits(this.epsilon());
            result = result * 59 + (int)($epsilon >>> 32 ^ $epsilon);
            long $maxRelError = Double.doubleToLongBits(this.maxRelError());
            result = result * 59 + (int)($maxRelError >>> 32 ^ $maxRelError);
            long $minAbsoluteError = Double.doubleToLongBits(this.minAbsoluteError());
            result = result * 59 + (int)($minAbsoluteError >>> 32 ^ $minAbsoluteError);
            result = result * 59 + (this.exitOnFirstError() ? 79 : 97);
            result = result * 59 + (this.subset() ? 79 : 97);
            result = result * 59 + this.maxPerParam();
            MultiLayerNetwork $net = this.net();
            result = result * 59 + ($net == null ? 43 : $net.hashCode());
            INDArray $input = this.input();
            result = result * 59 + ($input == null ? 43 : $input.hashCode());
            INDArray $labels = this.labels();
            result = result * 59 + ($labels == null ? 43 : $labels.hashCode());
            INDArray $inputMask = this.inputMask();
            result = result * 59 + ($inputMask == null ? 43 : $inputMask.hashCode());
            INDArray $labelMask = this.labelMask();
            result = result * 59 + ($labelMask == null ? 43 : $labelMask.hashCode());
            PrintMode $print = this.print();
            result = result * 59 + ($print == null ? 43 : ((Object)((Object)$print)).hashCode());
            Set<String> $excludeParams = this.excludeParams();
            result = result * 59 + ($excludeParams == null ? 43 : ((Object)$excludeParams).hashCode());
            Consumer<MultiLayerNetwork> $callEachIter = this.callEachIter();
            result = result * 59 + ($callEachIter == null ? 43 : $callEachIter.hashCode());
            return result;
        }

        public String toString() {
            return "GradientCheckUtil.MLNConfig(net=" + this.net() + ", input=" + this.input() + ", labels=" + this.labels() + ", inputMask=" + this.inputMask() + ", labelMask=" + this.labelMask() + ", epsilon=" + this.epsilon() + ", maxRelError=" + this.maxRelError() + ", minAbsoluteError=" + this.minAbsoluteError() + ", print=" + this.print() + ", exitOnFirstError=" + this.exitOnFirstError() + ", subset=" + this.subset() + ", maxPerParam=" + this.maxPerParam() + ", excludeParams=" + this.excludeParams() + ", callEachIter=" + this.callEachIter() + ")";
        }
    }

    public static enum PrintMode {
        ALL,
        ZEROS,
        FAILURES_ONLY;

    }
}

