/*
 * Decompiled with CFR 0.152.
 */
package smile.validation;

import java.util.function.BiFunction;
import smile.classification.Classifier;
import smile.classification.DataFrameClassifier;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.math.MathEx;
import smile.regression.DataFrameRegression;
import smile.regression.Regression;

public class LOOCV {
    public final int[][] train;
    public final int[] test;

    public LOOCV(int n) {
        if (n < 0) {
            throw new IllegalArgumentException("Invalid sample size: " + n);
        }
        this.train = new int[n][n - 1];
        this.test = new int[n];
        for (int i = 0; i < n; ++i) {
            this.test[i] = i;
            int p = 0;
            for (int j = 0; j < n; ++j) {
                if (j == i) continue;
                this.train[i][p++] = j;
            }
        }
    }

    public static <T> int[] classification(T[] x, int[] y, BiFunction<T[], int[], Classifier<T>> trainer) {
        int n = x.length;
        LOOCV cv = new LOOCV(n);
        int[] prediction = new int[n];
        for (int i = 0; i < n; ++i) {
            Object[] trainx = MathEx.slice((Object[])x, (int[])cv.train[i]);
            int[] trainy = MathEx.slice((int[])y, (int[])cv.train[i]);
            Classifier<T> model = trainer.apply((Object[][])trainx, trainy);
            prediction[cv.test[i]] = model.predict(x[cv.test[i]]);
        }
        return prediction;
    }

    public static int[] classification(Formula formula, DataFrame data, BiFunction<Formula, DataFrame, DataFrameClassifier> trainer) {
        int n = data.size();
        LOOCV cv = new LOOCV(n);
        int[] prediction = new int[n];
        for (int i = 0; i < n; ++i) {
            DataFrameClassifier model = trainer.apply(formula, data.of(cv.train[i]));
            prediction[cv.test[i]] = model.predict((Tuple)data.get(cv.test[i]));
        }
        return prediction;
    }

    public static <T> double[] regression(T[] x, double[] y, BiFunction<T[], double[], Regression<T>> trainer) {
        int n = x.length;
        LOOCV cv = new LOOCV(n);
        double[] prediction = new double[n];
        for (int i = 0; i < n; ++i) {
            Object[] trainx = MathEx.slice((Object[])x, (int[])cv.train[i]);
            double[] trainy = MathEx.slice((double[])y, (int[])cv.train[i]);
            Regression<T> model = trainer.apply((Object[][])trainx, trainy);
            prediction[cv.test[i]] = model.predict(x[cv.test[i]]);
        }
        return prediction;
    }

    public static double[] regression(Formula formula, DataFrame data, BiFunction<Formula, DataFrame, DataFrameRegression> trainer) {
        int n = data.size();
        LOOCV cv = new LOOCV(n);
        double[] prediction = new double[n];
        for (int i = 0; i < n; ++i) {
            DataFrameRegression model = trainer.apply(formula, data.of(cv.train[i]));
            prediction[cv.test[i]] = model.predict((Tuple)data.get(cv.test[i]));
        }
        return prediction;
    }
}

