package org.deeplearning4j.nn.conf.layers.recurrent;

import java.util.Collection;
import lombok.NonNull;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.nn.layers.recurrent.TimeDistributedLayer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.shade.jackson.annotation.JsonProperty;

/* loaded from: input_file:org/deeplearning4j/nn/conf/layers/recurrent/TimeDistributed.class */
public class TimeDistributed extends BaseWrapperLayer {
    private RNNFormat rnnDataFormat;

    public TimeDistributed(@NonNull @JsonProperty("underlying") Layer layer, @JsonProperty("rnnDataFormat") RNNFormat rNNFormat) {
        super(layer);
        this.rnnDataFormat = RNNFormat.NCW;
        if (layer == null) {
            throw new NullPointerException("underlying is marked non-null but is null");
        }
        this.rnnDataFormat = rNNFormat;
    }

    public TimeDistributed(Layer layer) {
        super(layer);
        this.rnnDataFormat = RNNFormat.NCW;
    }

    @Override // org.deeplearning4j.nn.conf.layers.Layer
    public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration neuralNetConfiguration, Collection<TrainingListener> collection, int i, INDArray iNDArray, boolean z, DataType dataType) {
        NeuralNetConfiguration m39clone = neuralNetConfiguration.m39clone();
        m39clone.setLayer(((TimeDistributed) m39clone.getLayer()).getUnderlying());
        return new TimeDistributedLayer(this.underlying.instantiate(m39clone, collection, i, iNDArray, z, dataType), this.rnnDataFormat);
    }

    @Override // org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer, org.deeplearning4j.nn.conf.layers.Layer
    public InputType getOutputType(int i, InputType inputType) {
        if (inputType.getType() != InputType.Type.RNN) {
            throw new IllegalStateException("Only RNN input type is supported as input to TimeDistributed layer (layer #" + i + ")");
        }
        InputType.InputTypeRecurrent inputTypeRecurrent = (InputType.InputTypeRecurrent) inputType;
        return InputType.recurrent(this.underlying.getOutputType(i, InputType.feedForward(inputTypeRecurrent.getSize())).arrayElementsPerExample(), inputTypeRecurrent.getTimeSeriesLength(), this.rnnDataFormat);
    }

    @Override // org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer, org.deeplearning4j.nn.conf.layers.Layer
    public void setNIn(InputType inputType, boolean z) {
        if (inputType.getType() != InputType.Type.RNN) {
            throw new IllegalStateException("Only RNN input type is supported as input to TimeDistributed layer");
        }
        InputType.InputTypeRecurrent inputTypeRecurrent = (InputType.InputTypeRecurrent) inputType;
        InputType feedForward = InputType.feedForward(inputTypeRecurrent.getSize());
        this.rnnDataFormat = inputTypeRecurrent.getFormat();
        this.underlying.setNIn(feedForward, z);
    }

    @Override // org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer, org.deeplearning4j.nn.conf.layers.Layer
    public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
        return null;
    }

    public RNNFormat getRnnDataFormat() {
        return this.rnnDataFormat;
    }

    public void setRnnDataFormat(RNNFormat rNNFormat) {
        this.rnnDataFormat = rNNFormat;
    }

    @Override // org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer, org.deeplearning4j.nn.conf.layers.Layer
    public String toString() {
        return "TimeDistributed(rnnDataFormat=" + getRnnDataFormat() + ")";
    }

    @Override // org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer, org.deeplearning4j.nn.conf.layers.Layer
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof TimeDistributed)) {
            return false;
        }
        TimeDistributed timeDistributed = (TimeDistributed) obj;
        if (!timeDistributed.canEqual(this) || !super.equals(obj)) {
            return false;
        }
        RNNFormat rnnDataFormat = getRnnDataFormat();
        RNNFormat rnnDataFormat2 = timeDistributed.getRnnDataFormat();
        return rnnDataFormat == null ? rnnDataFormat2 == null : rnnDataFormat.equals(rnnDataFormat2);
    }

    @Override // org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer, org.deeplearning4j.nn.conf.layers.Layer
    protected boolean canEqual(Object obj) {
        return obj instanceof TimeDistributed;
    }

    @Override // org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer, org.deeplearning4j.nn.conf.layers.Layer
    public int hashCode() {
        int hashCode = super.hashCode();
        RNNFormat rnnDataFormat = getRnnDataFormat();
        return (hashCode * 59) + (rnnDataFormat == null ? 43 : rnnDataFormat.hashCode());
    }
}
