package org.deeplearning4j.nn.layers.convolution;

import java.util.Arrays;
import java.util.List;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.EvaluationBinary;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.util.ConvolutionUtils;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

/* loaded from: input_file:org/deeplearning4j/nn/layers/convolution/CnnLossLayer.class */
public class CnnLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.CnnLossLayer> implements IOutputLayer {
    protected INDArray labels;

    public CnnLossLayer(NeuralNetConfiguration neuralNetConfiguration, DataType dataType) {
        super(neuralNetConfiguration, dataType);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        assertInputSet(true);
        if (this.input.rank() != 4) {
            throw new UnsupportedOperationException("Input is not rank 4. Got input with rank " + this.input.rank() + " " + layerId() + " with shape " + Arrays.toString(this.input.shape()) + " - expected shape " + layerConf().getFormat().dimensionNames());
        }
        if (this.labels == null) {
            throw new IllegalStateException("Labels are not set (null)");
        }
        Preconditions.checkState(this.input.equalShapes(this.labels), "Input and label arrays do not have same shape: %ndShape vs. %ndShape", this.input, this.labels);
        CNN2DFormat format = layerConf().getFormat();
        INDArray reshape4dTo2d = ConvolutionUtils.reshape4dTo2d(this.input, format, layerWorkspaceMgr, ArrayType.FF_WORKING_MEM);
        return new Pair<>(new DefaultGradient(), ConvolutionUtils.reshape2dTo4d(layerWorkspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, layerConf().getLossFn().computeGradient(ConvolutionUtils.reshape4dTo2d(this.labels, format, layerWorkspaceMgr, ArrayType.FF_WORKING_MEM), reshape4dTo2d.dup(reshape4dTo2d.ordering()), layerConf().getActivationFn(), ConvolutionUtils.reshapeMaskIfRequired(this.maskArray, this.input, format, layerWorkspaceMgr, ArrayType.FF_WORKING_MEM))), this.input.shape(), format, layerWorkspaceMgr, ArrayType.ACTIVATION_GRAD));
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public double calcRegularizationScore(boolean z) {
        return EvaluationBinary.DEFAULT_EDGE_VALUE;
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public double f1Score(DataSet dataSet) {
        return EvaluationBinary.DEFAULT_EDGE_VALUE;
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public double f1Score(INDArray iNDArray, INDArray iNDArray2) {
        INDArray activate = activate(iNDArray, false, null);
        Evaluation evaluation = new Evaluation();
        evaluation.evalTimeSeries(iNDArray2, activate, this.maskArray);
        return evaluation.f1();
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public int numLabels() {
        return (int) this.labels.size(1);
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void fit(DataSetIterator dataSetIterator) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public int[] predict(INDArray iNDArray) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public List<String> predict(DataSet dataSet) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void fit(INDArray iNDArray, INDArray iNDArray2) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void fit(DataSet dataSet) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void fit(INDArray iNDArray, int[] iArr) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public Layer.Type type() {
        return Layer.Type.CONVOLUTIONAL;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        assertInputSet(false);
        if (this.input.rank() != 4) {
            throw new UnsupportedOperationException("Input must be rank 4 with shape " + layerConf().getFormat().dimensionNames() + ". Got input with rank " + this.input.rank() + " " + layerId());
        }
        CNN2DFormat format = layerConf().getFormat();
        return ConvolutionUtils.reshape2dTo4d(layerConf().getActivationFn().getActivation(ConvolutionUtils.reshape4dTo2d(layerWorkspaceMgr.dup(ArrayType.ACTIVATIONS, this.input, this.input.ordering()), format, layerWorkspaceMgr, ArrayType.ACTIVATIONS), z), this.input.shape(), format, layerWorkspaceMgr, ArrayType.ACTIVATIONS);
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public void setMaskArray(INDArray iNDArray) {
        this.maskArray = iNDArray;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public boolean isPretrainLayer() {
        return false;
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray iNDArray, MaskState maskState, int i) {
        this.maskArray = iNDArray;
        return null;
    }

    @Override // org.deeplearning4j.nn.api.layers.IOutputLayer
    public boolean needsLabels() {
        return true;
    }

    @Override // org.deeplearning4j.nn.api.layers.IOutputLayer
    public double computeScore(double d, boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        double computeScore = (layerConf().getLossFn().computeScore(ConvolutionUtils.reshape4dTo2d(this.labels, layerWorkspaceMgr, ArrayType.FF_WORKING_MEM), ConvolutionUtils.reshape4dTo2d(this.input, layerWorkspaceMgr, ArrayType.FF_WORKING_MEM).dup(), layerConf().getActivationFn(), ConvolutionUtils.reshapeMaskIfRequired(this.maskArray, this.input, layerConf().getFormat(), layerWorkspaceMgr, ArrayType.FF_WORKING_MEM), false) / getInputMiniBatchSize()) + d;
        this.score = computeScore;
        return computeScore;
    }

    @Override // org.deeplearning4j.nn.api.layers.IOutputLayer
    public INDArray computeScoreForExamples(double d, LayerWorkspaceMgr layerWorkspaceMgr) {
        if (this.input == null || this.labels == null) {
            throw new IllegalStateException("Cannot calculate score without input and labels " + layerId());
        }
        CNN2DFormat format = layerConf().getFormat();
        INDArray computeScoreArray = layerConf().getLossFn().computeScoreArray(ConvolutionUtils.reshape4dTo2d(this.labels, format, layerWorkspaceMgr, ArrayType.FF_WORKING_MEM), ConvolutionUtils.reshape4dTo2d(this.input, format, layerWorkspaceMgr, ArrayType.FF_WORKING_MEM), layerConf().getActivationFn(), ConvolutionUtils.reshapeMaskIfRequired(this.maskArray, this.input, format, layerWorkspaceMgr, ArrayType.FF_WORKING_MEM));
        long[] jArr = (long[]) this.input.shape().clone();
        jArr[1] = 1;
        INDArray reshape2dTo4d = ConvolutionUtils.reshape2dTo4d(computeScoreArray, jArr, format, layerWorkspaceMgr, ArrayType.FF_WORKING_MEM);
        INDArray reshape = reshape2dTo4d.sum(new int[]{1, 2, 3}).reshape(reshape2dTo4d.size(0), 1L);
        if (d != EvaluationBinary.DEFAULT_EDGE_VALUE) {
            reshape.addi(Double.valueOf(d));
        }
        return layerWorkspaceMgr.leverageTo(ArrayType.ACTIVATIONS, reshape);
    }

    @Override // org.deeplearning4j.nn.api.layers.IOutputLayer
    public void setLabels(INDArray iNDArray) {
        this.labels = iNDArray;
    }

    @Override // org.deeplearning4j.nn.api.layers.IOutputLayer
    public INDArray getLabels() {
        return this.labels;
    }
}
