/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.layers;

import java.util.Arrays;
import java.util.Map;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.samediff.SDLayerParams;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.deeplearning4j.util.CapsuleUtils;
import org.deeplearning4j.util.ValidationUtils;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.factory.Nd4j;

public class PrimaryCapsules
extends SameDiffLayer {
    private int[] kernelSize;
    private int[] stride;
    private int[] padding;
    private int[] dilation;
    private int inputChannels;
    private int channels;
    private boolean hasBias;
    private int capsules;
    private int capsuleDimensions;
    private ConvolutionMode convolutionMode = ConvolutionMode.Truncate;
    private boolean useRelu = false;
    private double leak = 0.0;
    private static final String WEIGHT_PARAM = "weight";
    private static final String BIAS_PARAM = "bias";

    public PrimaryCapsules(Builder builder) {
        super(builder);
        this.kernelSize = builder.kernelSize;
        this.stride = builder.stride;
        this.padding = builder.padding;
        this.dilation = builder.dilation;
        this.channels = builder.channels;
        this.hasBias = builder.hasBias;
        this.capsules = builder.capsules;
        this.capsuleDimensions = builder.capsuleDimensions;
        this.convolutionMode = builder.convolutionMode;
        this.useRelu = builder.useRelu;
        this.leak = builder.leak;
        if (this.capsuleDimensions <= 0 || this.channels <= 0) {
            throw new IllegalArgumentException("Invalid configuration for Primary Capsules (layer name = \"" + this.layerName + "\"): capsuleDimensions and channels must be > 0.  Got: " + this.capsuleDimensions + ", " + this.channels);
        }
        if (this.capsules < 0) {
            throw new IllegalArgumentException("Invalid configuration for Capsule Layer (layer name = \"" + this.layerName + "\"): capsules must be >= 0 if set.  Got: " + this.capsules);
        }
    }

    @Override
    public SDVariable defineLayer(SameDiff SD, SDVariable input, Map<String, SDVariable> paramTable, SDVariable mask) {
        Conv2DConfig conf = Conv2DConfig.builder().kH((long)this.kernelSize[0]).kW((long)this.kernelSize[1]).sH((long)this.stride[0]).sW((long)this.stride[1]).pH((long)this.padding[0]).pW((long)this.padding[1]).dH((long)this.dilation[0]).dW((long)this.dilation[1]).paddingMode(ConvolutionMode.mapToMode(this.convolutionMode)).build();
        SDVariable conved = this.hasBias ? SD.cnn.conv2d(input, paramTable.get(WEIGHT_PARAM), paramTable.get(BIAS_PARAM), conf) : SD.cnn.conv2d(input, paramTable.get(WEIGHT_PARAM), conf);
        if (this.useRelu) {
            conved = this.leak == 0.0 ? SD.nn.relu(conved, 0.0) : SD.nn.leakyRelu(conved, this.leak);
        }
        SDVariable reshaped = conved.reshape(new int[]{-1, this.capsules, this.capsuleDimensions});
        return CapsuleUtils.squash(SD, reshaped, 2);
    }

    @Override
    public void defineParameters(SDLayerParams params) {
        params.clear();
        params.addWeightParam(WEIGHT_PARAM, this.kernelSize[0], this.kernelSize[1], this.inputChannels, this.capsuleDimensions * this.channels);
        if (this.hasBias) {
            params.addBiasParam(BIAS_PARAM, this.capsuleDimensions * this.channels);
        }
    }

    @Override
    public void initializeParameters(Map<String, INDArray> params) {
        try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
            for (Map.Entry<String, INDArray> e : params.entrySet()) {
                if (BIAS_PARAM.equals(e.getKey())) {
                    e.getValue().assign((Number)0);
                    continue;
                }
                if (!WEIGHT_PARAM.equals(e.getKey())) continue;
                double fanIn = this.inputChannels * this.kernelSize[0] * this.kernelSize[1];
                double fanOut = (double)(this.capsuleDimensions * this.channels * this.kernelSize[0] * this.kernelSize[1]) / ((double)this.stride[0] * (double)this.stride[1]);
                WeightInitUtil.initWeights(fanIn, fanOut, e.getValue().shape(), this.weightInit, null, 'c', e.getValue());
            }
        }
    }

    @Override
    public InputType getOutputType(int layerIndex, InputType inputType) {
        if (inputType == null || inputType.getType() != InputType.Type.CNN) {
            throw new IllegalStateException("Invalid input for Primary Capsules layer (layer name = \"" + this.layerName + "\"): expect CNN input.  Got: " + inputType);
        }
        if (this.capsules > 0) {
            return InputType.recurrent((long)this.capsules, this.capsuleDimensions);
        }
        InputType.InputTypeConvolutional out = (InputType.InputTypeConvolutional)InputTypeUtil.getOutputTypeCnnLayers(inputType, this.kernelSize, this.stride, this.padding, this.dilation, this.convolutionMode, this.capsuleDimensions * this.channels, -1L, this.getLayerName(), PrimaryCapsules.class);
        return InputType.recurrent((long)((int)(out.getChannels() * out.getHeight() * out.getWidth() / (long)this.capsuleDimensions)), this.capsuleDimensions);
    }

    @Override
    public void setNIn(InputType inputType, boolean override) {
        if (inputType == null || inputType.getType() != InputType.Type.CNN) {
            throw new IllegalStateException("Invalid input for Primary Capsules layer (layer name = \"" + this.layerName + "\"): expect CNN input.  Got: " + inputType);
        }
        InputType.InputTypeConvolutional ci = (InputType.InputTypeConvolutional)inputType;
        this.inputChannels = (int)ci.getChannels();
        if (this.capsules <= 0 || override) {
            InputType.InputTypeConvolutional out = (InputType.InputTypeConvolutional)InputTypeUtil.getOutputTypeCnnLayers(inputType, this.kernelSize, this.stride, this.padding, this.dilation, this.convolutionMode, this.capsuleDimensions * this.channels, -1L, this.getLayerName(), PrimaryCapsules.class);
            this.capsules = (int)(out.getChannels() * out.getHeight() * out.getWidth() / (long)this.capsuleDimensions);
        }
    }

    public int[] getKernelSize() {
        return this.kernelSize;
    }

    public int[] getStride() {
        return this.stride;
    }

    public int[] getPadding() {
        return this.padding;
    }

    public int[] getDilation() {
        return this.dilation;
    }

    public int getInputChannels() {
        return this.inputChannels;
    }

    public int getChannels() {
        return this.channels;
    }

    public boolean isHasBias() {
        return this.hasBias;
    }

    public int getCapsules() {
        return this.capsules;
    }

    public int getCapsuleDimensions() {
        return this.capsuleDimensions;
    }

    public ConvolutionMode getConvolutionMode() {
        return this.convolutionMode;
    }

    public boolean isUseRelu() {
        return this.useRelu;
    }

    public double getLeak() {
        return this.leak;
    }

    public void setKernelSize(int[] kernelSize) {
        this.kernelSize = kernelSize;
    }

    public void setStride(int[] stride) {
        this.stride = stride;
    }

    public void setPadding(int[] padding) {
        this.padding = padding;
    }

    public void setDilation(int[] dilation) {
        this.dilation = dilation;
    }

    public void setInputChannels(int inputChannels) {
        this.inputChannels = inputChannels;
    }

    public void setChannels(int channels) {
        this.channels = channels;
    }

    public void setHasBias(boolean hasBias) {
        this.hasBias = hasBias;
    }

    public void setCapsules(int capsules) {
        this.capsules = capsules;
    }

    public void setCapsuleDimensions(int capsuleDimensions) {
        this.capsuleDimensions = capsuleDimensions;
    }

    public void setConvolutionMode(ConvolutionMode convolutionMode) {
        this.convolutionMode = convolutionMode;
    }

    public void setUseRelu(boolean useRelu) {
        this.useRelu = useRelu;
    }

    public void setLeak(double leak) {
        this.leak = leak;
    }

    @Override
    public String toString() {
        return "PrimaryCapsules(kernelSize=" + Arrays.toString(this.getKernelSize()) + ", stride=" + Arrays.toString(this.getStride()) + ", padding=" + Arrays.toString(this.getPadding()) + ", dilation=" + Arrays.toString(this.getDilation()) + ", inputChannels=" + this.getInputChannels() + ", channels=" + this.getChannels() + ", hasBias=" + this.isHasBias() + ", capsules=" + this.getCapsules() + ", capsuleDimensions=" + this.getCapsuleDimensions() + ", convolutionMode=" + this.getConvolutionMode() + ", useRelu=" + this.isUseRelu() + ", leak=" + this.getLeak() + ")";
    }

    public PrimaryCapsules() {
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof PrimaryCapsules)) {
            return false;
        }
        PrimaryCapsules other = (PrimaryCapsules)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        if (this.getInputChannels() != other.getInputChannels()) {
            return false;
        }
        if (this.getChannels() != other.getChannels()) {
            return false;
        }
        if (this.isHasBias() != other.isHasBias()) {
            return false;
        }
        if (this.getCapsules() != other.getCapsules()) {
            return false;
        }
        if (this.getCapsuleDimensions() != other.getCapsuleDimensions()) {
            return false;
        }
        if (this.isUseRelu() != other.isUseRelu()) {
            return false;
        }
        if (Double.compare(this.getLeak(), other.getLeak()) != 0) {
            return false;
        }
        if (!Arrays.equals(this.getKernelSize(), other.getKernelSize())) {
            return false;
        }
        if (!Arrays.equals(this.getStride(), other.getStride())) {
            return false;
        }
        if (!Arrays.equals(this.getPadding(), other.getPadding())) {
            return false;
        }
        if (!Arrays.equals(this.getDilation(), other.getDilation())) {
            return false;
        }
        ConvolutionMode this$convolutionMode = this.getConvolutionMode();
        ConvolutionMode other$convolutionMode = other.getConvolutionMode();
        return !(this$convolutionMode == null ? other$convolutionMode != null : !((Object)((Object)this$convolutionMode)).equals((Object)other$convolutionMode));
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof PrimaryCapsules;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = super.hashCode();
        result = result * 59 + this.getInputChannels();
        result = result * 59 + this.getChannels();
        result = result * 59 + (this.isHasBias() ? 79 : 97);
        result = result * 59 + this.getCapsules();
        result = result * 59 + this.getCapsuleDimensions();
        result = result * 59 + (this.isUseRelu() ? 79 : 97);
        long $leak = Double.doubleToLongBits(this.getLeak());
        result = result * 59 + (int)($leak >>> 32 ^ $leak);
        result = result * 59 + Arrays.hashCode(this.getKernelSize());
        result = result * 59 + Arrays.hashCode(this.getStride());
        result = result * 59 + Arrays.hashCode(this.getPadding());
        result = result * 59 + Arrays.hashCode(this.getDilation());
        ConvolutionMode $convolutionMode = this.getConvolutionMode();
        result = result * 59 + ($convolutionMode == null ? 43 : ((Object)((Object)$convolutionMode)).hashCode());
        return result;
    }

    public static class Builder
    extends SameDiffLayer.Builder<Builder> {
        private int[] kernelSize = new int[]{9, 9};
        private int[] stride = new int[]{2, 2};
        private int[] padding = new int[]{0, 0};
        private int[] dilation = new int[]{1, 1};
        private int channels = 32;
        private boolean hasBias = true;
        private int capsules;
        private int capsuleDimensions;
        private ConvolutionMode convolutionMode = ConvolutionMode.Truncate;
        private boolean useRelu = false;
        private double leak = 0.0;

        public void setKernelSize(int ... kernelSize) {
            this.kernelSize = ValidationUtils.validate2NonNegative(kernelSize, true, "kernelSize");
        }

        public void setStride(int ... stride) {
            this.stride = ValidationUtils.validate2NonNegative(stride, true, "stride");
        }

        public void setPadding(int ... padding) {
            this.padding = ValidationUtils.validate2NonNegative(padding, true, "padding");
        }

        public void setDilation(int ... dilation) {
            this.dilation = ValidationUtils.validate2NonNegative(dilation, true, "dilation");
        }

        public Builder(int capsuleDimensions, int channels, int[] kernelSize, int[] stride, int[] padding, int[] dilation, ConvolutionMode convolutionMode) {
            this.capsuleDimensions = capsuleDimensions;
            this.channels = channels;
            this.setKernelSize(kernelSize);
            this.setStride(stride);
            this.setPadding(padding);
            this.setDilation(dilation);
            this.convolutionMode = convolutionMode;
        }

        public Builder(int capsuleDimensions, int channels, int[] kernelSize, int[] stride, int[] padding, int[] dilation) {
            this(capsuleDimensions, channels, kernelSize, stride, padding, dilation, ConvolutionMode.Truncate);
        }

        public Builder(int capsuleDimensions, int channels, int[] kernelSize, int[] stride, int[] padding) {
            this(capsuleDimensions, channels, kernelSize, stride, padding, new int[]{1, 1}, ConvolutionMode.Truncate);
        }

        public Builder(int capsuleDimensions, int channels, int[] kernelSize, int[] stride) {
            this(capsuleDimensions, channels, kernelSize, stride, new int[]{0, 0}, new int[]{1, 1}, ConvolutionMode.Truncate);
        }

        public Builder(int capsuleDimensions, int channels, int[] kernelSize) {
            this(capsuleDimensions, channels, kernelSize, new int[]{2, 2}, new int[]{0, 0}, new int[]{1, 1}, ConvolutionMode.Truncate);
        }

        public Builder(int capsuleDimensions, int channels) {
            this(capsuleDimensions, channels, new int[]{9, 9}, new int[]{2, 2}, new int[]{0, 0}, new int[]{1, 1}, ConvolutionMode.Truncate);
        }

        public Builder kernelSize(int ... kernelSize) {
            this.setKernelSize(kernelSize);
            return this;
        }

        public Builder stride(int ... stride) {
            this.setStride(stride);
            return this;
        }

        public Builder padding(int ... padding) {
            this.setPadding(padding);
            return this;
        }

        public Builder dilation(int ... dilation) {
            this.setDilation(dilation);
            return this;
        }

        public Builder channels(int channels) {
            this.channels = channels;
            return this;
        }

        public Builder nOut(int nOut) {
            return this.channels(nOut);
        }

        public Builder capsuleDimensions(int capsuleDimensions) {
            this.capsuleDimensions = capsuleDimensions;
            return this;
        }

        public Builder capsules(int capsules) {
            this.capsules = capsules;
            return this;
        }

        public Builder hasBias(boolean hasBias) {
            this.hasBias = hasBias;
            return this;
        }

        public Builder convolutionMode(ConvolutionMode convolutionMode) {
            this.convolutionMode = convolutionMode;
            return this;
        }

        public Builder useReLU(boolean useRelu) {
            this.useRelu = useRelu;
            return this;
        }

        public Builder useReLU() {
            return this.useReLU(true);
        }

        public Builder useLeakyReLU(double leak) {
            this.useRelu = true;
            this.leak = leak;
            return this;
        }

        @Override
        public <E extends Layer> E build() {
            return (E)new PrimaryCapsules(this);
        }

        public int[] getKernelSize() {
            return this.kernelSize;
        }

        public int[] getStride() {
            return this.stride;
        }

        public int[] getPadding() {
            return this.padding;
        }

        public int[] getDilation() {
            return this.dilation;
        }

        public int getChannels() {
            return this.channels;
        }

        public boolean isHasBias() {
            return this.hasBias;
        }

        public int getCapsules() {
            return this.capsules;
        }

        public int getCapsuleDimensions() {
            return this.capsuleDimensions;
        }

        public ConvolutionMode getConvolutionMode() {
            return this.convolutionMode;
        }

        public boolean isUseRelu() {
            return this.useRelu;
        }

        public double getLeak() {
            return this.leak;
        }

        public void setChannels(int channels) {
            this.channels = channels;
        }

        public void setHasBias(boolean hasBias) {
            this.hasBias = hasBias;
        }

        public void setCapsules(int capsules) {
            this.capsules = capsules;
        }

        public void setCapsuleDimensions(int capsuleDimensions) {
            this.capsuleDimensions = capsuleDimensions;
        }

        public void setConvolutionMode(ConvolutionMode convolutionMode) {
            this.convolutionMode = convolutionMode;
        }

        public void setUseRelu(boolean useRelu) {
            this.useRelu = useRelu;
        }

        public void setLeak(double leak) {
            this.leak = leak;
        }
    }
}

