/*
 * Decompiled with CFR 0.152.
 */
package smile.validation;

import java.util.function.BiFunction;
import smile.classification.Classifier;
import smile.classification.DataFrameClassifier;
import smile.data.DataFrame;
import smile.data.formula.Formula;
import smile.math.MathEx;
import smile.regression.DataFrameRegression;
import smile.regression.Regression;
import smile.validation.Accuracy;
import smile.validation.RMSE;

public class Bootstrap {
    public final int k;
    public final int[][] train;
    public final int[][] test;

    public Bootstrap(int n, int k) {
        if (n < 0) {
            throw new IllegalArgumentException("Invalid sample size: " + n);
        }
        if (k < 0) {
            throw new IllegalArgumentException("Invalid number of bootstrap: " + k);
        }
        this.k = k;
        this.train = new int[k][n];
        this.test = new int[k][];
        for (int j = 0; j < k; ++j) {
            int i;
            boolean[] hit = new boolean[n];
            int hits = 0;
            for (i = 0; i < n; ++i) {
                int r;
                this.train[j][i] = r = MathEx.randomInt((int)n);
                if (hit[r]) continue;
                ++hits;
                hit[r] = true;
            }
            this.test[j] = new int[n - hits];
            int p = 0;
            for (i = 0; i < n; ++i) {
                if (hit[i]) continue;
                this.test[j][p++] = i;
            }
        }
    }

    public <T> double[] classification(T[] x, int[] y, BiFunction<T[], int[], Classifier<T>> trainer) {
        double[] error = new double[this.k];
        for (int i = 0; i < this.k; ++i) {
            Object[] trainx = MathEx.slice((Object[])x, (int[])this.train[i]);
            int[] trainy = MathEx.slice((int[])y, (int[])this.train[i]);
            Object[] testx = MathEx.slice((Object[])x, (int[])this.test[i]);
            int[] testy = MathEx.slice((int[])y, (int[])this.test[i]);
            Classifier<Object> model = trainer.apply((Object[][])trainx, trainy);
            int[] prediction = model.predict(testx);
            error[i] = 1.0 - Accuracy.of(testy, prediction);
        }
        return error;
    }

    public double[] classification(Formula formula, DataFrame data, BiFunction<Formula, DataFrame, DataFrameClassifier> trainer) {
        double[] error = new double[this.k];
        for (int i = 0; i < this.k; ++i) {
            DataFrameClassifier model = trainer.apply(formula, data.of(this.train[i]));
            DataFrame oob = data.of(this.test[i]);
            int[] prediction = model.predict(oob);
            int[] testy = model.formula().y(oob).toIntArray();
            error[i] = 1.0 - Accuracy.of(testy, prediction);
        }
        return error;
    }

    public <T> double[] regression(T[] x, double[] y, BiFunction<T[], double[], Regression<T>> trainer) {
        double[] rmse = new double[this.k];
        for (int i = 0; i < this.k; ++i) {
            Object[] trainx = MathEx.slice((Object[])x, (int[])this.train[i]);
            double[] trainy = MathEx.slice((double[])y, (int[])this.train[i]);
            Object[] testx = MathEx.slice((Object[])x, (int[])this.test[i]);
            double[] testy = MathEx.slice((double[])y, (int[])this.test[i]);
            Regression<Object> model = trainer.apply((Object[][])trainx, trainy);
            double[] prediction = model.predict(testx);
            rmse[i] = RMSE.of(testy, prediction);
        }
        return rmse;
    }

    public double[] regression(Formula formula, DataFrame data, BiFunction<Formula, DataFrame, DataFrameRegression> trainer) {
        double[] rmse = new double[this.k];
        for (int i = 0; i < this.k; ++i) {
            DataFrameRegression model = trainer.apply(formula, data.of(this.train[i]));
            DataFrame oob = data.of(this.test[i]);
            double[] prediction = model.predict(oob);
            double[] testy = model.formula().y(oob).toDoubleArray();
            rmse[i] = RMSE.of(testy, prediction);
        }
        return rmse;
    }

    public static <T> double[] classification(int k, T[] x, int[] y, BiFunction<T[], int[], Classifier<T>> trainer) {
        Bootstrap cv = new Bootstrap(x.length, k);
        return cv.classification(x, y, trainer);
    }

    public static double[] classification(int k, Formula formula, DataFrame data, BiFunction<Formula, DataFrame, DataFrameClassifier> trainer) {
        Bootstrap cv = new Bootstrap(data.size(), k);
        return cv.classification(formula, data, trainer);
    }

    public static <T> double[] regression(int k, T[] x, double[] y, BiFunction<T[], double[], Regression<T>> trainer) {
        Bootstrap cv = new Bootstrap(x.length, k);
        return cv.regression(x, y, trainer);
    }

    public static double[] regression(int k, Formula formula, DataFrame data, BiFunction<Formula, DataFrame, DataFrameRegression> trainer) {
        Bootstrap cv = new Bootstrap(data.size(), k);
        return cv.regression(formula, data, trainer);
    }
}

