package org.deeplearning4j.nn.params;

import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/deeplearning4j/nn/params/BatchNormalizationParamInitializer.class */
public class BatchNormalizationParamInitializer implements ParamInitializer {
    private static final BatchNormalizationParamInitializer INSTANCE = new BatchNormalizationParamInitializer();
    public static final String GAMMA = "gamma";
    public static final String BETA = "beta";
    public static final String GLOBAL_MEAN = "mean";
    public static final String GLOBAL_VAR = "var";
    public static final String GLOBAL_LOG_STD = "log10stdev";

    public static BatchNormalizationParamInitializer getInstance() {
        return INSTANCE;
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public long numParams(NeuralNetConfiguration neuralNetConfiguration) {
        return numParams(neuralNetConfiguration.getLayer());
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public long numParams(Layer layer) {
        BatchNormalization batchNormalization = (BatchNormalization) layer;
        return batchNormalization.isLockGammaBeta() ? 2 * batchNormalization.getNOut() : 4 * batchNormalization.getNOut();
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public List<String> paramKeys(Layer layer) {
        return ((BatchNormalization) layer).isUseLogStd() ? Arrays.asList(GAMMA, BETA, GLOBAL_MEAN, GLOBAL_LOG_STD) : Arrays.asList(GAMMA, BETA, GLOBAL_MEAN, GLOBAL_VAR);
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public List<String> weightKeys(Layer layer) {
        return Collections.emptyList();
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public List<String> biasKeys(Layer layer) {
        return Collections.emptyList();
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public boolean isWeightParam(Layer layer, String str) {
        return false;
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public boolean isBiasParam(Layer layer, String str) {
        return false;
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public Map<String, INDArray> init(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray, boolean z) {
        Map<String, INDArray> synchronizedMap = Collections.synchronizedMap(new LinkedHashMap());
        BatchNormalization batchNormalization = (BatchNormalization) neuralNetConfiguration.getLayer();
        long nOut = batchNormalization.getNOut();
        long j = 0;
        INDArray reshape = iNDArray.reshape(new long[]{iNDArray.length()});
        if (!batchNormalization.isLockGammaBeta()) {
            INDArray iNDArray2 = reshape.get(new INDArrayIndex[]{NDArrayIndex.interval(0L, nOut)});
            INDArray iNDArray3 = reshape.get(new INDArrayIndex[]{NDArrayIndex.interval(nOut, 2 * nOut)});
            synchronizedMap.put(GAMMA, createGamma(neuralNetConfiguration, iNDArray2, z));
            neuralNetConfiguration.addVariable(GAMMA);
            synchronizedMap.put(BETA, createBeta(neuralNetConfiguration, iNDArray3, z));
            neuralNetConfiguration.addVariable(BETA);
            j = 2 * nOut;
        }
        INDArray iNDArray4 = reshape.get(new INDArrayIndex[]{NDArrayIndex.interval(j, j + nOut)});
        INDArray iNDArray5 = reshape.get(new INDArrayIndex[]{NDArrayIndex.interval(j + nOut, j + (2 * nOut))});
        if (z) {
            iNDArray4.assign(0);
            if (batchNormalization.isUseLogStd()) {
                iNDArray5.assign(0);
            } else {
                iNDArray5.assign(1);
            }
        }
        synchronizedMap.put(GLOBAL_MEAN, iNDArray4);
        neuralNetConfiguration.addVariable(GLOBAL_MEAN);
        if (batchNormalization.isUseLogStd()) {
            synchronizedMap.put(GLOBAL_LOG_STD, iNDArray5);
            neuralNetConfiguration.addVariable(GLOBAL_LOG_STD);
        } else {
            synchronizedMap.put(GLOBAL_VAR, iNDArray5);
            neuralNetConfiguration.addVariable(GLOBAL_VAR);
        }
        return synchronizedMap;
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        BatchNormalization batchNormalization = (BatchNormalization) neuralNetConfiguration.getLayer();
        long nOut = batchNormalization.getNOut();
        INDArray reshape = iNDArray.reshape(new long[]{iNDArray.length()});
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        long j = 0;
        if (!batchNormalization.isLockGammaBeta()) {
            INDArray iNDArray2 = reshape.get(new INDArrayIndex[]{NDArrayIndex.interval(0L, nOut)});
            INDArray iNDArray3 = reshape.get(new INDArrayIndex[]{NDArrayIndex.interval(nOut, 2 * nOut)});
            linkedHashMap.put(GAMMA, iNDArray2);
            linkedHashMap.put(BETA, iNDArray3);
            j = 2 * nOut;
        }
        linkedHashMap.put(GLOBAL_MEAN, reshape.get(new INDArrayIndex[]{NDArrayIndex.interval(j, j + nOut)}));
        if (batchNormalization.isUseLogStd()) {
            linkedHashMap.put(GLOBAL_LOG_STD, reshape.get(new INDArrayIndex[]{NDArrayIndex.interval(j + nOut, j + (2 * nOut))}));
        } else {
            linkedHashMap.put(GLOBAL_VAR, reshape.get(new INDArrayIndex[]{NDArrayIndex.interval(j + nOut, j + (2 * nOut))}));
        }
        return linkedHashMap;
    }

    private INDArray createBeta(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray, boolean z) {
        BatchNormalization batchNormalization = (BatchNormalization) neuralNetConfiguration.getLayer();
        if (z) {
            iNDArray.assign(Double.valueOf(batchNormalization.getBeta()));
        }
        return iNDArray;
    }

    private INDArray createGamma(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray, boolean z) {
        BatchNormalization batchNormalization = (BatchNormalization) neuralNetConfiguration.getLayer();
        if (z) {
            iNDArray.assign(Double.valueOf(batchNormalization.getGamma()));
        }
        return iNDArray;
    }
}
