/*
 * Decompiled with CFR 0.152.
 */
package smile.validation;

import java.io.Serializable;
import java.util.List;
import smile.math.MathEx;
import smile.validation.ClassificationMetrics;
import smile.validation.ClassificationValidation;

public record ClassificationValidations<M>(List<ClassificationValidation<M>> rounds, ClassificationMetrics avg, ClassificationMetrics std) implements Serializable
{
    private static final long serialVersionUID = 3L;

    public static <M> ClassificationValidations<M> of(List<ClassificationValidation<M>> rounds) {
        int k = rounds.size();
        double[] fitTime = new double[k];
        double[] scoreTime = new double[k];
        int[] size = new int[k];
        int[] error = new int[k];
        double[] accuracy = new double[k];
        double[] sensitivity = new double[k];
        double[] specificity = new double[k];
        double[] precision = new double[k];
        double[] f1 = new double[k];
        double[] mcc = new double[k];
        double[] auc = new double[k];
        double[] logloss = new double[k];
        double[] crossentropy = new double[k];
        for (int i = 0; i < k; ++i) {
            ClassificationMetrics metrics = rounds.get(i).metrics();
            fitTime[i] = metrics.fitTime();
            scoreTime[i] = metrics.scoreTime();
            size[i] = metrics.size();
            error[i] = metrics.error();
            accuracy[i] = metrics.accuracy();
            sensitivity[i] = metrics.sensitivity();
            specificity[i] = metrics.specificity();
            precision[i] = metrics.precision();
            f1[i] = metrics.f1();
            mcc[i] = metrics.mcc();
            auc[i] = metrics.auc();
            logloss[i] = metrics.logloss();
            crossentropy[i] = metrics.crossEntropy();
        }
        ClassificationMetrics avg = new ClassificationMetrics(MathEx.mean((double[])fitTime), MathEx.mean((double[])scoreTime), (int)Math.round(MathEx.mean((int[])size)), (int)Math.round(MathEx.mean((int[])error)), MathEx.mean((double[])accuracy), MathEx.mean((double[])sensitivity), MathEx.mean((double[])specificity), MathEx.mean((double[])precision), MathEx.mean((double[])f1), MathEx.mean((double[])mcc), MathEx.mean((double[])auc), MathEx.mean((double[])logloss), MathEx.mean((double[])crossentropy));
        ClassificationMetrics std = new ClassificationMetrics(MathEx.stdev((double[])fitTime), MathEx.stdev((double[])scoreTime), (int)Math.round(MathEx.stdev((int[])size)), (int)Math.round(MathEx.stdev((int[])error)), MathEx.stdev((double[])accuracy), MathEx.stdev((double[])sensitivity), MathEx.stdev((double[])specificity), MathEx.stdev((double[])precision), MathEx.stdev((double[])f1), MathEx.stdev((double[])mcc), MathEx.stdev((double[])auc), MathEx.stdev((double[])logloss), MathEx.stdev((double[])crossentropy));
        return new ClassificationValidations<M>(rounds, avg, std);
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder("{\n");
        sb.append(String.format("  fit time: %.3f ms \u00b1 %.3f,\n", this.avg.fitTime(), this.std.fitTime()));
        sb.append(String.format("  score time: %.3f ms \u00b1 %.3f,\n", this.avg.scoreTime(), this.std.scoreTime()));
        sb.append(String.format("  validation data size: %d \u00b1 %d,\n", this.avg.size(), this.std.size()));
        sb.append(String.format("  error: %d \u00b1 %d,\n", this.avg.error(), this.std.error()));
        sb.append(String.format("  accuracy: %.2f%% \u00b1 %.2f", 100.0 * this.avg.accuracy(), 100.0 * this.std.accuracy()));
        if (!Double.isNaN(this.avg.sensitivity())) {
            sb.append(String.format(",\n  sensitivity: %.2f%% \u00b1 %.2f", 100.0 * this.avg.sensitivity(), 100.0 * this.std.sensitivity()));
        }
        if (!Double.isNaN(this.avg.specificity())) {
            sb.append(String.format(",\n  specificity: %.2f%% \u00b1 %.2f", 100.0 * this.avg.specificity(), 100.0 * this.std.specificity()));
        }
        if (!Double.isNaN(this.avg.precision())) {
            sb.append(String.format(",\n  precision: %.2f%% \u00b1 %.2f", 100.0 * this.avg.precision(), 100.0 * this.std.precision()));
        }
        if (!Double.isNaN(this.avg.f1())) {
            sb.append(String.format(",\n  F1 score: %.2f%% \u00b1 %.2f", 100.0 * this.avg.f1(), 100.0 * this.std.f1()));
        }
        if (!Double.isNaN(this.avg.mcc())) {
            sb.append(String.format(",\n  MCC: %.2f%% \u00b1 %.2f", 100.0 * this.avg.mcc(), 100.0 * this.std.mcc()));
        }
        if (!Double.isNaN(this.avg.auc())) {
            sb.append(String.format(",\n  AUC: %.2f%% \u00b1 %.2f", 100.0 * this.avg.auc(), 100.0 * this.std.auc()));
        }
        if (!Double.isNaN(this.avg.logloss())) {
            sb.append(String.format(",\n  log loss: %.4f \u00b1 %.4f", this.avg.logloss(), this.std.logloss()));
        } else if (!Double.isNaN(this.avg.crossEntropy())) {
            sb.append(String.format(",\n  cross entropy: %.4f \u00b1 %.4f", this.avg.crossEntropy(), this.std.crossEntropy()));
        }
        sb.append("\n}");
        return sb.toString();
    }
}

