/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.layers;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
import org.deeplearning4j.nn.conf.layers.samediff.SDLayerParams;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayerUtils;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.deeplearning4j.util.Convolution1DUtils;
import org.nd4j.autodiff.samediff.SDIndex;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.enums.PadMode;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;

@JsonIgnoreProperties(value={"paramShapes"})
public class LocallyConnected1D
extends SameDiffLayer {
    private static final List<String> WEIGHT_KEYS = Collections.singletonList("W");
    private static final List<String> BIAS_KEYS = Collections.singletonList("b");
    private static final List<String> PARAM_KEYS = Arrays.asList("b", "W");
    private long nIn;
    private long nOut;
    private Activation activation;
    private int kernel;
    private int stride;
    private int padding;
    private int paddingR;
    private ConvolutionMode cm;
    private int dilation;
    private boolean hasBias;
    private int inputSize;
    private int outputSize;
    private int featureDim;

    protected LocallyConnected1D(Builder builder) {
        super(builder);
        this.nIn = builder.nIn;
        this.nOut = builder.nOut;
        this.activation = builder.activation;
        this.kernel = builder.kernel;
        this.stride = builder.stride;
        this.padding = builder.padding;
        this.cm = builder.cm;
        this.dilation = builder.dilation;
        this.hasBias = builder.hasBias;
        this.inputSize = builder.inputSize;
        this.featureDim = this.kernel * (int)this.nIn;
    }

    private LocallyConnected1D() {
    }

    public void computeOutputSize() {
        int nIn = (int)this.getNIn();
        if (this.inputSize == 0) {
            throw new IllegalArgumentException("Input size has to be set for Locally connected layers");
        }
        int[] inputShape = new int[]{1, nIn, this.inputSize};
        INDArray dummyInputForShapeInference = Nd4j.ones((int[])inputShape);
        if (this.cm == ConvolutionMode.Same) {
            this.outputSize = Convolution1DUtils.getOutputSize(dummyInputForShapeInference, this.kernel, this.stride, 0, this.cm, this.dilation);
            this.padding = Convolution1DUtils.getSameModeTopLeftPadding(this.outputSize, this.inputSize, this.kernel, this.stride, this.dilation);
            this.paddingR = Convolution1DUtils.getSameModeBottomRightPadding(this.outputSize, this.inputSize, this.kernel, this.stride, this.dilation);
        } else {
            this.outputSize = Convolution1DUtils.getOutputSize(dummyInputForShapeInference, this.kernel, this.stride, this.padding, this.cm, this.dilation);
        }
    }

    @Override
    public InputType getOutputType(int layerIndex, InputType inputType) {
        if (inputType == null || inputType.getType() != InputType.Type.RNN) {
            throw new IllegalArgumentException("Provided input type for locally connected 1D layers has to be of CNN1D/RNN type, got: " + inputType);
        }
        InputType.InputTypeRecurrent rnnType = (InputType.InputTypeRecurrent)inputType;
        this.inputSize = (int)rnnType.getTimeSeriesLength();
        this.computeOutputSize();
        return InputTypeUtil.getOutputTypeCnn1DLayers(inputType, this.kernel, this.stride, this.padding, 1, this.cm, this.nOut, layerIndex, this.getLayerName(), LocallyConnected1D.class);
    }

    @Override
    public void setNIn(InputType inputType, boolean override) {
        InputType.InputTypeRecurrent c;
        if (this.nIn <= 0L || override) {
            c = (InputType.InputTypeRecurrent)inputType;
            this.nIn = c.getSize();
        }
        if (this.featureDim <= 0 || override) {
            c = (InputType.InputTypeRecurrent)inputType;
            this.featureDim = this.kernel * (int)c.getSize();
        }
    }

    @Override
    public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
        return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW, this.getLayerName());
    }

    @Override
    public void defineParameters(SDLayerParams params) {
        Preconditions.checkState((this.featureDim > 0 ? 1 : 0) != 0, (String)"Cannot initialize layer: Feature dimension is set to %s", (int)this.featureDim);
        params.clear();
        long[] weightsShape = new long[]{this.outputSize, this.featureDim, this.nOut};
        params.addWeightParam("W", weightsShape);
        if (this.hasBias) {
            long[] biasShape = new long[]{this.nOut};
            params.addBiasParam("b", biasShape);
        }
    }

    @Override
    public void initializeParameters(Map<String, INDArray> params) {
        try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
            for (Map.Entry<String, INDArray> e : params.entrySet()) {
                if ("b".equals(e.getKey())) {
                    e.getValue().assign((Number)0);
                    continue;
                }
                double fanIn = this.nIn * (long)this.kernel;
                double fanOut = (double)(this.nOut * (long)this.kernel) / (double)this.stride;
                WeightInitUtil.initWeights(fanIn, fanOut, e.getValue().shape(), this.weightInit, null, 'c', e.getValue());
            }
        }
    }

    @Override
    public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, Map<String, SDVariable> paramTable, SDVariable mask) {
        SDVariable w = paramTable.get("W");
        int outH = this.outputSize;
        int sH = this.stride;
        int kH = this.kernel;
        if (this.padding > 0 || this.cm == ConvolutionMode.Same && this.paddingR > 0) {
            layerInput = this.cm == ConvolutionMode.Same ? sameDiff.nn().pad(layerInput, sameDiff.constant(Nd4j.createFromArray((int[][])new int[][]{{0, 0}, {0, 0}, {this.padding, this.paddingR}})), PadMode.CONSTANT, 0.0) : sameDiff.nn().pad(layerInput, sameDiff.constant(Nd4j.createFromArray((int[][])new int[][]{{0, 0}, {0, 0}, {this.padding, this.padding}})), PadMode.CONSTANT, 0.0);
        }
        SDVariable[] inputArray = new SDVariable[outH];
        for (int i = 0; i < outH; ++i) {
            SDVariable slice = layerInput.get(new SDIndex[]{SDIndex.all(), SDIndex.all(), SDIndex.interval((Integer)(i * sH), (Integer)(i * sH + kH))});
            inputArray[i] = sameDiff.reshape(slice, new long[]{1L, -1L, this.featureDim});
        }
        SDVariable concatOutput = sameDiff.concat(0, inputArray);
        SDVariable mmulResult = sameDiff.mmul(concatOutput, w);
        SDVariable result = sameDiff.permute(mmulResult, new int[]{1, 2, 0});
        if (this.hasBias) {
            SDVariable b = paramTable.get("b");
            SDVariable biasAddedResult = sameDiff.nn().biasAdd(result, b, true);
            return this.activation.asSameDiff("out", sameDiff, biasAddedResult);
        }
        return this.activation.asSameDiff("out", sameDiff, result);
    }

    @Override
    public void applyGlobalConfigToLayer(NeuralNetConfiguration.Builder globalConfig) {
        if (this.activation == null) {
            this.activation = SameDiffLayerUtils.fromIActivation(globalConfig.getActivationFn());
        }
        if (this.cm == null) {
            this.cm = globalConfig.getConvolutionMode();
        }
    }

    public long getNIn() {
        return this.nIn;
    }

    public long getNOut() {
        return this.nOut;
    }

    public Activation getActivation() {
        return this.activation;
    }

    public int getKernel() {
        return this.kernel;
    }

    public int getStride() {
        return this.stride;
    }

    public int getPadding() {
        return this.padding;
    }

    public int getPaddingR() {
        return this.paddingR;
    }

    public ConvolutionMode getCm() {
        return this.cm;
    }

    public int getDilation() {
        return this.dilation;
    }

    public boolean isHasBias() {
        return this.hasBias;
    }

    public int getInputSize() {
        return this.inputSize;
    }

    public int getOutputSize() {
        return this.outputSize;
    }

    public int getFeatureDim() {
        return this.featureDim;
    }

    public void setNIn(long nIn) {
        this.nIn = nIn;
    }

    public void setNOut(long nOut) {
        this.nOut = nOut;
    }

    public void setActivation(Activation activation) {
        this.activation = activation;
    }

    public void setKernel(int kernel) {
        this.kernel = kernel;
    }

    public void setStride(int stride) {
        this.stride = stride;
    }

    public void setPadding(int padding) {
        this.padding = padding;
    }

    public void setPaddingR(int paddingR) {
        this.paddingR = paddingR;
    }

    public void setCm(ConvolutionMode cm) {
        this.cm = cm;
    }

    public void setDilation(int dilation) {
        this.dilation = dilation;
    }

    public void setHasBias(boolean hasBias) {
        this.hasBias = hasBias;
    }

    public void setInputSize(int inputSize) {
        this.inputSize = inputSize;
    }

    public void setOutputSize(int outputSize) {
        this.outputSize = outputSize;
    }

    public void setFeatureDim(int featureDim) {
        this.featureDim = featureDim;
    }

    @Override
    public String toString() {
        return "LocallyConnected1D(nIn=" + this.getNIn() + ", nOut=" + this.getNOut() + ", activation=" + this.getActivation() + ", kernel=" + this.getKernel() + ", stride=" + this.getStride() + ", padding=" + this.getPadding() + ", paddingR=" + this.getPaddingR() + ", cm=" + this.getCm() + ", dilation=" + this.getDilation() + ", hasBias=" + this.isHasBias() + ", inputSize=" + this.getInputSize() + ", outputSize=" + this.getOutputSize() + ", featureDim=" + this.getFeatureDim() + ")";
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof LocallyConnected1D)) {
            return false;
        }
        LocallyConnected1D other = (LocallyConnected1D)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        if (this.getNIn() != other.getNIn()) {
            return false;
        }
        if (this.getNOut() != other.getNOut()) {
            return false;
        }
        if (this.getKernel() != other.getKernel()) {
            return false;
        }
        if (this.getStride() != other.getStride()) {
            return false;
        }
        if (this.getPadding() != other.getPadding()) {
            return false;
        }
        if (this.getPaddingR() != other.getPaddingR()) {
            return false;
        }
        if (this.getDilation() != other.getDilation()) {
            return false;
        }
        if (this.isHasBias() != other.isHasBias()) {
            return false;
        }
        if (this.getInputSize() != other.getInputSize()) {
            return false;
        }
        if (this.getOutputSize() != other.getOutputSize()) {
            return false;
        }
        if (this.getFeatureDim() != other.getFeatureDim()) {
            return false;
        }
        Activation this$activation = this.getActivation();
        Activation other$activation = other.getActivation();
        if (this$activation == null ? other$activation != null : !this$activation.equals(other$activation)) {
            return false;
        }
        ConvolutionMode this$cm = this.getCm();
        ConvolutionMode other$cm = other.getCm();
        return !(this$cm == null ? other$cm != null : !((Object)((Object)this$cm)).equals((Object)other$cm));
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof LocallyConnected1D;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = super.hashCode();
        long $nIn = this.getNIn();
        result = result * 59 + (int)($nIn >>> 32 ^ $nIn);
        long $nOut = this.getNOut();
        result = result * 59 + (int)($nOut >>> 32 ^ $nOut);
        result = result * 59 + this.getKernel();
        result = result * 59 + this.getStride();
        result = result * 59 + this.getPadding();
        result = result * 59 + this.getPaddingR();
        result = result * 59 + this.getDilation();
        result = result * 59 + (this.isHasBias() ? 79 : 97);
        result = result * 59 + this.getInputSize();
        result = result * 59 + this.getOutputSize();
        result = result * 59 + this.getFeatureDim();
        Activation $activation = this.getActivation();
        result = result * 59 + ($activation == null ? 43 : $activation.hashCode());
        ConvolutionMode $cm = this.getCm();
        result = result * 59 + ($cm == null ? 43 : ((Object)((Object)$cm)).hashCode());
        return result;
    }

    public static class Builder
    extends SameDiffLayer.Builder<Builder> {
        private int nIn;
        private int nOut;
        private Activation activation = Activation.TANH;
        private int kernel = 2;
        private int stride = 1;
        private int padding = 0;
        private int dilation = 1;
        private int inputSize;
        private ConvolutionMode cm = ConvolutionMode.Same;
        private boolean hasBias = true;

        public Builder nIn(int nIn) {
            this.setNIn(nIn);
            return this;
        }

        public Builder nOut(int nOut) {
            this.setNOut(nOut);
            return this;
        }

        public Builder activation(Activation activation) {
            this.setActivation(activation);
            return this;
        }

        public Builder kernelSize(int k) {
            this.setKernel(k);
            return this;
        }

        public Builder stride(int s) {
            this.setStride(s);
            return this;
        }

        public Builder padding(int p) {
            this.setPadding(p);
            return this;
        }

        public Builder convolutionMode(ConvolutionMode cm) {
            this.setCm(cm);
            return this;
        }

        public Builder dilation(int d) {
            this.setDilation(d);
            return this;
        }

        public Builder hasBias(boolean hasBias) {
            this.setHasBias(hasBias);
            return this;
        }

        public Builder setInputSize(int inputSize) {
            this.inputSize = inputSize;
            return this;
        }

        @Override
        public LocallyConnected1D build() {
            Convolution1DUtils.validateConvolutionModePadding(this.cm, this.padding);
            Convolution1DUtils.validateCnn1DKernelStridePadding(this.kernel, this.stride, this.padding);
            return new LocallyConnected1D(this);
        }

        public int getNIn() {
            return this.nIn;
        }

        public int getNOut() {
            return this.nOut;
        }

        public Activation getActivation() {
            return this.activation;
        }

        public int getKernel() {
            return this.kernel;
        }

        public int getStride() {
            return this.stride;
        }

        public int getPadding() {
            return this.padding;
        }

        public int getDilation() {
            return this.dilation;
        }

        public int getInputSize() {
            return this.inputSize;
        }

        public ConvolutionMode getCm() {
            return this.cm;
        }

        public boolean isHasBias() {
            return this.hasBias;
        }

        public void setNIn(int nIn) {
            this.nIn = nIn;
        }

        public void setNOut(int nOut) {
            this.nOut = nOut;
        }

        public void setActivation(Activation activation) {
            this.activation = activation;
        }

        public void setKernel(int kernel) {
            this.kernel = kernel;
        }

        public void setStride(int stride) {
            this.stride = stride;
        }

        public void setPadding(int padding) {
            this.padding = padding;
        }

        public void setDilation(int dilation) {
            this.dilation = dilation;
        }

        public void setCm(ConvolutionMode cm) {
            this.cm = cm;
        }

        public void setHasBias(boolean hasBias) {
            this.hasBias = hasBias;
        }
    }
}

