/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.modelimport.keras.utils;

import java.lang.reflect.Constructor;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.deeplearning4j.nn.conf.graph.ElementWiseVertex;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaLayer;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.layers.KerasInput;
import org.deeplearning4j.nn.modelimport.keras.layers.KerasTFOpLayer;
import org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations.KerasELU;
import org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations.KerasLeakyReLU;
import org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations.KerasPReLU;
import org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations.KerasReLU;
import org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations.KerasSoftmax;
import org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations.KerasThresholdedReLU;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasAtrousConvolution1D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasAtrousConvolution2D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolution1D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolution2D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolution3D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasCropping1D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasCropping2D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasCropping3D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasDeconvolution2D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasDepthwiseConvolution2D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasSeparableConvolution2D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasUpsampling1D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasUpsampling2D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasZeroPadding1D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasZeroPadding2D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasZeroPadding3D;
import org.deeplearning4j.nn.modelimport.keras.layers.core.KerasActivation;
import org.deeplearning4j.nn.modelimport.keras.layers.core.KerasDense;
import org.deeplearning4j.nn.modelimport.keras.layers.core.KerasDropout;
import org.deeplearning4j.nn.modelimport.keras.layers.core.KerasFlatten;
import org.deeplearning4j.nn.modelimport.keras.layers.core.KerasLambda;
import org.deeplearning4j.nn.modelimport.keras.layers.core.KerasMasking;
import org.deeplearning4j.nn.modelimport.keras.layers.core.KerasMerge;
import org.deeplearning4j.nn.modelimport.keras.layers.core.KerasPermute;
import org.deeplearning4j.nn.modelimport.keras.layers.core.KerasRepeatVector;
import org.deeplearning4j.nn.modelimport.keras.layers.core.KerasReshape;
import org.deeplearning4j.nn.modelimport.keras.layers.core.KerasSpatialDropout;
import org.deeplearning4j.nn.modelimport.keras.layers.embeddings.KerasEmbedding;
import org.deeplearning4j.nn.modelimport.keras.layers.local.KerasLocallyConnected1D;
import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasAlphaDropout;
import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasGaussianDropout;
import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasGaussianNoise;
import org.deeplearning4j.nn.modelimport.keras.layers.normalization.KerasBatchNormalization;
import org.deeplearning4j.nn.modelimport.keras.layers.pooling.KerasGlobalPooling;
import org.deeplearning4j.nn.modelimport.keras.layers.pooling.KerasPooling1D;
import org.deeplearning4j.nn.modelimport.keras.layers.pooling.KerasPooling2D;
import org.deeplearning4j.nn.modelimport.keras.layers.pooling.KerasPooling3D;
import org.deeplearning4j.nn.modelimport.keras.layers.recurrent.KerasLSTM;
import org.deeplearning4j.nn.modelimport.keras.layers.recurrent.KerasSimpleRnn;
import org.deeplearning4j.nn.modelimport.keras.layers.wrappers.KerasBidirectional;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KerasLayerUtils {
    private static final Logger log = LoggerFactory.getLogger(KerasLayerUtils.class);

    public static void checkForUnsupportedConfigurations(Map<String, Object> layerConfig, boolean enforceTrainingConfig, KerasLayerConfiguration conf) throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        KerasLayerUtils.getBiasL1RegularizationFromConfig(layerConfig, enforceTrainingConfig, conf);
        KerasLayerUtils.getBiasL2RegularizationFromConfig(layerConfig, enforceTrainingConfig, conf);
        Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
        if (innerConfig.containsKey(conf.getLAYER_FIELD_W_REGULARIZER())) {
            KerasLayerUtils.checkForUnknownRegularizer((Map)innerConfig.get(conf.getLAYER_FIELD_W_REGULARIZER()), enforceTrainingConfig, conf);
        }
        if (innerConfig.containsKey(conf.getLAYER_FIELD_B_REGULARIZER())) {
            KerasLayerUtils.checkForUnknownRegularizer((Map)innerConfig.get(conf.getLAYER_FIELD_B_REGULARIZER()), enforceTrainingConfig, conf);
        }
    }

    public static double getBiasL1RegularizationFromConfig(Map<String, Object> layerConfig, boolean enforceTrainingConfig, KerasLayerConfiguration conf) throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        Map regularizerConfig;
        Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
        if (innerConfig.containsKey(conf.getLAYER_FIELD_B_REGULARIZER()) && (regularizerConfig = (Map)innerConfig.get(conf.getLAYER_FIELD_B_REGULARIZER())) != null && regularizerConfig.containsKey(conf.getREGULARIZATION_TYPE_L1())) {
            throw new UnsupportedKerasConfigurationException("L1 regularization for bias parameter not supported");
        }
        return 0.0;
    }

    private static double getBiasL2RegularizationFromConfig(Map<String, Object> layerConfig, boolean enforceTrainingConfig, KerasLayerConfiguration conf) throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        Map regularizerConfig;
        Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
        if (innerConfig.containsKey(conf.getLAYER_FIELD_B_REGULARIZER()) && (regularizerConfig = (Map)innerConfig.get(conf.getLAYER_FIELD_B_REGULARIZER())) != null && regularizerConfig.containsKey(conf.getREGULARIZATION_TYPE_L2())) {
            throw new UnsupportedKerasConfigurationException("L2 regularization for bias parameter not supported");
        }
        return 0.0;
    }

    private static void checkForUnknownRegularizer(Map<String, Object> regularizerConfig, boolean enforceTrainingConfig, KerasLayerConfiguration conf) throws UnsupportedKerasConfigurationException {
        if (regularizerConfig != null) {
            for (String field : regularizerConfig.keySet()) {
                if (field.equals(conf.getREGULARIZATION_TYPE_L1()) || field.equals(conf.getREGULARIZATION_TYPE_L2()) || field.equals(conf.getLAYER_FIELD_NAME()) || field.equals(conf.getLAYER_FIELD_CLASS_NAME()) || field.equals(conf.getLAYER_FIELD_CONFIG())) continue;
                if (enforceTrainingConfig) {
                    throw new UnsupportedKerasConfigurationException("Unknown regularization field " + field);
                }
                log.warn("Ignoring unknown regularization field " + field);
            }
        }
    }

    public static KerasLayer getKerasLayerFromConfig(Map<String, Object> layerConfig, KerasLayerConfiguration conf, Map<String, Class<? extends KerasLayer>> customLayers, Map<String, SameDiffLambdaLayer> lambdaLayers, Map<String, ? extends KerasLayer> previousLayers) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        return KerasLayerUtils.getKerasLayerFromConfig(layerConfig, false, conf, customLayers, lambdaLayers, previousLayers);
    }

    public static KerasLayer getKerasLayerFromConfig(Map<String, Object> layerConfig, boolean enforceTrainingConfig, KerasLayerConfiguration conf, Map<String, Class<? extends KerasLayer>> customLayers, Map<String, SameDiffLambdaLayer> lambdaLayers, Map<String, ? extends KerasLayer> previousLayers) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        Keras2LayerConfiguration k2conf;
        String layerClassName = KerasLayerUtils.getClassNameFromConfig(layerConfig, conf);
        if (layerClassName.equals(conf.getLAYER_CLASS_NAME_TIME_DISTRIBUTED())) {
            layerConfig = KerasLayerUtils.getTimeDistributedLayerConfig(layerConfig, conf);
            layerClassName = KerasLayerUtils.getClassNameFromConfig(layerConfig, conf);
        }
        KerasLayer layer = null;
        if (layerClassName.equals(conf.getLAYER_CLASS_NAME_ACTIVATION())) {
            layer = new KerasActivation(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_LEAKY_RELU())) {
            layer = new KerasLeakyReLU(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_MASKING())) {
            layer = new KerasMasking(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_THRESHOLDED_RELU())) {
            layer = new KerasThresholdedReLU(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_PRELU())) {
            layer = new KerasPReLU(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_DROPOUT())) {
            layer = new KerasDropout(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_SPATIAL_DROPOUT_1D()) || layerClassName.equals(conf.getLAYER_CLASS_NAME_SPATIAL_DROPOUT_2D()) || layerClassName.equals(conf.getLAYER_CLASS_NAME_SPATIAL_DROPOUT_3D())) {
            layer = new KerasSpatialDropout(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_ALPHA_DROPOUT())) {
            layer = new KerasAlphaDropout(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_GAUSSIAN_DROPOUT())) {
            layer = new KerasGaussianDropout(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_GAUSSIAN_NOISE())) {
            layer = new KerasGaussianNoise(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_DENSE()) || layerClassName.equals(conf.getLAYER_CLASS_NAME_TIME_DISTRIBUTED_DENSE())) {
            layer = new KerasDense(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_BIDIRECTIONAL())) {
            layer = new KerasBidirectional(layerConfig, enforceTrainingConfig, previousLayers);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_LSTM())) {
            layer = new KerasLSTM(layerConfig, enforceTrainingConfig, previousLayers);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_SIMPLE_RNN())) {
            layer = new KerasSimpleRnn(layerConfig, enforceTrainingConfig, previousLayers);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_CONVOLUTION_3D())) {
            layer = new KerasConvolution3D(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_CONVOLUTION_2D())) {
            layer = new KerasConvolution2D(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_DECONVOLUTION_2D())) {
            layer = new KerasDeconvolution2D(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_CONVOLUTION_1D())) {
            layer = new KerasConvolution1D(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_ATROUS_CONVOLUTION_2D())) {
            layer = new KerasAtrousConvolution2D(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_ATROUS_CONVOLUTION_1D())) {
            layer = new KerasAtrousConvolution1D(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_DEPTHWISE_CONVOLUTION_2D())) {
            layer = new KerasDepthwiseConvolution2D(layerConfig, previousLayers, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_SEPARABLE_CONVOLUTION_2D())) {
            layer = new KerasSeparableConvolution2D(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_MAX_POOLING_3D()) || layerClassName.equals(conf.getLAYER_CLASS_NAME_AVERAGE_POOLING_3D())) {
            layer = new KerasPooling3D(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_MAX_POOLING_2D()) || layerClassName.equals(conf.getLAYER_CLASS_NAME_AVERAGE_POOLING_2D())) {
            layer = new KerasPooling2D(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_MAX_POOLING_1D()) || layerClassName.equals(conf.getLAYER_CLASS_NAME_AVERAGE_POOLING_1D())) {
            layer = new KerasPooling1D(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_GLOBAL_AVERAGE_POOLING_1D()) || layerClassName.equals(conf.getLAYER_CLASS_NAME_GLOBAL_AVERAGE_POOLING_2D()) || layerClassName.equals(conf.getLAYER_CLASS_NAME_GLOBAL_AVERAGE_POOLING_3D()) || layerClassName.equals(conf.getLAYER_CLASS_NAME_GLOBAL_MAX_POOLING_1D()) || layerClassName.equals(conf.getLAYER_CLASS_NAME_GLOBAL_MAX_POOLING_2D()) || layerClassName.equals(conf.getLAYER_CLASS_NAME_GLOBAL_MAX_POOLING_3D())) {
            layer = new KerasGlobalPooling(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_BATCHNORMALIZATION())) {
            layer = new KerasBatchNormalization(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_EMBEDDING())) {
            layer = new KerasEmbedding(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_INPUT())) {
            layer = new KerasInput(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_REPEAT())) {
            layer = new KerasRepeatVector(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_PERMUTE())) {
            layer = new KerasPermute(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_MERGE())) {
            layer = new KerasMerge(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_ADD()) || layerClassName.equals(conf.getLAYER_CLASS_NAME_ADD())) {
            layer = new KerasMerge(layerConfig, ElementWiseVertex.Op.Add, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_SUBTRACT()) || layerClassName.equals(conf.getLAYER_CLASS_NAME_FUNCTIONAL_SUBTRACT())) {
            layer = new KerasMerge(layerConfig, ElementWiseVertex.Op.Subtract, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_AVERAGE()) || layerClassName.equals(conf.getLAYER_CLASS_NAME_FUNCTIONAL_AVERAGE())) {
            layer = new KerasMerge(layerConfig, ElementWiseVertex.Op.Average, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_MULTIPLY()) || layerClassName.equals(conf.getLAYER_CLASS_NAME_FUNCTIONAL_MULTIPLY())) {
            layer = new KerasMerge(layerConfig, ElementWiseVertex.Op.Product, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_CONCATENATE()) || layerClassName.equals(conf.getLAYER_CLASS_NAME_FUNCTIONAL_CONCATENATE())) {
            layer = new KerasMerge(layerConfig, null, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_FLATTEN())) {
            layer = new KerasFlatten(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_RESHAPE())) {
            layer = new KerasReshape(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_ZERO_PADDING_1D())) {
            layer = new KerasZeroPadding1D(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_ZERO_PADDING_2D())) {
            layer = new KerasZeroPadding2D(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_ZERO_PADDING_3D())) {
            layer = new KerasZeroPadding3D(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_UPSAMPLING_1D())) {
            layer = new KerasUpsampling1D(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_UPSAMPLING_2D())) {
            layer = new KerasUpsampling2D(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_CROPPING_3D())) {
            layer = new KerasCropping3D(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_CROPPING_2D())) {
            layer = new KerasCropping2D(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_CROPPING_1D())) {
            layer = new KerasCropping1D(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_LAMBDA())) {
            String lambdaLayerName = KerasLayerUtils.getLayerNameFromConfig(layerConfig, conf);
            if (!lambdaLayers.containsKey(lambdaLayerName) && !customLayers.containsKey(layerClassName)) {
                throw new UnsupportedKerasConfigurationException("No SameDiff Lambda layer found for Lambda layer " + lambdaLayerName + ". You can register a SameDiff Lambda layer using KerasLayer.registerLambdaLayer(lambdaLayerName, sameDiffLambdaLayer);");
            }
            SameDiffLambdaLayer lambdaLayer = lambdaLayers.get(lambdaLayerName);
            if (lambdaLayer != null) {
                layer = new KerasLambda(layerConfig, enforceTrainingConfig, (SameDiffLayer)lambdaLayer);
            }
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_RELU())) {
            layer = new KerasReLU(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_ELU())) {
            layer = new KerasELU(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_SOFTMAX())) {
            layer = new KerasSoftmax(layerConfig, enforceTrainingConfig);
        } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_LOCALLY_CONNECTED_1D())) {
            layer = new KerasLocallyConnected1D(layerConfig, enforceTrainingConfig);
        } else if (conf instanceof Keras2LayerConfiguration && layerClassName.equals((k2conf = (Keras2LayerConfiguration)conf).getTENSORFLOW_OP_LAYER())) {
            layer = new KerasTFOpLayer(layerConfig, enforceTrainingConfig);
        }
        if (layer == null) {
            Class<? extends KerasLayer> customConfig = customLayers.get(layerClassName);
            if (customConfig == null) {
                throw new UnsupportedKerasConfigurationException("Unsupported keras layer type " + layerClassName);
            }
            try {
                Constructor<? extends KerasLayer> constructor = customConfig.getConstructor(Map.class);
                layer = constructor.newInstance(layerConfig);
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        return layer;
    }

    public static String getClassNameFromConfig(Map<String, Object> layerConfig, KerasLayerConfiguration conf) throws InvalidKerasConfigurationException {
        if (!layerConfig.containsKey(conf.getLAYER_FIELD_CLASS_NAME())) {
            throw new InvalidKerasConfigurationException("Field " + conf.getLAYER_FIELD_CLASS_NAME() + " missing from layer config");
        }
        return (String)layerConfig.get(conf.getLAYER_FIELD_CLASS_NAME());
    }

    public static Map<String, Object> getTimeDistributedLayerConfig(Map<String, Object> layerConfig, KerasLayerConfiguration conf) throws InvalidKerasConfigurationException {
        if (!layerConfig.containsKey(conf.getLAYER_FIELD_CLASS_NAME())) {
            throw new InvalidKerasConfigurationException("Field " + conf.getLAYER_FIELD_CLASS_NAME() + " missing from layer config");
        }
        if (!layerConfig.get(conf.getLAYER_FIELD_CLASS_NAME()).equals(conf.getLAYER_CLASS_NAME_TIME_DISTRIBUTED())) {
            throw new InvalidKerasConfigurationException("Expected " + conf.getLAYER_CLASS_NAME_TIME_DISTRIBUTED() + " layer, found " + layerConfig.get(conf.getLAYER_FIELD_CLASS_NAME()));
        }
        if (!layerConfig.containsKey(conf.getLAYER_FIELD_CONFIG())) {
            throw new InvalidKerasConfigurationException("Field " + conf.getLAYER_FIELD_CONFIG() + " missing from layer config");
        }
        Map<String, Object> outerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
        Map innerLayer = (Map)outerConfig.get(conf.getLAYER_FIELD_LAYER());
        layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), innerLayer.get(conf.getLAYER_FIELD_CLASS_NAME()));
        Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(innerLayer, conf);
        innerConfig.put(conf.getLAYER_FIELD_NAME(), outerConfig.get(conf.getLAYER_FIELD_NAME()));
        outerConfig.putAll(innerConfig);
        outerConfig.remove(conf.getLAYER_FIELD_LAYER());
        return layerConfig;
    }

    public static Map<String, Object> getInnerLayerConfigFromConfig(Map<String, Object> layerConfig, KerasLayerConfiguration conf) throws InvalidKerasConfigurationException {
        if (!layerConfig.containsKey(conf.getLAYER_FIELD_CONFIG())) {
            throw new InvalidKerasConfigurationException("Field " + conf.getLAYER_FIELD_CONFIG() + " missing from layer config");
        }
        return (Map)layerConfig.get(conf.getLAYER_FIELD_CONFIG());
    }

    public static String getLayerNameFromConfig(Map<String, Object> layerConfig, KerasLayerConfiguration conf) throws InvalidKerasConfigurationException {
        Map<String, Object> innerConfig;
        if (conf instanceof Keras2LayerConfiguration) {
            Keras2LayerConfiguration k2conf = (Keras2LayerConfiguration)conf;
            if (KerasLayerUtils.getClassNameFromConfig(layerConfig, conf).equals(((Keras2LayerConfiguration)conf).getTENSORFLOW_OP_LAYER())) {
                if (!layerConfig.containsKey(conf.getLAYER_FIELD_NAME())) {
                    throw new InvalidKerasConfigurationException("Field " + conf.getLAYER_FIELD_NAME() + " missing from layer config");
                }
                return (String)layerConfig.get(conf.getLAYER_FIELD_NAME());
            }
        }
        if (!(innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf)).containsKey(conf.getLAYER_FIELD_NAME())) {
            throw new InvalidKerasConfigurationException("Field " + conf.getLAYER_FIELD_NAME() + " missing from layer config");
        }
        return (String)innerConfig.get(conf.getLAYER_FIELD_NAME());
    }

    public static int[] getInputShapeFromConfig(Map<String, Object> layerConfig, KerasLayerConfiguration conf) throws InvalidKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
        if (!innerConfig.containsKey(conf.getLAYER_FIELD_BATCH_INPUT_SHAPE())) {
            return null;
        }
        List batchInputShape = (List)innerConfig.get(conf.getLAYER_FIELD_BATCH_INPUT_SHAPE());
        int[] inputShape = new int[batchInputShape.size() - 1];
        for (int i = 1; i < batchInputShape.size(); ++i) {
            inputShape[i - 1] = batchInputShape.get(i) != null ? (Integer)batchInputShape.get(i) : 0;
        }
        return inputShape;
    }

    public static KerasLayer.DimOrder getDimOrderFromConfig(Map<String, Object> layerConfig, KerasLayerConfiguration conf) throws InvalidKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
        KerasLayer.DimOrder dimOrder = KerasLayer.DimOrder.NONE;
        if (layerConfig.containsKey(conf.getLAYER_FIELD_BACKEND())) {
            String backend = (String)layerConfig.get(conf.getLAYER_FIELD_BACKEND());
            if (backend.equals("tensorflow") || backend.equals("cntk")) {
                dimOrder = KerasLayer.DimOrder.TENSORFLOW;
            } else if (backend.equals("theano")) {
                dimOrder = KerasLayer.DimOrder.THEANO;
            }
        }
        if (innerConfig.containsKey(conf.getLAYER_FIELD_DIM_ORDERING())) {
            String dimOrderStr = (String)innerConfig.get(conf.getLAYER_FIELD_DIM_ORDERING());
            if (dimOrderStr.equals(conf.getDIM_ORDERING_TENSORFLOW())) {
                dimOrder = KerasLayer.DimOrder.TENSORFLOW;
            } else if (dimOrderStr.equals(conf.getDIM_ORDERING_THEANO())) {
                dimOrder = KerasLayer.DimOrder.THEANO;
            } else {
                log.warn("Keras layer has unknown Keras dimension order: " + (Object)((Object)dimOrder));
            }
        }
        return dimOrder;
    }

    public static List<String> getInboundLayerNamesFromConfig(Map<String, Object> layerConfig, KerasLayerConfiguration conf) {
        List inboundNodes;
        ArrayList<String> inboundLayerNames = new ArrayList<String>();
        if (layerConfig.containsKey(conf.getLAYER_FIELD_INBOUND_NODES()) && !(inboundNodes = (List)layerConfig.get(conf.getLAYER_FIELD_INBOUND_NODES())).isEmpty()) {
            inboundNodes = (List)inboundNodes.get(0);
            for (Object o : inboundNodes) {
                String nodeName = (String)((List)o).get(0);
                inboundLayerNames.add(nodeName);
            }
        }
        return inboundLayerNames;
    }

    public static int getNOutFromConfig(Map<String, Object> layerConfig, KerasLayerConfiguration conf) throws InvalidKerasConfigurationException {
        int nOut;
        Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
        if (innerConfig.containsKey(conf.getLAYER_FIELD_OUTPUT_DIM())) {
            nOut = (Integer)innerConfig.get(conf.getLAYER_FIELD_OUTPUT_DIM());
        } else if (innerConfig.containsKey(conf.getLAYER_FIELD_EMBEDDING_OUTPUT_DIM())) {
            nOut = (Integer)innerConfig.get(conf.getLAYER_FIELD_EMBEDDING_OUTPUT_DIM());
        } else if (innerConfig.containsKey(conf.getLAYER_FIELD_NB_FILTER())) {
            nOut = (Integer)innerConfig.get(conf.getLAYER_FIELD_NB_FILTER());
        } else {
            throw new InvalidKerasConfigurationException("Could not determine number of outputs for layer: no " + conf.getLAYER_FIELD_OUTPUT_DIM() + " or " + conf.getLAYER_FIELD_NB_FILTER() + " field found");
        }
        return nOut;
    }

    public static Integer getNInFromInputDim(Map<String, Object> layerConfig, KerasLayerConfiguration conf) throws InvalidKerasConfigurationException {
        Object id;
        Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
        if (innerConfig.containsKey(conf.getLAYER_FIELD_INPUT_DIM()) && (id = innerConfig.get(conf.getLAYER_FIELD_INPUT_DIM())) instanceof Number) {
            return ((Number)id).intValue();
        }
        return null;
    }

    public static double getDropoutFromConfig(Map<String, Object> layerConfig, KerasLayerConfiguration conf) throws InvalidKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
        double dropout = 1.0;
        if (innerConfig.containsKey(conf.getLAYER_FIELD_DROPOUT())) {
            try {
                dropout = 1.0 - (Double)innerConfig.get(conf.getLAYER_FIELD_DROPOUT());
            }
            catch (Exception e) {
                int kerasDropout = (Integer)innerConfig.get(conf.getLAYER_FIELD_DROPOUT());
                dropout = 1.0 - (double)kerasDropout;
            }
        } else if (innerConfig.containsKey(conf.getLAYER_FIELD_DROPOUT_W())) {
            try {
                dropout = 1.0 - (Double)innerConfig.get(conf.getLAYER_FIELD_DROPOUT_W());
            }
            catch (Exception e) {
                int kerasDropout = (Integer)innerConfig.get(conf.getLAYER_FIELD_DROPOUT_W());
                dropout = 1.0 - (double)kerasDropout;
            }
        }
        return dropout;
    }

    public static boolean getHasBiasFromConfig(Map<String, Object> layerConfig, KerasLayerConfiguration conf) throws InvalidKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
        boolean hasBias = true;
        if (innerConfig.containsKey(conf.getLAYER_FIELD_USE_BIAS())) {
            hasBias = (Boolean)innerConfig.get(conf.getLAYER_FIELD_USE_BIAS());
        }
        return hasBias;
    }

    public static boolean getZeroMaskingFromConfig(Map<String, Object> layerConfig, KerasLayerConfiguration conf) throws InvalidKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
        boolean hasZeroMasking = true;
        if (innerConfig.containsKey(conf.getLAYER_FIELD_MASK_ZERO())) {
            hasZeroMasking = (Boolean)innerConfig.get(conf.getLAYER_FIELD_MASK_ZERO());
        }
        return hasZeroMasking;
    }

    public static double getMaskingValueFromConfig(Map<String, Object> layerConfig, KerasLayerConfiguration conf) throws InvalidKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
        double maskValue = 0.0;
        if (innerConfig.containsKey(conf.getLAYER_FIELD_MASK_VALUE())) {
            try {
                maskValue = (Double)innerConfig.get(conf.getLAYER_FIELD_MASK_VALUE());
            }
            catch (Exception e) {
                log.warn("Couldn't read masking value, default to 0.0");
            }
        } else {
            throw new InvalidKerasConfigurationException("No mask value found, field " + conf.getLAYER_FIELD_MASK_VALUE());
        }
        return maskValue;
    }

    public static void removeDefaultWeights(Map<String, INDArray> weights, KerasLayerConfiguration conf) {
        if (weights.size() > 2) {
            Set<String> paramNames = weights.keySet();
            paramNames.remove(conf.getKERAS_PARAM_NAME_W());
            paramNames.remove(conf.getKERAS_PARAM_NAME_B());
            String unknownParamNames = paramNames.toString();
            log.warn("Attemping to set weights for unknown parameters: " + unknownParamNames.substring(1, unknownParamNames.length() - 1));
        }
    }

    public static Pair<Boolean, Double> getMaskingConfiguration(List<String> inboundLayerNames, Map<String, ? extends KerasLayer> previousLayers) {
        Boolean hasMasking = false;
        Double maskingValue = 0.0;
        for (String inboundLayerName : inboundLayerNames) {
            if (!previousLayers.containsKey(inboundLayerName)) continue;
            KerasLayer inbound = previousLayers.get(inboundLayerName);
            if (inbound instanceof KerasEmbedding && ((KerasEmbedding)inbound).isZeroMasking()) {
                hasMasking = true;
                continue;
            }
            if (!(inbound instanceof KerasMasking)) continue;
            hasMasking = true;
            maskingValue = ((KerasMasking)inbound).getMaskingValue();
        }
        return new Pair((Object)hasMasking, (Object)maskingValue);
    }
}

