/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.feedforward.embedding;

import java.util.Arrays;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Broadcast;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class EmbeddingSequenceLayer
extends BaseLayer<org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer> {
    private static final Logger log = LoggerFactory.getLogger(EmbeddingSequenceLayer.class);
    private static final int[] WEIGHT_DIM = new int[]{1};
    private int[] indexes;

    public EmbeddingSequenceLayer(NeuralNetConfiguration conf, DataType dataType) {
        super(conf, dataType);
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
        boolean ncw;
        this.assertInputSet(true);
        INDArray z = this.preOutput(true, workspaceMgr);
        INDArray delta = (INDArray)((org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer)this.layerConf()).getActivationFn().backprop(z, epsilon).getFirst();
        boolean bl = ncw = ((org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer)this.layerConf()).getOutputFormat() == RNNFormat.NCW;
        if (this.maskArray != null) {
            delta = ncw ? Broadcast.mul((INDArray)delta.castTo(z.dataType()), (INDArray)this.maskArray.castTo(z.dataType()), (INDArray)delta.castTo(z.dataType()), (int[])new int[]{0, 2}) : Broadcast.mul((INDArray)delta.castTo(z.dataType()), (INDArray)this.maskArray.castTo(z.dataType()), (INDArray)delta.castTo(z.dataType()), (int[])new int[]{0, 1});
        }
        int inputLength = ((org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer)this.layerConf()).getInputLength();
        long numSamples = this.input.size(0);
        long nOut = ((org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer)this.layerConf()).getNOut();
        if (delta.ordering() != 'c' || delta.isView() || !Shape.hasDefaultStridesForShape((INDArray)delta)) {
            delta = delta.dup('c');
        }
        if (ncw) {
            delta = delta.permute(new int[]{0, 2, 1});
        }
        delta = delta.reshape('c', new long[]{(long)inputLength * numSamples, nOut});
        INDArray weightGradients = (INDArray)this.gradientViews.get("W");
        weightGradients.assign((Number)0);
        if (!Shape.hasDefaultStridesForShape((INDArray)this.input)) {
            this.input = workspaceMgr.dup(ArrayType.ACTIVATIONS, this.input, 'f');
        }
        INDArray indices = Nd4j.createFromArray((int[])this.indexes);
        Nd4j.scatterUpdate((ScatterUpdate.UpdateOp)ScatterUpdate.UpdateOp.ADD, (INDArray)weightGradients, (INDArray)indices, (INDArray)delta, (int[])WEIGHT_DIM);
        DefaultGradient ret = new DefaultGradient();
        ret.gradientForVariable().put("W", weightGradients);
        if (this.hasBias()) {
            INDArray biasGradientsView = (INDArray)this.gradientViews.get("b");
            delta.sum(biasGradientsView, new int[]{0});
            ret.gradientForVariable().put("b", biasGradientsView);
        }
        return new Pair((Object)ret, null);
    }

    @Override
    protected INDArray preOutput(boolean training, LayerWorkspaceMgr workspaceMgr) {
        boolean inferInputLength;
        this.assertInputSet(false);
        if (this.input.rank() == 1) {
            this.input = this.input.reshape(new long[]{this.input.length(), 1L, 1L});
        }
        if (this.input.rank() == 3 && this.input.size(1) != 1L || this.input.rank() != 2 && this.input.rank() != 3) {
            throw new IllegalStateException("Invalid input: EmbeddingSequenceLayer expects either rank 2 input of shape [minibatch,seqLength] or rank 3 input of shape [minibatch,1,seqLength]. Got rank " + this.input.rank() + " input of shape " + Arrays.toString(this.input.shape()));
        }
        INDArray in = this.input;
        if (this.input.rank() == 3) {
            in = this.input.reshape(this.input.ordering(), new long[]{this.input.size(0), this.input.size(2)});
        }
        if (inferInputLength = ((org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer)this.layerConf()).isInferInputLength()) {
            ((org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer)this.layerConf()).setInputLength(in.columns());
        }
        if (in.columns() != ((org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer)this.layerConf()).getInputLength()) {
            throw new DL4JInvalidInputException("Sequence length of embedding input has to be equal to the specified input length: " + ((org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer)this.layerConf()).getInputLength() + " i.e. we expect input shape [numExamples, inputLength] (or [numExamples, 1, inputLength] with each entry being an integer index,  got " + Arrays.toString(this.input.shape()) + " instead, for layer with id: " + this.layerId());
        }
        long nIn = ((org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer)this.layerConf()).getNIn();
        int minibatch = in.rows();
        int inputLength = ((org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer)this.layerConf()).getInputLength();
        if (in.ordering() != 'c' || in.isView() || !Shape.hasDefaultStridesForShape((INDArray)in)) {
            in = workspaceMgr.dup(ArrayType.INPUT, in, 'c');
        }
        this.indexes = in.data().asInt();
        for (int i = 0; i < this.indexes.length; ++i) {
            if (this.indexes[i] >= 0 && (long)this.indexes[i] < nIn) continue;
            throw new DL4JInvalidInputException("Invalid index for embedding layer: got index " + this.indexes[i] + " for entry " + i + " in minibatch; indexes must be between 0 and nIn-1 inclusive (0 to " + (nIn - 1L) + ")");
        }
        INDArray weights = this.getParam("W");
        long nOut = ((org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer)this.layerConf()).getNOut();
        INDArray destination = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, weights.dataType(), new long[]{minibatch * inputLength, nOut}, 'c');
        INDArray rows = Nd4j.pullRows((INDArray)weights, (INDArray)destination, (int)1, (int[])this.indexes);
        if (this.hasBias()) {
            INDArray bias = this.getParam("b");
            rows.addiRowVector(bias);
        }
        long[] shape = new long[]{minibatch, inputLength, nOut};
        INDArray ret = rows.reshape('c', shape);
        if (((org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer)this.layerConf()).getOutputFormat() == RNNFormat.NCW) {
            ret = ret.permute(new int[]{0, 2, 1});
        }
        return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, ret);
    }

    @Override
    public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {
        INDArray rows = this.preOutput(training, workspaceMgr);
        INDArray ret = ((org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer)this.layerConf()).getActivationFn().getActivation(rows, training);
        if (this.maskArray != null) {
            boolean ncw;
            if (this.maskArray.rank() != 2 || this.input.rank() == 2 && !this.maskArray.equalShapes(this.input) || this.input.rank() == 3 && (this.input.size(0) != this.maskArray.size(0) || this.input.size(2) != this.maskArray.size(1))) {
                throw new IllegalStateException("Mask array for EmbeddingSequenceLayer (when defined) must be rank 2 andhave shape equal to input shape (when input is rank 2, shape [mb,tsLength]) or equal to input dimensions 0 and 2 (when input is rank 3, shape [mb,1,tsLength]). Input shape: " + Arrays.toString(this.input.shape()) + ", mask shape: " + Arrays.toString(this.maskArray.shape()));
            }
            boolean bl = ncw = ((org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer)this.layerConf()).getOutputFormat() == RNNFormat.NCW;
            if (ncw) {
                Broadcast.mul((INDArray)ret, (INDArray)this.maskArray.castTo(ret.dataType()), (INDArray)ret, (int[])new int[]{0, 2});
            } else {
                Broadcast.mul((INDArray)ret, (INDArray)this.maskArray.castTo(ret.dataType()), (INDArray)ret, (int[])new int[]{0, 1});
            }
        }
        return ret;
    }

    @Override
    public boolean hasBias() {
        return ((org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer)this.layerConf()).hasBias();
    }

    @Override
    public boolean isPretrainLayer() {
        return false;
    }

    @Override
    protected void applyDropOutIfNecessary(boolean training, LayerWorkspaceMgr workspaceMgr) {
        throw new UnsupportedOperationException("Dropout not supported with EmbeddingLayer " + this.layerId());
    }

    @Override
    public Layer.Type type() {
        return Layer.Type.RECURRENT;
    }

    @Override
    public void clear() {
        super.clear();
        this.indexes = null;
    }
}

