/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.core.evaluation;

import java.awt.Color;
import java.io.File;
import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.ui.api.Component;
import org.deeplearning4j.ui.api.LengthUnit;
import org.deeplearning4j.ui.api.Style;
import org.deeplearning4j.ui.components.chart.ChartHistogram;
import org.deeplearning4j.ui.components.chart.ChartLine;
import org.deeplearning4j.ui.components.chart.style.StyleChart;
import org.deeplearning4j.ui.components.component.ComponentDiv;
import org.deeplearning4j.ui.components.component.style.StyleDiv;
import org.deeplearning4j.ui.components.table.ComponentTable;
import org.deeplearning4j.ui.components.table.style.StyleTable;
import org.deeplearning4j.ui.components.text.ComponentText;
import org.deeplearning4j.ui.components.text.style.StyleText;
import org.deeplearning4j.ui.standalone.StaticPageUtil;
import org.nd4j.evaluation.classification.EvaluationCalibration;
import org.nd4j.evaluation.classification.ROC;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.evaluation.curves.Histogram;
import org.nd4j.evaluation.curves.PrecisionRecallCurve;
import org.nd4j.evaluation.curves.ReliabilityDiagram;
import org.nd4j.evaluation.curves.RocCurve;

public class EvaluationTools {
    private static final String ROC_TITLE = "ROC: TPR/Recall (y) vs. FPR (x)";
    private static final String PR_TITLE = "Precision (y) vs. Recall (x)";
    private static final String PR_THRESHOLD_TITLE = "Precision and Recall (y) vs. Classifier Threshold (x)";
    private static final double CHART_WIDTH_PX = 600.0;
    private static final double CHART_HEIGHT_PX = 400.0;
    private static final StyleChart CHART_STYLE = ((StyleChart.Builder)((StyleChart.Builder)((StyleChart.Builder)new StyleChart.Builder().width(600.0, LengthUnit.Px)).height(400.0, LengthUnit.Px)).margin(LengthUnit.Px, Integer.valueOf(60), Integer.valueOf(60), Integer.valueOf(75), Integer.valueOf(10))).strokeWidth(2.0).seriesColors(new Color[]{Color.BLUE, Color.LIGHT_GRAY}).build();
    private static final StyleChart CHART_STYLE_PRECISION_RECALL = ((StyleChart.Builder)((StyleChart.Builder)((StyleChart.Builder)new StyleChart.Builder().width(600.0, LengthUnit.Px)).height(400.0, LengthUnit.Px)).margin(LengthUnit.Px, Integer.valueOf(60), Integer.valueOf(60), Integer.valueOf(40), Integer.valueOf(10))).strokeWidth(2.0).seriesColors(new Color[]{Color.BLUE, Color.GREEN}).build();
    private static final StyleTable TABLE_STYLE = ((StyleTable.Builder)((StyleTable.Builder)new StyleTable.Builder().backgroundColor(Color.WHITE).headerColor(Color.LIGHT_GRAY).borderWidth(1).columnWidths(LengthUnit.Percent, new double[]{50.0, 50.0}).width(400.0, LengthUnit.Px)).height(200.0, LengthUnit.Px)).build();
    private static final StyleDiv OUTER_DIV_STYLE = ((StyleDiv.Builder)((StyleDiv.Builder)new StyleDiv.Builder().width(1200.0, LengthUnit.Px)).height(400.0, LengthUnit.Px)).build();
    private static final StyleDiv OUTER_DIV_STYLE_WIDTH_ONLY = ((StyleDiv.Builder)new StyleDiv.Builder().width(1200.0, LengthUnit.Px)).build();
    private static final StyleDiv INNER_DIV_STYLE = ((StyleDiv.Builder)new StyleDiv.Builder().width(600.0, LengthUnit.Px)).floatValue(StyleDiv.FloatValue.left).build();
    private static final StyleDiv PAD_DIV_STYLE = ((StyleDiv.Builder)((StyleDiv.Builder)new StyleDiv.Builder().width(600.0, LengthUnit.Px)).height(100.0, LengthUnit.Px)).floatValue(StyleDiv.FloatValue.left).build();
    private static final ComponentDiv PAD_DIV = new ComponentDiv((Style)PAD_DIV_STYLE, new Component[0]);
    private static final StyleText HEADER_TEXT_STYLE = new StyleText.Builder().color(Color.BLACK).fontSize(16.0).underline(true).build();
    private static final StyleDiv HEADER_DIV_STYLE = ((StyleDiv.Builder)((StyleDiv.Builder)((StyleDiv.Builder)((StyleDiv.Builder)new StyleDiv.Builder().width(1050.0, LengthUnit.Px)).height(30.0, LengthUnit.Px)).backgroundColor(Color.LIGHT_GRAY)).margin(LengthUnit.Px, Integer.valueOf(5), Integer.valueOf(5), Integer.valueOf(200), Integer.valueOf(10))).floatValue(StyleDiv.FloatValue.left).build();
    private static final StyleDiv HEADER_DIV_STYLE_1400 = ((StyleDiv.Builder)((StyleDiv.Builder)((StyleDiv.Builder)((StyleDiv.Builder)new StyleDiv.Builder().width(1250.0, LengthUnit.Px)).height(30.0, LengthUnit.Px)).backgroundColor(Color.LIGHT_GRAY)).margin(LengthUnit.Px, Integer.valueOf(5), Integer.valueOf(5), Integer.valueOf(200), Integer.valueOf(10))).floatValue(StyleDiv.FloatValue.left).build();
    private static final StyleDiv HEADER_DIV_PAD_STYLE = ((StyleDiv.Builder)((StyleDiv.Builder)((StyleDiv.Builder)new StyleDiv.Builder().width(1200.0, LengthUnit.Px)).height(150.0, LengthUnit.Px)).backgroundColor(Color.WHITE)).build();
    private static final StyleDiv HEADER_DIV_TEXT_PAD_STYLE = ((StyleDiv.Builder)((StyleDiv.Builder)((StyleDiv.Builder)new StyleDiv.Builder().width(120.0, LengthUnit.Px)).height(30.0, LengthUnit.Px)).backgroundColor(Color.LIGHT_GRAY)).floatValue(StyleDiv.FloatValue.left).build();
    private static final ComponentTable INFO_TABLE = new ComponentTable.Builder(new StyleTable.Builder().backgroundColor(Color.WHITE).borderWidth(0).build()).content((String[][])new String[][]{{"Precision", "(true positives) / (true positives + false positives)"}, {"True Positive Rate (Recall)", "(true positives) / (data positives)"}, {"False Positive Rate", "(false positives) / (data negatives)"}}).build();

