package org.deeplearning4j.gradientcheck;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.deeplearning4j.eval.EvaluationBinary;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.graph.LayerVertex;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.BaseOutputLayer;
import org.deeplearning4j.nn.layers.LossLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.UpdaterCreator;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.common.function.Consumer;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT;
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/gradientcheck/GradientCheckUtil.class */
public class GradientCheckUtil {
    private static final Logger log = LoggerFactory.getLogger(GradientCheckUtil.class);
    private static final List<Class<? extends IActivation>> VALID_ACTIVATION_FUNCTIONS = Arrays.asList(Activation.CUBE.getActivationFunction().getClass(), Activation.ELU.getActivationFunction().getClass(), Activation.IDENTITY.getActivationFunction().getClass(), Activation.RATIONALTANH.getActivationFunction().getClass(), Activation.SIGMOID.getActivationFunction().getClass(), Activation.SOFTMAX.getActivationFunction().getClass(), Activation.SOFTPLUS.getActivationFunction().getClass(), Activation.SOFTSIGN.getActivationFunction().getClass(), Activation.TANH.getActivationFunction().getClass());

    /* loaded from: input_file:org/deeplearning4j/gradientcheck/GradientCheckUtil$GraphConfig.class */
    public static class GraphConfig {
        private ComputationGraph net;
        private INDArray[] inputs;
        private INDArray[] labels;
        private INDArray[] inputMask;
        private INDArray[] labelMask;
        private double epsilon = 1.0E-6d;
        private double maxRelError = 0.001d;
        private double minAbsoluteError = 1.0E-8d;
        private PrintMode print = PrintMode.ZEROS;
        private boolean exitOnFirstError = false;
        private boolean subset;
        private int maxPerParam;
        private Set<String> excludeParams;
        private Consumer<ComputationGraph> callEachIter;

        public ComputationGraph net() {
            return this.net;
        }

        public INDArray[] inputs() {
            return this.inputs;
        }

        public INDArray[] labels() {
            return this.labels;
        }

        public INDArray[] inputMask() {
            return this.inputMask;
        }

        public INDArray[] labelMask() {
            return this.labelMask;
        }

        public double epsilon() {
            return this.epsilon;
        }

        public double maxRelError() {
            return this.maxRelError;
        }

        public double minAbsoluteError() {
            return this.minAbsoluteError;
        }

        public PrintMode print() {
            return this.print;
        }

        public boolean exitOnFirstError() {
            return this.exitOnFirstError;
        }

        public boolean subset() {
            return this.subset;
        }

        public int maxPerParam() {
            return this.maxPerParam;
        }

        public Set<String> excludeParams() {
            return this.excludeParams;
        }

        public Consumer<ComputationGraph> callEachIter() {
            return this.callEachIter;
        }

        public GraphConfig net(ComputationGraph computationGraph) {
            this.net = computationGraph;
            return this;
        }

        public GraphConfig inputs(INDArray[] iNDArrayArr) {
            this.inputs = iNDArrayArr;
            return this;
        }

        public GraphConfig labels(INDArray[] iNDArrayArr) {
            this.labels = iNDArrayArr;
            return this;
        }

        public GraphConfig inputMask(INDArray[] iNDArrayArr) {
            this.inputMask = iNDArrayArr;
            return this;
        }

        public GraphConfig labelMask(INDArray[] iNDArrayArr) {
            this.labelMask = iNDArrayArr;
            return this;
        }

        public GraphConfig epsilon(double d) {
            this.epsilon = d;
            return this;
        }

        public GraphConfig maxRelError(double d) {
            this.maxRelError = d;
            return this;
        }

        public GraphConfig minAbsoluteError(double d) {
            this.minAbsoluteError = d;
            return this;
        }

        public GraphConfig print(PrintMode printMode) {
            this.print = printMode;
            return this;
        }

        public GraphConfig exitOnFirstError(boolean z) {
            this.exitOnFirstError = z;
            return this;
        }

        public GraphConfig subset(boolean z) {
            this.subset = z;
            return this;
        }

        public GraphConfig maxPerParam(int i) {
            this.maxPerParam = i;
            return this;
        }

        public GraphConfig excludeParams(Set<String> set) {
            this.excludeParams = set;
            return this;
        }

