/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.modelimport.keras.utils;

import java.util.Map;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;

public class KerasActivationUtils {
    public static Activation mapToActivation(String kerasActivation, KerasLayerConfiguration conf) throws UnsupportedKerasConfigurationException {
        Activation dl4jActivation;
        if (kerasActivation.equals(conf.getKERAS_ACTIVATION_SOFTMAX())) {
            dl4jActivation = Activation.SOFTMAX;
        } else if (kerasActivation.equals(conf.getKERAS_ACTIVATION_SOFTPLUS())) {
            dl4jActivation = Activation.SOFTPLUS;
        } else if (kerasActivation.equals(conf.getKERAS_ACTIVATION_SOFTSIGN())) {
            dl4jActivation = Activation.SOFTSIGN;
        } else if (kerasActivation.equals(conf.getKERAS_ACTIVATION_RELU())) {
            dl4jActivation = Activation.RELU;
        } else if (kerasActivation.equals(conf.getKERAS_ACTIVATION_RELU6())) {
            dl4jActivation = Activation.RELU6;
        } else if (kerasActivation.equals(conf.getKERAS_ACTIVATION_ELU())) {
            dl4jActivation = Activation.ELU;
        } else if (kerasActivation.equals(conf.getKERAS_ACTIVATION_SELU())) {
            dl4jActivation = Activation.SELU;
        } else if (kerasActivation.equals(conf.getKERAS_ACTIVATION_TANH())) {
            dl4jActivation = Activation.TANH;
        } else if (kerasActivation.equals(conf.getKERAS_ACTIVATION_SIGMOID())) {
            dl4jActivation = Activation.SIGMOID;
        } else if (kerasActivation.equals(conf.getKERAS_ACTIVATION_HARD_SIGMOID())) {
            dl4jActivation = Activation.HARDSIGMOID;
        } else if (kerasActivation.equals(conf.getKERAS_ACTIVATION_LINEAR())) {
            dl4jActivation = Activation.IDENTITY;
        } else if (kerasActivation.equals(conf.getKERAS_ACTIVATION_SWISH())) {
            dl4jActivation = Activation.SWISH;
        } else {
            throw new UnsupportedKerasConfigurationException("Unknown Keras activation function " + kerasActivation);
        }
        return dl4jActivation;
    }

    public static IActivation mapToIActivation(String kerasActivation, KerasLayerConfiguration conf) throws UnsupportedKerasConfigurationException {
        Activation activation = KerasActivationUtils.mapToActivation(kerasActivation, conf);
        return activation.getActivationFunction();
    }

    public static IActivation getIActivationFromConfig(Map<String, Object> layerConfig, KerasLayerConfiguration conf) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        return KerasActivationUtils.getActivationFromConfig(layerConfig, conf).getActivationFunction();
    }

    public static Activation getActivationFromConfig(Map<String, Object> layerConfig, KerasLayerConfiguration conf) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
        if (!innerConfig.containsKey(conf.getLAYER_FIELD_ACTIVATION())) {
            throw new InvalidKerasConfigurationException("Keras layer is missing " + conf.getLAYER_FIELD_ACTIVATION() + " field");
        }
        return KerasActivationUtils.mapToActivation((String)innerConfig.get(conf.getLAYER_FIELD_ACTIVATION()), conf);
    }
}