    private EvaluationTools() {
    }

    public static void exportRocChartsToHtmlFile(ROC roc, File file) throws IOException {
        String rocAsHtml = EvaluationTools.rocChartToHtml(roc);
        FileUtils.writeStringToFile((File)file, (String)rocAsHtml);
    }

    public static void exportRocChartsToHtmlFile(ROCMultiClass roc, File file) throws Exception {
        String rocAsHtml = EvaluationTools.rocChartToHtml(roc);
        FileUtils.writeStringToFile((File)file, (String)rocAsHtml, (Charset)StandardCharsets.UTF_8);
    }

    public static String rocChartToHtml(ROC roc) {
        RocCurve rocCurve = roc.getRocCurve();
        Component c = EvaluationTools.getRocFromPoints(ROC_TITLE, rocCurve, roc.getCountActualPositive(), roc.getCountActualNegative(), roc.calculateAUC(), roc.calculateAUCPR());
        Component c2 = EvaluationTools.getPRCharts(PR_TITLE, PR_THRESHOLD_TITLE, roc.getPrecisionRecallCurve());
        return StaticPageUtil.renderHTML((Component[])new Component[]{c, c2});
    }

    public static String rocChartToHtml(ROCMultiClass rocMultiClass) {
        return EvaluationTools.rocChartToHtml(rocMultiClass, null);
    }

    public static String rocChartToHtml(ROCMultiClass rocMultiClass, List<String> classNames) {
        int n = rocMultiClass.getNumClasses();
        ArrayList<Object> components = new ArrayList<Object>(n);
        for (int i = 0; i < n; ++i) {
            RocCurve roc = rocMultiClass.getRocCurve(i);
            String headerText = "Class " + i;
            if (classNames != null && classNames.size() > i) {
                headerText = headerText + " (" + classNames.get(i) + ")";
            }
            headerText = headerText + " vs. All";
            ComponentDiv headerDivPad = new ComponentDiv((Style)HEADER_DIV_PAD_STYLE, new Component[0]);
            components.add(headerDivPad);
            ComponentDiv headerDivLeft = new ComponentDiv((Style)HEADER_DIV_TEXT_PAD_STYLE, new Component[0]);
            ComponentDiv headerDiv = new ComponentDiv((Style)HEADER_DIV_STYLE, new Component[]{new ComponentText(headerText, HEADER_TEXT_STYLE)});
            Component c = EvaluationTools.getRocFromPoints(ROC_TITLE, roc, rocMultiClass.getCountActualPositive(i), rocMultiClass.getCountActualNegative(i), rocMultiClass.calculateAUC(i), rocMultiClass.calculateAUCPR(i));
            Component c2 = EvaluationTools.getPRCharts(PR_TITLE, PR_THRESHOLD_TITLE, rocMultiClass.getPrecisionRecallCurve(i));
            components.add(headerDivLeft);
            components.add(headerDiv);
            components.add(c);
            components.add(c2);
        }
        return StaticPageUtil.renderHTML(components);
    }

