package org.deeplearning4j.nn.conf.layers.variational;

import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.shade.jackson.annotation.JsonProperty;

/* loaded from: input_file:org/deeplearning4j/nn/conf/layers/variational/LossFunctionWrapper.class */
public class LossFunctionWrapper implements ReconstructionDistribution {
    private final IActivation activationFn;
    private final ILossFunction lossFunction;

    public LossFunctionWrapper(@JsonProperty("activationFn") IActivation iActivation, @JsonProperty("lossFunction") ILossFunction iLossFunction) {
        this.activationFn = iActivation;
        this.lossFunction = iLossFunction;
    }

    public LossFunctionWrapper(Activation activation, ILossFunction iLossFunction) {
        this(activation.getActivationFunction(), iLossFunction);
    }

    @Override // org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution
    public boolean hasLossFunction() {
        return true;
    }

    @Override // org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution
    public int distributionInputSize(int i) {
        return i;
    }

    @Override // org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution
    public double negLogProbability(INDArray iNDArray, INDArray iNDArray2, boolean z) {
        return this.lossFunction.computeScore(iNDArray, iNDArray2, this.activationFn, (INDArray) null, z);
    }

    @Override // org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution
    public INDArray exampleNegLogProbability(INDArray iNDArray, INDArray iNDArray2) {
        return this.lossFunction.computeScoreArray(iNDArray, iNDArray2, this.activationFn, (INDArray) null);
    }

    @Override // org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution
    public INDArray gradient(INDArray iNDArray, INDArray iNDArray2) {
        return this.lossFunction.computeGradient(iNDArray, iNDArray2, this.activationFn, (INDArray) null);
    }

    @Override // org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution
    public INDArray generateRandom(INDArray iNDArray) {
        return generateAtMean(iNDArray);
    }

    @Override // org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution
    public INDArray generateAtMean(INDArray iNDArray) {
        return this.activationFn.getActivation(iNDArray.dup(), true);
    }

    public String toString() {
        return "LossFunctionWrapper(afn=" + this.activationFn + "," + this.lossFunction + ")";
    }

    public IActivation getActivationFn() {
        return this.activationFn;
    }

    public ILossFunction getLossFunction() {
        return this.lossFunction;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof LossFunctionWrapper)) {
            return false;
        }
        LossFunctionWrapper lossFunctionWrapper = (LossFunctionWrapper) obj;
        if (!lossFunctionWrapper.canEqual(this)) {
            return false;
        }
        IActivation activationFn = getActivationFn();
        IActivation activationFn2 = lossFunctionWrapper.getActivationFn();
        if (activationFn == null) {
            if (activationFn2 != null) {
                return false;
            }
        } else if (!activationFn.equals(activationFn2)) {
            return false;
        }
        ILossFunction lossFunction = getLossFunction();
        ILossFunction lossFunction2 = lossFunctionWrapper.getLossFunction();
        return lossFunction == null ? lossFunction2 == null : lossFunction.equals(lossFunction2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof LossFunctionWrapper;
    }

    public int hashCode() {
        IActivation activationFn = getActivationFn();
        int hashCode = (1 * 59) + (activationFn == null ? 43 : activationFn.hashCode());
        ILossFunction lossFunction = getLossFunction();
        return (hashCode * 59) + (lossFunction == null ? 43 : lossFunction.hashCode());
    }
}
