/*
 * Decompiled with CFR 0.152.
 */
package smile.validation;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.function.BiFunction;
import smile.classification.Classifier;
import smile.classification.DataFrameClassifier;
import smile.data.DataFrame;
import smile.data.formula.Formula;
import smile.math.MathEx;
import smile.util.Index;
import smile.validation.Bag;
import smile.validation.ClassificationMetrics;
import smile.validation.ClassificationValidations;
import smile.validation.metric.ConfusionMatrix;

public record ClassificationValidation<M>(M model, int[] truth, int[] prediction, double[][] posteriori, ConfusionMatrix confusion, ClassificationMetrics metrics) implements Serializable
{
    private static final long serialVersionUID = 3L;

    public ClassificationValidation(M model, double fitTime, double scoreTime, int[] truth, int[] prediction) {
        this(model, truth, prediction, null, ConfusionMatrix.of(truth, prediction), ClassificationMetrics.of(fitTime, scoreTime, truth, prediction));
    }

    public ClassificationValidation(M model, double fitTime, double scoreTime, int[] truth, int[] prediction, double[][] posteriori) {
        this(model, truth, prediction, posteriori, ConfusionMatrix.of(truth, prediction), ClassificationMetrics.of(fitTime, scoreTime, truth, prediction, posteriori));
    }

    @Override
    public String toString() {
        return this.metrics.toString();
    }

    public static <T, M extends Classifier<T>> ClassificationValidation<M> of(T[] x, int[] y, T[] testx, int[] testy, BiFunction<T[], int[], M> trainer) {
        long start = System.nanoTime();
        Classifier model = (Classifier)trainer.apply((T[][])x, y);
        double fitTime = (double)(System.nanoTime() - start) / 1000000.0;
        start = System.nanoTime();
        if (model.soft()) {
            int k = model.numClasses();
            double[][] posteriori = new double[testx.length][k];
            int[] prediction = model.predict(testx, posteriori);
            double scoreTime = (double)(System.nanoTime() - start) / 1000000.0;
            return new ClassificationValidation<Classifier>(model, fitTime, scoreTime, testy, prediction, posteriori);
        }
        int[] prediction = model.predict(testx);
        double scoreTime = (double)(System.nanoTime() - start) / 1000000.0;
        return new ClassificationValidation<Classifier>(model, fitTime, scoreTime, testy, prediction);
    }

    public static <T, M extends Classifier<T>> ClassificationValidations<M> of(Bag[] bags, T[] x, int[] y, BiFunction<T[], int[], M> trainer) {
        ArrayList rounds = new ArrayList(bags.length);
        for (Bag bag : bags) {
            Object[] trainx = MathEx.slice((Object[])x, (int[])bag.samples());
            int[] trainy = MathEx.slice((int[])y, (int[])bag.samples());
            Object[] testx = MathEx.slice((Object[])x, (int[])bag.oob());
            int[] testy = MathEx.slice((int[])y, (int[])bag.oob());
            rounds.add(ClassificationValidation.of(trainx, trainy, testx, testy, trainer));
        }
        return ClassificationValidations.of(rounds);
    }

    public static <M extends DataFrameClassifier> ClassificationValidation<M> of(Formula formula, DataFrame train, DataFrame test, BiFunction<Formula, DataFrame, M> trainer) {
        int[] y = formula.y(train).toIntArray();
        int[] testy = formula.y(test).toIntArray();
        long start = System.nanoTime();
        DataFrameClassifier model = (DataFrameClassifier)trainer.apply(formula, train);
        double fitTime = (double)(System.nanoTime() - start) / 1000000.0;
        int n = test.size();
        int[] prediction = new int[n];
        if (model.soft()) {
            int k = model.numClasses();
            double[][] posteriori = new double[n][k];
            start = System.nanoTime();
            for (int i = 0; i < n; ++i) {
                prediction[i] = model.predict(test.get(i), posteriori[i]);
            }
            double scoreTime = (double)(System.nanoTime() - start) / 1000000.0;
            return new ClassificationValidation<DataFrameClassifier>(model, fitTime, scoreTime, testy, prediction, posteriori);
        }
        start = System.nanoTime();
        for (int i = 0; i < n; ++i) {
            prediction[i] = model.predict(test.get(i));
        }
        double scoreTime = (double)(System.nanoTime() - start) / 1000000.0;
        return new ClassificationValidation<DataFrameClassifier>(model, fitTime, scoreTime, testy, prediction);
    }

    public static <M extends DataFrameClassifier> ClassificationValidations<M> of(Bag[] bags, Formula formula, DataFrame data, BiFunction<Formula, DataFrame, M> trainer) {
        ArrayList rounds = new ArrayList(bags.length);
        for (Bag bag : bags) {
            rounds.add(ClassificationValidation.of(formula, data.get(Index.of((int[])bag.samples())), data.get(Index.of((int[])bag.oob())), trainer));
        }
        return ClassificationValidations.of(rounds);
    }
}