    public static void exportevaluationCalibrationToHtmlFile(EvaluationCalibration ec, File file) throws IOException {
        String asHtml = EvaluationTools.evaluationCalibrationToHtml(ec);
        FileUtils.writeStringToFile((File)file, (String)asHtml);
    }

    public static String evaluationCalibrationToHtml(EvaluationCalibration ec) {
        ArrayList<ComponentDiv> components = new ArrayList<ComponentDiv>();
        int nClasses = ec.numClasses();
        ComponentDiv headerDiv = new ComponentDiv((Style)HEADER_DIV_STYLE_1400, new Component[]{new ComponentText("Labels and Network Prediction Class Distributions (X: Class Index. Y: Count)", HEADER_TEXT_STYLE)});
        components.add(headerDiv);
        int[] labelCounts = ec.getLabelCountsEachClass();
        int[] predictedCounts = ec.getPredictionCountsEachClass();
        ChartHistogram.Builder chbLabels = new ChartHistogram.Builder("Label Class Distribution", CHART_STYLE);
        ChartHistogram.Builder chbPredictions = new ChartHistogram.Builder("Predicted Class Distribution", CHART_STYLE);
        for (int i = 0; i < nClasses; ++i) {
            double lower = (double)i - 0.5;
            double upper = (double)i + 0.5;
            chbLabels.addBin(lower, upper, (double)labelCounts[i]);
            chbPredictions.addBin(lower, upper, (double)predictedCounts[i]);
        }
        ChartHistogram chL = chbLabels.build();
        ChartHistogram chP = chbPredictions.build();
        components.add(new ComponentDiv((Style)OUTER_DIV_STYLE_WIDTH_ONLY, new Component[]{chL, chP}));
        headerDiv = new ComponentDiv((Style)HEADER_DIV_STYLE_1400, new Component[]{new ComponentText("Reliability Diagrams (X: Mean Predicted Value. Y: Fraction Positives)", HEADER_TEXT_STYLE)});
        components.add(headerDiv);
        ArrayList<Object> sectionDiv = new ArrayList<Object>();
        double[] zeroOne = new double[]{0.0, 1.0};
        for (int i = 0; i < nClasses; ++i) {
            ReliabilityDiagram rd = ec.getReliabilityDiagram(i);
            double[] x = rd.getMeanPredictedValueX();
            double[] y = rd.getFractionPositivesY();
            String title = rd.getTitle();
            ChartLine cl = new ChartLine.Builder(title, CHART_STYLE).addSeries("Classifier", x, y).addSeries("Ideal Classifier", zeroOne, zeroOne).build();
            sectionDiv.add(cl);
        }
        components.add(new ComponentDiv((Style)OUTER_DIV_STYLE_WIDTH_ONLY, sectionDiv));
        headerDiv = new ComponentDiv((Style)HEADER_DIV_STYLE_1400, new Component[]{new ComponentText("Network Predictions - Residual Plots - |Label(i) - P(class(i))|", HEADER_TEXT_STYLE)});
        components.add(headerDiv);
        sectionDiv = new ArrayList();
        Histogram resPlotAll = ec.getResidualPlotAllClasses();
        sectionDiv.add(EvaluationTools.getHistogram(resPlotAll));
        for (int i = 0; i < nClasses; ++i) {
            Histogram resPlotCurrent = ec.getResidualPlot(i);
            sectionDiv.add(EvaluationTools.getHistogram(resPlotCurrent));
        }
        components.add(new ComponentDiv((Style)OUTER_DIV_STYLE_WIDTH_ONLY, sectionDiv));
        headerDiv = new ComponentDiv((Style)HEADER_DIV_STYLE_1400, new Component[]{new ComponentText("Network Prediction Probabilities (X: P(class). Y: Count)", HEADER_TEXT_STYLE)});
        components.add(headerDiv);
        sectionDiv = new ArrayList();
        Histogram allProbs = ec.getProbabilityHistogramAllClasses();
        sectionDiv.add(EvaluationTools.getHistogram(allProbs));
        for (int i = 0; i < nClasses; ++i) {
            Histogram classProbs = ec.getProbabilityHistogram(i);
            sectionDiv.add(EvaluationTools.getHistogram(classProbs));
        }
        components.add(new ComponentDiv((Style)OUTER_DIV_STYLE_WIDTH_ONLY, sectionDiv));
        return StaticPageUtil.renderHTML(components);
    }