        public GraphConfig callEachIter(Consumer<ComputationGraph> consumer) {
            this.callEachIter = consumer;
            return this;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof GraphConfig)) {
                return false;
            }
            GraphConfig graphConfig = (GraphConfig) obj;
            if (!graphConfig.canEqual(this) || Double.compare(epsilon(), graphConfig.epsilon()) != 0 || Double.compare(maxRelError(), graphConfig.maxRelError()) != 0 || Double.compare(minAbsoluteError(), graphConfig.minAbsoluteError()) != 0 || exitOnFirstError() != graphConfig.exitOnFirstError() || subset() != graphConfig.subset() || maxPerParam() != graphConfig.maxPerParam()) {
                return false;
            }
            ComputationGraph net = net();
            ComputationGraph net2 = graphConfig.net();
            if (net == null) {
                if (net2 != null) {
                    return false;
                }
            } else if (!net.equals(net2)) {
                return false;
            }
            if (!Arrays.deepEquals(inputs(), graphConfig.inputs()) || !Arrays.deepEquals(labels(), graphConfig.labels()) || !Arrays.deepEquals(inputMask(), graphConfig.inputMask()) || !Arrays.deepEquals(labelMask(), graphConfig.labelMask())) {
                return false;
            }
            PrintMode print = print();
            PrintMode print2 = graphConfig.print();
            if (print == null) {
                if (print2 != null) {
                    return false;
                }
            } else if (!print.equals(print2)) {
                return false;
            }
            Set<String> excludeParams = excludeParams();
            Set<String> excludeParams2 = graphConfig.excludeParams();
            if (excludeParams == null) {
                if (excludeParams2 != null) {
                    return false;
                }
            } else if (!excludeParams.equals(excludeParams2)) {
                return false;
            }
            Consumer<ComputationGraph> callEachIter = callEachIter();
            Consumer<ComputationGraph> callEachIter2 = graphConfig.callEachIter();
            return callEachIter == null ? callEachIter2 == null : callEachIter.equals(callEachIter2);
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof GraphConfig;
        }

        public int hashCode() {
            long doubleToLongBits = Double.doubleToLongBits(epsilon());
            int i = (1 * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
            long doubleToLongBits2 = Double.doubleToLongBits(maxRelError());
            int i2 = (i * 59) + ((int) ((doubleToLongBits2 >>> 32) ^ doubleToLongBits2));
            long doubleToLongBits3 = Double.doubleToLongBits(minAbsoluteError());
            int maxPerParam = (((((((i2 * 59) + ((int) ((doubleToLongBits3 >>> 32) ^ doubleToLongBits3))) * 59) + (exitOnFirstError() ? 79 : 97)) * 59) + (subset() ? 79 : 97)) * 59) + maxPerParam();
            ComputationGraph net = net();
            int hashCode = (((((((((maxPerParam * 59) + (net == null ? 43 : net.hashCode())) * 59) + Arrays.deepHashCode(inputs())) * 59) + Arrays.deepHashCode(labels())) * 59) + Arrays.deepHashCode(inputMask())) * 59) + Arrays.deepHashCode(labelMask());
            PrintMode print = print();
            int hashCode2 = (hashCode * 59) + (print == null ? 43 : print.hashCode());
            Set<String> excludeParams = excludeParams();
            int hashCode3 = (hashCode2 * 59) + (excludeParams == null ? 43 : excludeParams.hashCode());
            Consumer<ComputationGraph> callEachIter = callEachIter();
            return (hashCode3 * 59) + (callEachIter == null ? 43 : callEachIter.hashCode());
        }

        public String toString() {
            ComputationGraph net = net();
            String deepToString = Arrays.deepToString(inputs());
            String deepToString2 = Arrays.deepToString(labels());
            String deepToString3 = Arrays.deepToString(inputMask());
            String deepToString4 = Arrays.deepToString(labelMask());
            double epsilon = epsilon();
            double maxRelError = maxRelError();
            double minAbsoluteError = minAbsoluteError();
            PrintMode print = print();
            boolean exitOnFirstError = exitOnFirstError();
            boolean subset = subset();
            maxPerParam();
            excludeParams();
            callEachIter();
            return "GradientCheckUtil.GraphConfig(net=" + net + ", inputs=" + deepToString + ", labels=" + deepToString2 + ", inputMask=" + deepToString3 + ", labelMask=" + deepToString4 + ", epsilon=" + epsilon + ", maxRelError=" + net + ", minAbsoluteError=" + maxRelError + ", print=" + net + ", exitOnFirstError=" + minAbsoluteError + ", subset=" + net + ", maxPerParam=" + print + ", excludeParams=" + exitOnFirstError + ", callEachIter=" + subset + ")";
        }
    }

    /* loaded from: input_file:org/deeplearning4j/gradientcheck/GradientCheckUtil$MLNConfig.class */
    public static class MLNConfig {
        private MultiLayerNetwork net;
        private INDArray input;
        private INDArray labels;
        private INDArray inputMask;
        private INDArray labelMask;
        private double epsilon = 1.0E-6d;
        private double maxRelError = 0.001d;
        private double minAbsoluteError = 1.0E-8d;
        private PrintMode print = PrintMode.ZEROS;
        private boolean exitOnFirstError = false;
        private boolean subset;
        private int maxPerParam;
        private Set<String> excludeParams;
        private Consumer<MultiLayerNetwork> callEachIter;

        public MultiLayerNetwork net() {
            return this.net;
        }

        public INDArray input() {
            return this.input;
        }

        public INDArray labels() {
            return this.labels;
        }

        public INDArray inputMask() {
            return this.inputMask;
        }

        public INDArray labelMask() {
            return this.labelMask;
        }

        public double epsilon() {
            return this.epsilon;
        }

        public double maxRelError() {
            return this.maxRelError;
        }

        public double minAbsoluteError() {
            return this.minAbsoluteError;
        }

        public PrintMode print() {
            return this.print;
        }

        public boolean exitOnFirstError() {
            return this.exitOnFirstError;
        }

        public boolean subset() {
            return this.subset;
        }

        public int maxPerParam() {
            return this.maxPerParam;
        }

        public Set<String> excludeParams() {
            return this.excludeParams;
        }

        public Consumer<MultiLayerNetwork> callEachIter() {
            return this.callEachIter;
        }

        public MLNConfig net(MultiLayerNetwork multiLayerNetwork) {
            this.net = multiLayerNetwork;
            return this;
        }

        public MLNConfig input(INDArray iNDArray) {
            this.input = iNDArray;
            return this;
        }

        public MLNConfig labels(INDArray iNDArray) {
            this.labels = iNDArray;
            return this;
        }

        public MLNConfig inputMask(INDArray iNDArray) {
            this.inputMask = iNDArray;
            return this;
        }

        public MLNConfig labelMask(INDArray iNDArray) {
            this.labelMask = iNDArray;
            return this;
        }

        public MLNConfig epsilon(double d) {
            this.epsilon = d;
            return this;
        }

        public MLNConfig maxRelError(double d) {
            this.maxRelError = d;
            return this;
        }

        public MLNConfig minAbsoluteError(double d) {
            this.minAbsoluteError = d;
            return this;
        }

        public MLNConfig print(PrintMode printMode) {
            this.print = printMode;
            return this;
        }

        public MLNConfig exitOnFirstError(boolean z) {
            this.exitOnFirstError = z;
            return this;
        }

        public MLNConfig subset(boolean z) {
            this.subset = z;
            return this;
        }

        public MLNConfig maxPerParam(int i) {
            this.maxPerParam = i;
            return this;
        }

        public MLNConfig excludeParams(Set<String> set) {
            this.excludeParams = set;
            return this;
        }

        public MLNConfig callEachIter(Consumer<MultiLayerNetwork> consumer) {
            this.callEachIter = consumer;
            return this;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof MLNConfig)) {
                return false;
            }
            MLNConfig mLNConfig = (MLNConfig) obj;
            if (!mLNConfig.canEqual(this) || Double.compare(epsilon(), mLNConfig.epsilon()) != 0 || Double.compare(maxRelError(), mLNConfig.maxRelError()) != 0 || Double.compare(minAbsoluteError(), mLNConfig.minAbsoluteError()) != 0 || exitOnFirstError() != mLNConfig.exitOnFirstError() || subset() != mLNConfig.subset() || maxPerParam() != mLNConfig.maxPerParam()) {
                return false;
            }
            MultiLayerNetwork net = net();
            MultiLayerNetwork net2 = mLNConfig.net();
            if (net == null) {
                if (net2 != null) {
                    return false;
                }
            } else if (!net.equals(net2)) {
                return false;
            }
            INDArray input = input();
            INDArray input2 = mLNConfig.input();
            if (input == null) {
                if (input2 != null) {
                    return false;
                }
            } else if (!input.equals(input2)) {
                return false;
            }
            INDArray labels = labels();
            INDArray labels2 = mLNConfig.labels();
            if (labels == null) {
                if (labels2 != null) {
                    return false;
                }
            } else if (!labels.equals(labels2)) {
                return false;
            }
            INDArray inputMask = inputMask();
            INDArray inputMask2 = mLNConfig.inputMask();
            if (inputMask == null) {
                if (inputMask2 != null) {
                    return false;
                }
            } else if (!inputMask.equals(inputMask2)) {
                return false;
            }
            INDArray labelMask = labelMask();
            INDArray labelMask2 = mLNConfig.labelMask();
            if (labelMask == null) {
                if (labelMask2 != null) {
                    return false;
                }
            } else if (!labelMask.equals(labelMask2)) {
                return false;
            }
            PrintMode print = print();
            PrintMode print2 = mLNConfig.print();
            if (print == null) {
                if (print2 != null) {
                    return false;
                }
            } else if (!print.equals(print2)) {
                return false;
            }
            Set<String> excludeParams = excludeParams();
            Set<String> excludeParams2 = mLNConfig.excludeParams();
            if (excludeParams == null) {
                if (excludeParams2 != null) {
                    return false;
                }
            } else if (!excludeParams.equals(excludeParams2)) {
                return false;
            }
            Consumer<MultiLayerNetwork> callEachIter = callEachIter();
            Consumer<MultiLayerNetwork> callEachIter2 = mLNConfig.callEachIter();
            return callEachIter == null ? callEachIter2 == null : callEachIter.equals(callEachIter2);
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof MLNConfig;
        }

        public int hashCode() {
            long doubleToLongBits = Double.doubleToLongBits(epsilon());
            int i = (1 * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
            long doubleToLongBits2 = Double.doubleToLongBits(maxRelError());
            int i2 = (i * 59) + ((int) ((doubleToLongBits2 >>> 32) ^ doubleToLongBits2));
            long doubleToLongBits3 = Double.doubleToLongBits(minAbsoluteError());
            int maxPerParam = (((((((i2 * 59) + ((int) ((doubleToLongBits3 >>> 32) ^ doubleToLongBits3))) * 59) + (exitOnFirstError() ? 79 : 97)) * 59) + (subset() ? 79 : 97)) * 59) + maxPerParam();
            MultiLayerNetwork net = net();
            int hashCode = (maxPerParam * 59) + (net == null ? 43 : net.hashCode());
            INDArray input = input();
            int hashCode2 = (hashCode * 59) + (input == null ? 43 : input.hashCode());
            INDArray labels = labels();
            int hashCode3 = (hashCode2 * 59) + (labels == null ? 43 : labels.hashCode());
            INDArray inputMask = inputMask();
            int hashCode4 = (hashCode3 * 59) + (inputMask == null ? 43 : inputMask.hashCode());
            INDArray labelMask = labelMask();
            int hashCode5 = (hashCode4 * 59) + (labelMask == null ? 43 : labelMask.hashCode());
            PrintMode print = print();
            int hashCode6 = (hashCode5 * 59) + (print == null ? 43 : print.hashCode());
            Set<String> excludeParams = excludeParams();
            int hashCode7 = (hashCode6 * 59) + (excludeParams == null ? 43 : excludeParams.hashCode());
            Consumer<MultiLayerNetwork> callEachIter = callEachIter();
            return (hashCode7 * 59) + (callEachIter == null ? 43 : callEachIter.hashCode());
        }

        public String toString() {
            MultiLayerNetwork net = net();
            INDArray input = input();
            INDArray labels = labels();
            INDArray inputMask = inputMask();
            INDArray labelMask = labelMask();
            double epsilon = epsilon();
            double maxRelError = maxRelError();
            double minAbsoluteError = minAbsoluteError();
            PrintMode print = print();
            boolean exitOnFirstError = exitOnFirstError();
            boolean subset = subset();
            maxPerParam();
            excludeParams();
            callEachIter();
            return "GradientCheckUtil.MLNConfig(net=" + net + ", input=" + input + ", labels=" + labels + ", inputMask=" + inputMask + ", labelMask=" + labelMask + ", epsilon=" + epsilon + ", maxRelError=" + net + ", minAbsoluteError=" + maxRelError + ", print=" + net + ", exitOnFirstError=" + minAbsoluteError + ", subset=" + net + ", maxPerParam=" + print + ", excludeParams=" + exitOnFirstError + ", callEachIter=" + subset + ")";
        }
    }

    /* loaded from: input_file:org/deeplearning4j/gradientcheck/GradientCheckUtil$PrintMode.class */
    public enum PrintMode {
        ALL,
        ZEROS,
        FAILURES_ONLY
    }

    private GradientCheckUtil() {
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v43, types: [org.deeplearning4j.nn.conf.layers.BaseLayer] */
    private static void configureLossFnClippingIfPresent(IOutputLayer iOutputLayer) {
        ILossFunction iLossFunction = null;
        IActivation iActivation = null;
        if (iOutputLayer instanceof BaseOutputLayer) {
            BaseOutputLayer baseOutputLayer = (BaseOutputLayer) iOutputLayer;
            iLossFunction = ((org.deeplearning4j.nn.conf.layers.BaseOutputLayer) baseOutputLayer.layerConf()).getLossFn();
            iActivation = baseOutputLayer.layerConf().getActivationFn();
        } else if (iOutputLayer instanceof LossLayer) {
            LossLayer lossLayer = (LossLayer) iOutputLayer;
            iLossFunction = lossLayer.layerConf().getLossFn();
            iActivation = lossLayer.layerConf().getActivationFn();
        }
        if ((iLossFunction instanceof LossMCXENT) && (iActivation instanceof ActivationSoftmax) && ((LossMCXENT) iLossFunction).getSoftmaxClipEps() != EvaluationBinary.DEFAULT_EDGE_VALUE) {
            log.info("Setting softmax clipping epsilon to 0.0 for " + iLossFunction.getClass() + " loss function to avoid spurious gradient check failures");
            ((LossMCXENT) iLossFunction).setSoftmaxClipEps(EvaluationBinary.DEFAULT_EDGE_VALUE);
        } else {
            if (!(iLossFunction instanceof LossBinaryXENT) || ((LossBinaryXENT) iLossFunction).getClipEps() == EvaluationBinary.DEFAULT_EDGE_VALUE) {
                return;
            }
            log.info("Setting clipping epsilon to 0.0 for " + iLossFunction.getClass() + " loss function to avoid spurious gradient check failures");
            ((LossBinaryXENT) iLossFunction).setClipEps(EvaluationBinary.DEFAULT_EDGE_VALUE);
        }
    }

    @Deprecated
    public static boolean checkGradients(MultiLayerNetwork multiLayerNetwork, double d, double d2, double d3, boolean z, boolean z2, INDArray iNDArray, INDArray iNDArray2) {
        return checkGradients(new MLNConfig().net(multiLayerNetwork).epsilon(d).maxRelError(d2).minAbsoluteError(d3).print(PrintMode.FAILURES_ONLY).exitOnFirstError(z2).input(iNDArray).labels(iNDArray2));
    }

    @Deprecated
    public static boolean checkGradients(MultiLayerNetwork multiLayerNetwork, double d, double d2, double d3, boolean z, boolean z2, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, boolean z3, int i, Set<String> set, final Integer num) {
        Consumer<MultiLayerNetwork> consumer = null;
        if (num != null) {
            consumer = new Consumer<MultiLayerNetwork>() { // from class: org.deeplearning4j.gradientcheck.GradientCheckUtil.1
                public void accept(MultiLayerNetwork multiLayerNetwork2) {
                    Nd4j.getRandom().setSeed(num.intValue());
                }
            };
        }
        return checkGradients(new MLNConfig().net(multiLayerNetwork).epsilon(d).maxRelError(d2).minAbsoluteError(d3).print(PrintMode.FAILURES_ONLY).exitOnFirstError(z2).input(iNDArray).labels(iNDArray2).inputMask(iNDArray3).labelMask(iNDArray4).subset(z3).maxPerParam(i).excludeParams(set).callEachIter(consumer));
    }

    public static boolean checkGradients(MLNConfig mLNConfig) {
        HashMap hashMap;
        long j;
        if (mLNConfig.epsilon <= EvaluationBinary.DEFAULT_EDGE_VALUE || mLNConfig.epsilon > 0.1d) {
            throw new IllegalArgumentException("Invalid epsilon: expect epsilon in range (0,0.1], usually 1e-4 or so");
        }
        if (mLNConfig.maxRelError <= EvaluationBinary.DEFAULT_EDGE_VALUE || mLNConfig.maxRelError > 0.25d) {
            throw new IllegalArgumentException("Invalid maxRelativeError: " + mLNConfig.maxRelError);
        }
        if (!(mLNConfig.net.getOutputLayer() instanceof IOutputLayer)) {
            throw new IllegalArgumentException("Cannot check backprop gradients without OutputLayer");
        }
        DataType dtypeFromContext = DataTypeUtil.getDtypeFromContext();
        if (dtypeFromContext != DataType.DOUBLE) {
            throw new IllegalStateException("Cannot perform gradient check: Datatype is not set to double precision (is: " + dtypeFromContext + "). Double precision must be used for gradient checks. Set DataTypeUtil.setDTypeForContext(DataType.DOUBLE); before using GradientCheckUtil");
        }
        DataType dataType = mLNConfig.net.getLayerWiseConfigurations().getDataType();
        if (dataType != DataType.DOUBLE) {
            throw new IllegalStateException("Cannot perform gradient check: Network datatype is not set to double precision (is: " + dataType + "). Double precision must be used for gradient checks. Create network with .dataType(DataType.DOUBLE) before using GradientCheckUtil");
        }
        if (dataType != mLNConfig.net.params().dataType()) {
            throw new IllegalStateException("Parameters datatype does not match network configuration datatype (is: " + mLNConfig.net.params().dataType() + "). If network datatype is set to DOUBLE, parameters must also be DOUBLE.");
        }
        for (NeuralNetConfiguration neuralNetConfiguration : mLNConfig.net.getLayerWiseConfigurations().getConfs()) {
            if (neuralNetConfiguration.getLayer() instanceof BaseLayer) {
                BaseLayer baseLayer = (BaseLayer) neuralNetConfiguration.getLayer();
                Sgd iUpdater = baseLayer.getIUpdater();
                if (iUpdater instanceof Sgd) {
                    double learningRate = iUpdater.getLearningRate();
                    if (learningRate != 1.0d) {
                        neuralNetConfiguration.getLayer().getLayerName();
                        IllegalStateException illegalStateException = new IllegalStateException("When using SGD updater, must also use lr=1.0 for layer " + 0 + "; got " + iUpdater + " with lr=" + learningRate + " for layer \"" + illegalStateException + "\"");
                        throw illegalStateException;
                    }
                } else if (!(iUpdater instanceof NoOp)) {
                    throw new IllegalStateException("Must have Updater.NONE (or SGD + lr=1.0) for layer " + 0 + "; got " + iUpdater);
                }
                IActivation activationFn = baseLayer.getActivationFn();
                if (activationFn != null && !VALID_ACTIVATION_FUNCTIONS.contains(activationFn.getClass())) {
                    log.warn("Layer " + 0 + " is possibly using an unsuitable activation function: " + activationFn.getClass() + ". Activation functions for gradient checks must be smooth (like sigmoid, tanh, softmax) and not contain discontinuities like ReLU or LeakyReLU (these may cause spurious failures)");
                }
            }
            if (neuralNetConfiguration.getLayer().getIDropout() != null && mLNConfig.callEachIter == null) {
                throw new IllegalStateException("When gradient checking dropout, need to reset RNG seed each iter, or no dropout should be present during gradient checks - got dropout = " + neuralNetConfiguration.getLayer().getIDropout() + " for layer " + 0);
            }
        }
        for (Layer layer : mLNConfig.net.getLayers()) {
            if (layer instanceof IOutputLayer) {
                configureLossFnClippingIfPresent((IOutputLayer) layer);
            }
        }
        mLNConfig.net.setInput(mLNConfig.input);
        mLNConfig.net.setLabels(mLNConfig.labels);
        mLNConfig.net.setLayerMaskArrays(mLNConfig.inputMask, mLNConfig.labelMask);
        if (mLNConfig.callEachIter != null) {
            mLNConfig.callEachIter.accept(mLNConfig.net);
        }
        mLNConfig.net.computeGradientAndScore();
        Pair<Gradient, Double> gradientAndScore = mLNConfig.net.gradientAndScore();
        UpdaterCreator.getUpdater(mLNConfig.net).update(mLNConfig.net, (Gradient) gradientAndScore.getFirst(), 0, 0, mLNConfig.net.batchSize(), LayerWorkspaceMgr.noWorkspaces());
        INDArray dup = ((Gradient) gradientAndScore.getFirst()).gradient().dup();
        long length = mLNConfig.net.params().dup().length();
        Map<String, INDArray> paramTable = mLNConfig.net.paramTable();
        ArrayList arrayList = new ArrayList(paramTable.keySet());
        long[] jArr = new long[arrayList.size()];
        jArr[0] = paramTable.get(arrayList.get(0)).length();
        if (mLNConfig.subset) {
            hashMap = new HashMap();
            hashMap.put((String) arrayList.get(0), Integer.valueOf((int) Math.max(1L, paramTable.get(arrayList.get(0)).length() / mLNConfig.maxPerParam)));
        } else {
            hashMap = null;
        }
        for (int i = 1; i < jArr.length; i++) {
            long length2 = paramTable.get(arrayList.get(i)).length();
            jArr[i] = jArr[i - 1] + length2;
            if (mLNConfig.subset) {
                long j2 = length2 / mLNConfig.maxPerParam;
                if (j2 == 0) {
                    j2 = length2;
                }
                if (j2 > 2147483647L) {
                    throw new ND4JArraySizeException();
                }
                hashMap.put((String) arrayList.get(i), Integer.valueOf((int) j2));
            }
        }
        if (mLNConfig.print == PrintMode.ALL) {
            int i2 = 0;
            for (Layer layer2 : mLNConfig.net.getLayers()) {
                log.info("Layer " + i2 + ": " + layer2.getClass().getSimpleName() + " - params " + layer2.paramTable().keySet());
                i2++;
            }
        }
        int i3 = 0;
        double d = 0.0d;
        DataSet dataSet = new DataSet(mLNConfig.input, mLNConfig.labels, mLNConfig.inputMask, mLNConfig.labelMask);
        int i4 = 0;
        if (mLNConfig.excludeParams != null && !mLNConfig.excludeParams.isEmpty()) {
            log.info("NOTE: parameters will be skipped due to config: {}", mLNConfig.excludeParams);
        }
        INDArray params = mLNConfig.net.params();
        long j3 = 0;
        while (true) {
            long j4 = j3;
            if (j4 >= length) {
                long j5 = length - i3;
                Logger logger = log;
                logger.info("GradientCheckUtil.checkGradients(): " + length + " params checked, " + logger + " passed, " + j5 + " failed. Largest relative error = " + logger);
                return i3 == 0;
            }
            if (j4 >= jArr[i4]) {
                i4++;
            }
            String str = (String) arrayList.get(i4);
            if (mLNConfig.excludeParams == null || !mLNConfig.excludeParams.contains(str)) {
                double d2 = params.getDouble(j4);
                params.putScalar(j4, d2 + mLNConfig.epsilon);
                if (mLNConfig.callEachIter != null) {
                    mLNConfig.callEachIter.accept(mLNConfig.net);
                }
                double score = mLNConfig.net.score(dataSet, true);
                params.putScalar(j4, d2 - mLNConfig.epsilon);
                if (mLNConfig.callEachIter != null) {
                    mLNConfig.callEachIter.accept(mLNConfig.net);
                }
                double score2 = mLNConfig.net.score(dataSet, true);
                params.putScalar(j4, d2);
                double d3 = (score - score2) / (2.0d * mLNConfig.epsilon);
                if (Double.isNaN(d3)) {
                    IllegalStateException illegalStateException2 = new IllegalStateException("Numerical gradient was NaN for parameter " + j4 + " of " + illegalStateException2);
                    throw illegalStateException2;
                }
                double d4 = dup.getDouble(j4);
                double abs = Math.abs(d4 - d3) / (Math.abs(d3) + Math.abs(d4));
                if (d4 == EvaluationBinary.DEFAULT_EDGE_VALUE && d3 == EvaluationBinary.DEFAULT_EDGE_VALUE) {
                    abs = 0.0d;
                }
                if (abs > d) {
                    d = abs;
                }
                if (abs > mLNConfig.maxRelError || Double.isNaN(abs)) {
                    double abs2 = Math.abs(d4 - d3);
                    if (abs2 >= mLNConfig.minAbsoluteError) {
                        Logger logger2 = log;
                        logger2.info("Param " + j4 + " (" + logger2 + ") FAILED: grad= " + str + ", numericalGrad= " + d4 + ", relError= " + logger2 + ", scorePlus=" + d3 + ", scoreMinus= " + logger2 + ", paramValue = " + abs);
                        if (mLNConfig.exitOnFirstError) {
                            return false;
                        }
                        i3++;
                    } else if (mLNConfig.print == PrintMode.ALL || (mLNConfig.print == PrintMode.ZEROS && abs2 == EvaluationBinary.DEFAULT_EDGE_VALUE)) {
                        Logger logger3 = log;
                        double d5 = mLNConfig.minAbsoluteError;
                        logger3.info("Param " + j4 + " (" + logger3 + ") passed: grad= " + str + ", numericalGrad= " + d4 + ", relError= " + logger3 + "; absolute error = " + d3 + " < minAbsoluteError = " + logger3);
                    }
                } else if (mLNConfig.print == PrintMode.ALL) {
                    Logger logger4 = log;
                    logger4.info("Param " + j4 + " (" + logger4 + ") passed: grad= " + str + ", numericalGrad= " + d4 + ", relError= " + logger4);
                }
                if (mLNConfig.subset) {
                    j = ((Integer) hashMap.get(str)).intValue();
                    if (j4 + j > jArr[i4] + 1) {
                        j = (jArr[i4] + 1) - j4;
                    }
                } else {
                    j = 1;
                }
                j3 = j4 + j;
            } else {
                int i5 = i4;
                i4++;
                j3 = jArr[i5];
            }
        }
    }

    public static boolean checkGradients(GraphConfig graphConfig) {
        if (graphConfig.epsilon <= EvaluationBinary.DEFAULT_EDGE_VALUE || graphConfig.epsilon > 0.1d) {
            throw new IllegalArgumentException("Invalid epsilon: expect epsilon in range (0,0.1], usually 1e-4 or so");
        }
        if (graphConfig.maxRelError <= EvaluationBinary.DEFAULT_EDGE_VALUE || graphConfig.maxRelError > 0.25d) {
            throw new IllegalArgumentException("Invalid maxRelativeError: " + graphConfig.maxRelError);
        }
        if (graphConfig.net.getNumInputArrays() != graphConfig.inputs.length) {
            throw new IllegalArgumentException("Invalid input arrays: expect " + graphConfig.net.getNumInputArrays() + " inputs");
        }
        if (graphConfig.net.getNumOutputArrays() != graphConfig.labels.length) {
            throw new IllegalArgumentException("Invalid labels arrays: expect " + graphConfig.net.getNumOutputArrays() + " outputs");
        }
        DataType dtypeFromContext = DataTypeUtil.getDtypeFromContext();
        if (dtypeFromContext != DataType.DOUBLE) {
            throw new IllegalStateException("Cannot perform gradient check: Datatype is not set to double precision (is: " + dtypeFromContext + "). Double precision must be used for gradient checks. Set DataTypeUtil.setDTypeForContext(DataType.DOUBLE); before using GradientCheckUtil");
        }
        DataType dataType = graphConfig.net.getConfiguration().getDataType();
        if (dataType != DataType.DOUBLE) {
            throw new IllegalStateException("Cannot perform gradient check: Network datatype is not set to double precision (is: " + dataType + "). Double precision must be used for gradient checks. Create network with .dataType(DataType.DOUBLE) before using GradientCheckUtil");
        }
        if (dataType != graphConfig.net.params().dataType()) {
            throw new IllegalStateException("Parameters datatype does not match network configuration datatype (is: " + graphConfig.net.params().dataType() + "). If network datatype is set to DOUBLE, parameters must also be DOUBLE.");
        }
        for (String str : graphConfig.net.getConfiguration().getVertices().keySet()) {
            GraphVertex graphVertex = graphConfig.net.getConfiguration().getVertices().get(str);
            if (graphVertex instanceof LayerVertex) {
                LayerVertex layerVertex = (LayerVertex) graphVertex;
                if (layerVertex.getLayerConf().getLayer() instanceof BaseLayer) {
                    BaseLayer baseLayer = (BaseLayer) layerVertex.getLayerConf().getLayer();
                    Sgd iUpdater = baseLayer.getIUpdater();
                    if (iUpdater instanceof Sgd) {
                        double learningRate = iUpdater.getLearningRate();
                        if (learningRate != 1.0d) {
                            layerVertex.getLayerConf().getLayer().getLayerName();
                            IllegalStateException illegalStateException = new IllegalStateException("When using SGD updater, must also use lr=1.0 for layer " + 0 + "; got " + iUpdater + " with lr=" + learningRate + " for layer \"" + illegalStateException + "\"");
                            throw illegalStateException;
                        }
                    } else if (!(iUpdater instanceof NoOp)) {
                        throw new IllegalStateException("Must have Updater.NONE (or SGD + lr=1.0) for layer " + 0 + "; got " + iUpdater);
                    }
                    IActivation activationFn = baseLayer.getActivationFn();
                    if (activationFn != null && !VALID_ACTIVATION_FUNCTIONS.contains(activationFn.getClass())) {
                        log.warn("Layer \"" + str + "\" is possibly using an unsuitable activation function: " + activationFn.getClass() + ". Activation functions for gradient checks must be smooth (like sigmoid, tanh, softmax) and not contain discontinuities like ReLU or LeakyReLU (these may cause spurious failures)");
                    }
                }
                if (layerVertex.getLayerConf().getLayer().getIDropout() != null && graphConfig.callEachIter == null) {
                    throw new IllegalStateException("When gradient checking dropout, rng seed must be reset each iteration, or no dropout should be present during gradient checks - got dropout = " + layerVertex.getLayerConf().getLayer().getIDropout() + " for layer " + 0);
                }
            }
        }
        for (Layer layer : graphConfig.net.getLayers()) {
            if (layer instanceof IOutputLayer) {
                configureLossFnClippingIfPresent((IOutputLayer) layer);
            }
        }
        for (int i = 0; i < graphConfig.inputs.length; i++) {
            graphConfig.net.setInput(i, graphConfig.inputs[i]);
        }
        for (int i2 = 0; i2 < graphConfig.labels.length; i2++) {
            graphConfig.net.setLabel(i2, graphConfig.labels[i2]);
        }
        graphConfig.net.setLayerMaskArrays(graphConfig.inputMask, graphConfig.labelMask);
        if (graphConfig.callEachIter != null) {
            graphConfig.callEachIter.accept(graphConfig.net);
        }
        graphConfig.net.computeGradientAndScore();
        Pair<Gradient, Double> gradientAndScore = graphConfig.net.gradientAndScore();
        new ComputationGraphUpdater(graphConfig.net).update((Gradient) gradientAndScore.getFirst(), 0, 0, graphConfig.net.batchSize(), LayerWorkspaceMgr.noWorkspaces());
        INDArray dup = ((Gradient) gradientAndScore.getFirst()).gradient().dup();
        long length = graphConfig.net.params().dup().length();
        Map<String, INDArray> paramTable = graphConfig.net.paramTable();
        ArrayList arrayList = new ArrayList(paramTable.keySet());
        long[] jArr = new long[arrayList.size()];
        jArr[0] = paramTable.get(arrayList.get(0)).length();
        for (int i3 = 1; i3 < jArr.length; i3++) {
            jArr[i3] = jArr[i3 - 1] + paramTable.get(arrayList.get(i3)).length();
        }
        if (graphConfig.excludeParams != null && !graphConfig.excludeParams.isEmpty()) {
            log.info("NOTE: parameters will be skipped due to config: {}", graphConfig.excludeParams);
        }
        int i4 = 0;
        int i5 = 0;
        double d = 0.0d;
        MultiDataSet multiDataSet = new org.nd4j.linalg.dataset.MultiDataSet(graphConfig.inputs, graphConfig.labels, graphConfig.inputMask, graphConfig.labelMask);
        INDArray params = graphConfig.net.params();
        long j = 0;
        while (true) {
            long j2 = j;
            if (j2 >= length) {
                long j3 = length - i5;
                Logger logger = log;
                logger.info("GradientCheckUtil.checkGradients(): " + length + " params checked, " + logger + " passed, " + j3 + " failed. Largest relative error = " + logger);
                return i5 == 0;
            }
            if (j2 >= jArr[i4]) {
                i4++;
            }
            String str2 = (String) arrayList.get(i4);
            if (graphConfig.excludeParams == null || !graphConfig.excludeParams.contains(str2)) {
                double d2 = params.getDouble(j2);
                params.putScalar(j2, d2 + graphConfig.epsilon);
                if (graphConfig.callEachIter != null) {
                    graphConfig.callEachIter.accept(graphConfig.net);
                }
                double score = graphConfig.net.score(multiDataSet, true);
                params.putScalar(j2, d2 - graphConfig.epsilon);
                if (graphConfig.callEachIter != null) {
                    graphConfig.callEachIter.accept(graphConfig.net);
                }
                double score2 = graphConfig.net.score(multiDataSet, true);
                params.putScalar(j2, d2);
                double d3 = (score - score2) / (2.0d * graphConfig.epsilon);
                if (Double.isNaN(d3)) {
                    IllegalStateException illegalStateException2 = new IllegalStateException("Numerical gradient was NaN for parameter " + j2 + " of " + illegalStateException2);
                    throw illegalStateException2;
                }
                double d4 = dup.getDouble(j2);
                double abs = Math.abs(d4 - d3) / (Math.abs(d3) + Math.abs(d4));
                if (d4 == EvaluationBinary.DEFAULT_EDGE_VALUE && d3 == EvaluationBinary.DEFAULT_EDGE_VALUE) {
                    abs = 0.0d;
                }
                if (abs > d) {
                    d = abs;
                }
                if (abs > graphConfig.maxRelError || Double.isNaN(abs)) {
                    double abs2 = Math.abs(d4 - d3);
                    if (abs2 >= graphConfig.minAbsoluteError) {
                        Logger logger2 = log;
                        logger2.info("Param " + j2 + " (" + logger2 + ") FAILED: grad= " + str2 + ", numericalGrad= " + d4 + ", relError= " + logger2 + ", scorePlus=" + d3 + ", scoreMinus= " + logger2 + ", paramValue = " + abs);
                        if (graphConfig.exitOnFirstError) {
                            return false;
                        }
                        i5++;
                    } else if (graphConfig.print == PrintMode.ALL || (graphConfig.print == PrintMode.ZEROS && abs2 == EvaluationBinary.DEFAULT_EDGE_VALUE)) {
                        Logger logger3 = log;
                        double d5 = graphConfig.minAbsoluteError;
                        logger3.info("Param " + j2 + " (" + logger3 + ") passed: grad= " + str2 + ", numericalGrad= " + d4 + ", relError= " + logger3 + "; absolute error = " + d3 + " < minAbsoluteError = " + logger3);
                    }
                } else if (graphConfig.print == PrintMode.ALL) {
                    Logger logger4 = log;
                    logger4.info("Param " + j2 + " (" + logger4 + ") passed: grad= " + str2 + ", numericalGrad= " + d4 + ", relError= " + logger4);
                }
            } else {
                int i6 = i4;
                i4++;
                j2 = jArr[i6];
            }
            j = j2 + 1;
        }
    }

    public static boolean checkGradientsPretrainLayer(Layer layer, double d, double d2, double d3, boolean z, boolean z2, INDArray iNDArray, int i) {
        LayerWorkspaceMgr noWorkspaces = LayerWorkspaceMgr.noWorkspaces();
        if (d <= EvaluationBinary.DEFAULT_EDGE_VALUE || d > 0.1d) {
            throw new IllegalArgumentException("Invalid epsilon: expect epsilon in range (0,0.1], usually 1e-4 or so");
        }
        if (d2 <= EvaluationBinary.DEFAULT_EDGE_VALUE || d2 > 0.25d) {
            throw new IllegalArgumentException("Invalid maxRelativeError: " + d2);
        }
        DataType dtypeFromContext = DataTypeUtil.getDtypeFromContext();
        if (dtypeFromContext != DataType.DOUBLE) {
            throw new IllegalStateException("Cannot perform gradient check: Datatype is not set to double precision (is: " + dtypeFromContext + "). Double precision must be used for gradient checks. Set DataTypeUtil.setDTypeForContext(DataType.DOUBLE); before using GradientCheckUtil");
        }
        layer.setInput(iNDArray, LayerWorkspaceMgr.noWorkspaces());
        Nd4j.getRandom().setSeed(i);
        layer.computeGradientAndScore(noWorkspaces);
        Pair<Gradient, Double> gradientAndScore = layer.gradientAndScore();
        UpdaterCreator.getUpdater(layer).update(layer, (Gradient) gradientAndScore.getFirst(), 0, 0, layer.batchSize(), LayerWorkspaceMgr.noWorkspaces());
        INDArray dup = ((Gradient) gradientAndScore.getFirst()).gradient().dup();
        long length = layer.params().dup().length();
        Map<String, INDArray> paramTable = layer.paramTable();
        ArrayList arrayList = new ArrayList(paramTable.keySet());
        long[] jArr = new long[arrayList.size()];
        jArr[0] = paramTable.get(arrayList.get(0)).length();
        for (int i2 = 1; i2 < jArr.length; i2++) {
            jArr[i2] = jArr[i2 - 1] + paramTable.get(arrayList.get(i2)).length();
        }
        int i3 = 0;
        double d4 = 0.0d;
        int i4 = 0;
        INDArray params = layer.params();
        for (int i5 = 0; i5 < length; i5++) {
            if (i5 >= jArr[i4]) {
                i4++;
            }
            String str = (String) arrayList.get(i4);
            double d5 = params.getDouble(i5);
            params.putScalar(i5, d5 + d);
            Nd4j.getRandom().setSeed(i);
            layer.computeGradientAndScore(noWorkspaces);
            double score = layer.score();
            params.putScalar(i5, d5 - d);
            Nd4j.getRandom().setSeed(i);
            layer.computeGradientAndScore(noWorkspaces);
            double score2 = layer.score();
            params.putScalar(i5, d5);
            double d6 = (score - score2) / (2.0d * d);
            if (Double.isNaN(d6)) {
                throw new IllegalStateException("Numerical gradient was NaN for parameter " + i5 + " of " + length);
            }
            double d7 = dup.getDouble(i5);
            double abs = Math.abs(d7 - d6) / (Math.abs(d6) + Math.abs(d7));
            if (d7 == EvaluationBinary.DEFAULT_EDGE_VALUE && d6 == EvaluationBinary.DEFAULT_EDGE_VALUE) {
                abs = 0.0d;
            }
            if (abs > d4) {
                d4 = abs;
            }
            if (abs > d2 || Double.isNaN(abs)) {
                if (Math.abs(d7 - d6) < d3) {
                    Logger logger = log;
                    logger.info("Param " + i5 + " (" + str + ") passed: grad= " + d7 + ", numericalGrad= " + logger + ", relError= " + d6 + "; absolute error = " + logger + " < minAbsoluteError = " + abs);
                } else {
                    if (z) {
                        Logger logger2 = log;
                        logger2.info("Param " + i5 + " (" + str + ") FAILED: grad= " + d7 + ", numericalGrad= " + logger2 + ", relError= " + d6 + ", scorePlus=" + logger2 + ", scoreMinus= " + abs + ", paramValue = " + logger2);
                    }
                    if (z2) {
                        return false;
                    }
                    i3++;
                }
            } else if (z) {
                Logger logger3 = log;
                logger3.info("Param " + i5 + " (" + str + ") passed: grad= " + d7 + ", numericalGrad= " + logger3 + ", relError= " + d6);
            }
        }
        if (z) {
            long j = length - i3;
            Logger logger4 = log;
            logger4.info("GradientCheckUtil.checkGradients(): " + length + " params checked, " + logger4 + " passed, " + j + " failed. Largest relative error = " + logger4);
        }
        return i3 == 0;
    }
}
