package org.deeplearning4j.nn.conf.dropout;

import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp;
import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.schedule.ISchedule;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonProperty;

@JsonIgnoreProperties({"noise"})
/* loaded from: input_file:org/deeplearning4j/nn/conf/dropout/GaussianDropout.class */
public class GaussianDropout implements IDropout {
    private final double rate;
    private final ISchedule rateSchedule;
    private transient INDArray noise;

    public GaussianDropout(double d) {
        this(d, null);
    }

    public GaussianDropout(ISchedule iSchedule) {
        this(Double.NaN, iSchedule);
    }

    protected GaussianDropout(@JsonProperty("rate") double d, @JsonProperty("rateSchedule") ISchedule iSchedule) {
        this.rate = d;
        this.rateSchedule = iSchedule;
    }

    @Override // org.deeplearning4j.nn.conf.dropout.IDropout
    public INDArray applyDropout(INDArray iNDArray, INDArray iNDArray2, int i, int i2, LayerWorkspaceMgr layerWorkspaceMgr) {
        double valueAt = this.rateSchedule != null ? this.rateSchedule.valueAt(i, i2) : this.rate;
        double sqrt = Math.sqrt(valueAt / (1.0d - valueAt));
        this.noise = layerWorkspaceMgr.createUninitialized(ArrayType.INPUT, iNDArray2.dataType(), iNDArray.shape(), iNDArray.ordering());
        Nd4j.getExecutioner().exec(new GaussianDistribution(this.noise, 1.0d, sqrt));
        return Nd4j.getExecutioner().exec(new MulOp(iNDArray, this.noise, iNDArray2))[0];
    }

    @Override // org.deeplearning4j.nn.conf.dropout.IDropout
    public INDArray backprop(INDArray iNDArray, INDArray iNDArray2, int i, int i2) {
        Preconditions.checkState(this.noise != null, "Cannot perform backprop: GaussianDropout noise array is absent (already cleared?)");
        Nd4j.getExecutioner().exec(new MulOp(iNDArray, this.noise, iNDArray2));
        this.noise = null;
        return iNDArray2;
    }

    @Override // org.deeplearning4j.nn.conf.dropout.IDropout
    public void clear() {
        this.noise = null;
    }

    @Override // org.deeplearning4j.nn.conf.dropout.IDropout
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public GaussianDropout m51clone() {
        return new GaussianDropout(this.rate, this.rateSchedule == null ? null : this.rateSchedule.clone());
    }

    public double getRate() {
        return this.rate;
    }

    public ISchedule getRateSchedule() {
        return this.rateSchedule;
    }

    public INDArray getNoise() {
        return this.noise;
    }

    public void setNoise(INDArray iNDArray) {
        this.noise = iNDArray;
    }

    public String toString() {
        double rate = getRate();
        ISchedule rateSchedule = getRateSchedule();
        getNoise();
        return "GaussianDropout(rate=" + rate + ", rateSchedule=" + rate + ", noise=" + rateSchedule + ")";
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof GaussianDropout)) {
            return false;
        }
        GaussianDropout gaussianDropout = (GaussianDropout) obj;
        if (!gaussianDropout.canEqual(this) || Double.compare(getRate(), gaussianDropout.getRate()) != 0) {
            return false;
        }
        ISchedule rateSchedule = getRateSchedule();
        ISchedule rateSchedule2 = gaussianDropout.getRateSchedule();
        return rateSchedule == null ? rateSchedule2 == null : rateSchedule.equals(rateSchedule2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof GaussianDropout;
    }

    public int hashCode() {
        long doubleToLongBits = Double.doubleToLongBits(getRate());
        int i = (1 * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
        ISchedule rateSchedule = getRateSchedule();
        return (i * 59) + (rateSchedule == null ? 43 : rateSchedule.hashCode());
    }
}