    private static Component getRocFromPoints(String title, RocCurve roc, long positiveCount, long negativeCount, double auc, double aucpr) {
        double[] zeroOne = new double[]{0.0, 1.0};
        ChartLine chartLine = ((ChartLine.Builder)((ChartLine.Builder)((ChartLine.Builder)((ChartLine.Builder)new ChartLine.Builder(title, CHART_STYLE).setXMin(Double.valueOf(0.0))).setXMax(Double.valueOf(1.0))).setYMin(Double.valueOf(0.0))).setYMax(Double.valueOf(1.0))).addSeries("ROC", roc.getX(), roc.getY()).addSeries("", zeroOne, zeroOne).build();
        ComponentTable ct = new ComponentTable.Builder(TABLE_STYLE).header(new String[]{"Field", "Value"}).content((String[][])new String[][]{{"AUROC: Area under ROC:", String.format("%.5f", auc)}, {"AUPRC: Area under P/R:", String.format("%.5f", aucpr)}, {"Total Data Positive Count", String.valueOf(positiveCount)}, {"Total Data Negative Count", String.valueOf(negativeCount)}}).build();
        ComponentDiv divLeft = new ComponentDiv((Style)INNER_DIV_STYLE, new Component[]{PAD_DIV, ct, PAD_DIV, INFO_TABLE});
        ComponentDiv divRight = new ComponentDiv((Style)INNER_DIV_STYLE, new Component[]{chartLine});
        return new ComponentDiv((Style)OUTER_DIV_STYLE, new Component[]{divLeft, divRight});
    }

    private static Component getPRCharts(String precisionRecallTitle, String prThresholdTitle, PrecisionRecallCurve prCurve) {
        ComponentDiv divLeft = new ComponentDiv((Style)INNER_DIV_STYLE, new Component[]{EvaluationTools.getPrecisionRecallCurve(precisionRecallTitle, prCurve)});
        ComponentDiv divRight = new ComponentDiv((Style)INNER_DIV_STYLE, new Component[]{EvaluationTools.getPrecisionRecallVsThreshold(prThresholdTitle, prCurve)});
        return new ComponentDiv((Style)OUTER_DIV_STYLE, new Component[]{divLeft, divRight});
    }

    private static Component getPrecisionRecallCurve(String title, PrecisionRecallCurve prCurve) {
        double[] recallX = prCurve.getRecall();
        double[] precisionY = prCurve.getPrecision();
        return ((ChartLine.Builder)((ChartLine.Builder)((ChartLine.Builder)((ChartLine.Builder)new ChartLine.Builder(title, CHART_STYLE).setXMin(Double.valueOf(0.0))).setXMax(Double.valueOf(1.0))).setYMin(Double.valueOf(0.0))).setYMax(Double.valueOf(1.0))).addSeries("P vs R", recallX, precisionY).build();
    }

    private static Component getPrecisionRecallVsThreshold(String title, PrecisionRecallCurve prCurve) {
        double[] recallY = prCurve.getRecall();
        double[] precisionY = prCurve.getPrecision();
        double[] thresholdX = prCurve.getThreshold();
        return ((ChartLine.Builder)((ChartLine.Builder)((ChartLine.Builder)((ChartLine.Builder)((ChartLine.Builder)new ChartLine.Builder(title, CHART_STYLE_PRECISION_RECALL).setXMin(Double.valueOf(0.0))).setXMax(Double.valueOf(1.0))).setYMin(Double.valueOf(0.0))).setYMax(Double.valueOf(1.0))).addSeries("Precision", thresholdX, precisionY).addSeries("Recall", thresholdX, recallY).showLegend(true)).build();
    }

    private static Component getHistogram(Histogram histogram) {
        ChartHistogram.Builder chb = new ChartHistogram.Builder(histogram.getTitle(), CHART_STYLE);
        double[] lower = histogram.getBinLowerBounds();
        double[] upper = histogram.getBinUpperBounds();
        int[] counts = histogram.getBinCounts();
        for (int i = 0; i < counts.length; ++i) {
            chb.addBin(lower[i], upper[i], (double)counts[i]);
        }
        return chb.build();
    }
}

