/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.preprocessors;

import java.util.Arrays;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@JsonIgnoreProperties(value={"hasLeadingDimension"})
public class PermutePreprocessor
extends BaseInputPreProcessor {
    private static final Logger log = LoggerFactory.getLogger(PermutePreprocessor.class);
    private int[] permutationIndices;
    private boolean hasLeadingDimension = false;

    public PermutePreprocessor(int ... permutationIndices) {
        this.permutationIndices = permutationIndices;
    }

    private static int[] prependZero(int[] shape) {
        int shapeLength = shape.length;
        int[] augmentedShape = new int[shapeLength + 1];
        for (int i = 0; i < augmentedShape.length; ++i) {
            augmentedShape[i] = i == 0 ? 0 : shape[i - 1];
        }
        return augmentedShape;
    }

    @Override
    public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
        if (this.permutationIndices.length + 1 == input.shape().length) {
            this.permutationIndices = PermutePreprocessor.prependZero(this.permutationIndices);
            this.hasLeadingDimension = true;
        }
        if (input.ordering() != 'c' || !Shape.hasDefaultStridesForShape((INDArray)input)) {
            input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c');
        }
        INDArray output = workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.permute(this.permutationIndices));
        return output;
    }

    @Override
    public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
        if (output.ordering() != 'c' || !Shape.hasDefaultStridesForShape((INDArray)output)) {
            output = workspaceMgr.dup(ArrayType.ACTIVATIONS, output, 'c');
        }
        return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, output.permute(this.permutationIndices));
    }

    @Override
    public InputType getOutputType(InputType inputType) throws InvalidInputTypeException {
        if (inputType instanceof InputType.InputTypeConvolutional) {
            InputType.InputTypeConvolutional it = (InputType.InputTypeConvolutional)inputType;
            return InputType.convolutional(it.getWidth(), it.getHeight(), it.getChannels());
        }
        if (inputType instanceof InputType.InputTypeRecurrent) {
            InputType.InputTypeRecurrent it = (InputType.InputTypeRecurrent)inputType;
            return InputType.recurrent(it.getTimeSeriesLength(), it.getSize());
        }
        if (inputType instanceof InputType.InputTypeFeedForward || inputType instanceof InputType.InputTypeConvolutional3D) {
            return inputType;
        }
        throw new InvalidInputTypeException("Unsupported Input type " + inputType);
    }

    public int[] getPermutationIndices() {
        return this.permutationIndices;
    }

    public boolean isHasLeadingDimension() {
        return this.hasLeadingDimension;
    }

    public void setPermutationIndices(int[] permutationIndices) {
        this.permutationIndices = permutationIndices;
    }

    public void setHasLeadingDimension(boolean hasLeadingDimension) {
        this.hasLeadingDimension = hasLeadingDimension;
    }

    public String toString() {
        return "PermutePreprocessor(permutationIndices=" + Arrays.toString(this.getPermutationIndices()) + ", hasLeadingDimension=" + this.isHasLeadingDimension() + ")";
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof PermutePreprocessor)) {
            return false;
        }
        PermutePreprocessor other = (PermutePreprocessor)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (this.isHasLeadingDimension() != other.isHasLeadingDimension()) {
            return false;
        }
        return Arrays.equals(this.getPermutationIndices(), other.getPermutationIndices());
    }

    protected boolean canEqual(Object other) {
        return other instanceof PermutePreprocessor;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + (this.isHasLeadingDimension() ? 79 : 97);
        result = result * 59 + Arrays.hashCode(this.getPermutationIndices());
        return result;
    }
}

