package org.deeplearning4j.nn.modelimport.keras.utils;

import java.util.HashMap;
import java.util.Map;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/modelimport/keras/utils/KerasLossUtils.class */
public class KerasLossUtils {
    private static final Logger log = LoggerFactory.getLogger(KerasLossUtils.class);
    static final Map<String, ILossFunction> customLoss = new HashMap();

    public static void registerCustomLoss(String str, ILossFunction iLossFunction) {
        customLoss.put(str, iLossFunction);
    }

    public static void clearCustomLoss() {
        customLoss.clear();
    }

    public static ILossFunction mapLossFunction(String str, KerasLayerConfiguration kerasLayerConfiguration) throws UnsupportedKerasConfigurationException {
        LossFunctions.LossFunction lossFunction;
        String lowerCase = str.toLowerCase();
        if (lowerCase.equals(kerasLayerConfiguration.getKERAS_LOSS_MEAN_SQUARED_ERROR()) || lowerCase.equals(kerasLayerConfiguration.getKERAS_LOSS_MSE()) || lowerCase.equals(kerasLayerConfiguration.getTF_KERAS_LOSS_MEAN_SQUARED_ERROR())) {
            lossFunction = LossFunctions.LossFunction.SQUARED_LOSS;
        } else if (lowerCase.equals(kerasLayerConfiguration.getKERAS_LOSS_MEAN_ABSOLUTE_ERROR()) || lowerCase.equals(kerasLayerConfiguration.getKERAS_LOSS_MAE()) || lowerCase.equals(kerasLayerConfiguration.getTF_KERAS_LOSS_MEAN_ABSOLUTE_ERROR())) {
            lossFunction = LossFunctions.LossFunction.MEAN_ABSOLUTE_ERROR;
        } else if (lowerCase.equals(kerasLayerConfiguration.getKERAS_LOSS_MEAN_ABSOLUTE_PERCENTAGE_ERROR()) || lowerCase.equals(kerasLayerConfiguration.getKERAS_LOSS_MAPE()) || lowerCase.equals(kerasLayerConfiguration.getTF_KERAS_LOSS_MEAN_ABSOLUTE_PERCENTAGE_ERROR())) {
            lossFunction = LossFunctions.LossFunction.MEAN_ABSOLUTE_PERCENTAGE_ERROR;
        } else if (lowerCase.equals(kerasLayerConfiguration.getKERAS_LOSS_MEAN_SQUARED_LOGARITHMIC_ERROR()) || lowerCase.equals(kerasLayerConfiguration.getKERAS_LOSS_MSLE()) || lowerCase.equals(kerasLayerConfiguration.getTF_KERAS_LOSS_MEAN_SQUARED_LOGARITHMIC_ERROR())) {
            lossFunction = LossFunctions.LossFunction.MEAN_SQUARED_LOGARITHMIC_ERROR;
        } else if (lowerCase.equals(kerasLayerConfiguration.getKERAS_LOSS_SQUARED_HINGE()) || lowerCase.equals(kerasLayerConfiguration.getTF_KERAS_LOSS_SQUARED_HINGE())) {
            lossFunction = LossFunctions.LossFunction.SQUARED_HINGE;
        } else if (lowerCase.equals(kerasLayerConfiguration.getKERAS_LOSS_HINGE())) {
            lossFunction = LossFunctions.LossFunction.HINGE;
        } else if (lowerCase.equals(kerasLayerConfiguration.getKERAS_LOSS_SPARSE_CATEGORICAL_CROSSENTROPY()) || lowerCase.equals(kerasLayerConfiguration.getTF_KERAS_LOSS_SPARSE_CATEGORICAL_CROSS_ENTROPY())) {
            lossFunction = LossFunctions.LossFunction.SPARSE_MCXENT;
        } else if (lowerCase.equals(kerasLayerConfiguration.getKERAS_LOSS_BINARY_CROSSENTROPY()) || lowerCase.equals(kerasLayerConfiguration.getTF_KERAS_LOSS_BINARY_CROSSENTROPY())) {
            lossFunction = LossFunctions.LossFunction.XENT;
        } else if (lowerCase.equals(kerasLayerConfiguration.getKERAS_LOSS_CATEGORICAL_CROSSENTROPY())) {
            lossFunction = LossFunctions.LossFunction.MCXENT;
        } else if (lowerCase.equals(kerasLayerConfiguration.getKERAS_LOSS_KULLBACK_LEIBLER_DIVERGENCE()) || lowerCase.equals(kerasLayerConfiguration.getKERAS_LOSS_KLD()) || lowerCase.equals(kerasLayerConfiguration.getTF_KERAS_LOSS_KLDIVERGENCE())) {
            lossFunction = LossFunctions.LossFunction.KL_DIVERGENCE;
        } else if (lowerCase.equals(kerasLayerConfiguration.getKERAS_LOSS_POISSON())) {
            lossFunction = LossFunctions.LossFunction.POISSON;
        } else {
            if (!lowerCase.equals(kerasLayerConfiguration.getKERAS_LOSS_COSINE_PROXIMITY()) && !lowerCase.equals(kerasLayerConfiguration.getTF_KERAS_LOSS_COSINE_SIMILARITY())) {
                ILossFunction iLossFunction = customLoss.get(lowerCase);
                if (iLossFunction != null) {
                    return iLossFunction;
                }
                throw new UnsupportedKerasConfigurationException("Unknown Keras loss function " + lowerCase);
            }
            lossFunction = LossFunctions.LossFunction.COSINE_PROXIMITY;
        }
        return lossFunction.getILossFunction();
    }
}
