/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.modelimport.keras.layers.wrappers;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.layers.recurrent.KerasLSTM;
import org.deeplearning4j.nn.modelimport.keras.layers.recurrent.KerasSimpleRnn;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
import org.nd4j.linalg.api.ndarray.INDArray;

public class KerasBidirectional
extends KerasLayer {
    private KerasLayer kerasRnnlayer;

    public KerasBidirectional(Integer kerasVersion) throws UnsupportedKerasConfigurationException {
        super(kerasVersion);
    }

    public KerasBidirectional(Map<String, Object> layerConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        this(layerConfig, true, Collections.emptyMap());
    }

    public KerasBidirectional(Map<String, Object> layerConfig, Map<String, ? extends KerasLayer> previousLayers) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        this(layerConfig, true, previousLayers);
    }

    public KerasBidirectional(Map<String, Object> layerConfig, boolean enforceTrainingConfig, Map<String, ? extends KerasLayer> previousLayers) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        super(layerConfig, enforceTrainingConfig);
        String rnnClass;
        Bidirectional.Mode mode;
        String mergeModeString;
        Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, this.conf);
        if (!innerConfig.containsKey("merge_mode")) {
            throw new InvalidKerasConfigurationException("Field 'merge_mode' not found in configuration of Bidirectional layer.");
        }
        if (!innerConfig.containsKey("layer")) {
            throw new InvalidKerasConfigurationException("Field 'layer' not found in configuration ofBidirectional layer, i.e. no layer to be wrapped found.");
        }
        Map innerRnnConfig = (Map)innerConfig.get("layer");
        if (!innerRnnConfig.containsKey("class_name")) {
            throw new InvalidKerasConfigurationException("No 'class_name' specified within Bidirectional layerconfiguration.");
        }
        switch (mergeModeString = (String)innerConfig.get("merge_mode")) {
            case "sum": {
                mode = Bidirectional.Mode.ADD;
                break;
            }
            case "concat": {
                mode = Bidirectional.Mode.CONCAT;
                break;
            }
            case "mul": {
                mode = Bidirectional.Mode.MUL;
                break;
            }
            case "ave": {
                mode = Bidirectional.Mode.AVERAGE;
                break;
            }
            default: {
                throw new UnsupportedKerasConfigurationException("Merge mode " + mergeModeString + " not supported.");
            }
        }
        innerRnnConfig.put(this.conf.getLAYER_FIELD_KERAS_VERSION(), this.kerasMajorVersion);
        switch (rnnClass = (String)innerRnnConfig.get("class_name")) {
            case "LSTM": {
                this.kerasRnnlayer = new KerasLSTM(innerRnnConfig, enforceTrainingConfig, previousLayers);
                try {
                    LSTM rnnLayer = (LSTM)((KerasLSTM)this.kerasRnnlayer).getLSTMLayer();
                    this.layer = new Bidirectional(mode, (Layer)rnnLayer);
                    this.layer.setLayerName(this.layerName);
                }
                catch (Exception e) {
                    LastTimeStep rnnLayer = (LastTimeStep)((KerasLSTM)this.kerasRnnlayer).getLSTMLayer();
                    this.layer = new Bidirectional(mode, (Layer)rnnLayer);
                    this.layer.setLayerName(this.layerName);
                }
                break;
            }
            case "SimpleRNN": {
                this.kerasRnnlayer = new KerasSimpleRnn(innerRnnConfig, enforceTrainingConfig, previousLayers);
                Layer rnnLayer = ((KerasSimpleRnn)this.kerasRnnlayer).getSimpleRnnLayer();
                this.layer = new Bidirectional(mode, rnnLayer);
                this.layer.setLayerName(this.layerName);
                break;
            }
            default: {
                throw new UnsupportedKerasConfigurationException("Currently only two types of recurrent Keras layers aresupported, 'LSTM' and 'SimpleRNN'. You tried to load a layer of class:" + rnnClass);
            }
        }
    }

    public Layer getUnderlyingRecurrentLayer() {
        return this.kerasRnnlayer.getLayer();
    }

    public Bidirectional getBidirectionalLayer() {
        return (Bidirectional)this.layer;
    }

    @Override
    public InputType getOutputType(InputType ... inputType) throws InvalidKerasConfigurationException {
        if (inputType.length > 1) {
            throw new InvalidKerasConfigurationException("Keras Bidirectional layer accepts only one input (received " + inputType.length + ")");
        }
        InputPreProcessor preProcessor = this.getInputPreprocessor(inputType);
        if (preProcessor != null) {
            return this.getBidirectionalLayer().getOutputType(-1, preProcessor.getOutputType(inputType[0]));
        }
        return this.getBidirectionalLayer().getOutputType(-1, inputType[0]);
    }

    @Override
    public int getNumParams() {
        return 2 * this.kerasRnnlayer.getNumParams();
    }

    @Override
    public InputPreProcessor getInputPreprocessor(InputType ... inputType) throws InvalidKerasConfigurationException {
        if (inputType.length > 1) {
            throw new InvalidKerasConfigurationException("Keras Bidirectional layer accepts only one input (received " + inputType.length + ")");
        }
        return InputTypeUtil.getPreprocessorForInputTypeRnnLayers((InputType)inputType[0], (RNNFormat)((Bidirectional)this.layer).getRNNDataFormat(), (String)this.layerName);
    }

    @Override
    public void setWeights(Map<String, INDArray> weights) throws InvalidKerasConfigurationException {
        Map<String, INDArray> forwardWeights = this.getUnderlyingWeights(((Bidirectional)this.layer).getFwd(), weights, "forward");
        Map<String, INDArray> backwardWeights = this.getUnderlyingWeights(((Bidirectional)this.layer).getBwd(), weights, "backward");
        this.weights = new HashMap();
        for (String key : forwardWeights.keySet()) {
            this.weights.put("f" + key, forwardWeights.get(key));
        }
        for (String key : backwardWeights.keySet()) {
            this.weights.put("b" + key, backwardWeights.get(key));
        }
    }

    private Map<String, INDArray> getUnderlyingWeights(Layer l, Map<String, INDArray> weights, String direction) throws InvalidKerasConfigurationException {
        int keras1SubstringLength;
        if (this.kerasRnnlayer instanceof KerasLSTM) {
            keras1SubstringLength = 3;
        } else if (this.kerasRnnlayer instanceof KerasSimpleRnn) {
            keras1SubstringLength = 1;
        } else {
            throw new InvalidKerasConfigurationException("Unsupported layer type " + this.kerasRnnlayer.getClassName());
        }
        HashMap<String, INDArray> newWeights = new HashMap<String, INDArray>();
        for (String key : weights.keySet()) {
            Object newKey;
            if (!key.contains(direction)) continue;
            if (this.kerasMajorVersion == 2) {
                String[] subKeys = key.split("_");
                newKey = key.contains("recurrent") ? subKeys[subKeys.length - 2] + "_" + subKeys[subKeys.length - 1] : subKeys[subKeys.length - 1];
            } else {
                newKey = key.substring(key.length() - keras1SubstringLength);
            }
            newWeights.put((String)newKey, weights.get(key));
        }
        if (!newWeights.isEmpty()) {
            weights = newWeights;
        }
        Layer layerBefore = this.kerasRnnlayer.getLayer();
        this.kerasRnnlayer.setLayer(l);
        this.kerasRnnlayer.setWeights(weights);
        Map<String, INDArray> ret = this.kerasRnnlayer.getWeights();
        this.kerasRnnlayer.setLayer(layerBefore);
        return ret;
    }
}

