/*
 * Decompiled with CFR 0.152.
 */
package smile.validation;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.function.BiFunction;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.math.MathEx;
import smile.regression.DataFrameRegression;
import smile.regression.Regression;
import smile.validation.Bag;
import smile.validation.RegressionMetrics;
import smile.validation.RegressionValidations;

public class RegressionValidation<M>
implements Serializable {
    private static final long serialVersionUID = 2L;
    public final M model;
    public final double[] truth;
    public final double[] prediction;
    public final RegressionMetrics metrics;

    public RegressionValidation(M model, double[] truth, double[] prediction, RegressionMetrics metrics) {
        this.model = model;
        this.truth = truth;
        this.prediction = prediction;
        this.metrics = metrics;
    }

    public String toString() {
        return this.metrics.toString();
    }

    public static <T, M extends Regression<T>> RegressionValidation<M> of(T[] x, double[] y, T[] testx, double[] testy, BiFunction<T[], double[], M> trainer) {
        long start = System.nanoTime();
        Regression model = (Regression)trainer.apply((T[][])x, y);
        double fitTime = (double)(System.nanoTime() - start) / 1000000.0;
        start = System.nanoTime();
        double[] prediction = model.predict(testx);
        double scoreTime = (double)(System.nanoTime() - start) / 1000000.0;
        RegressionMetrics metrics = RegressionMetrics.of(fitTime, scoreTime, testy, prediction);
        return new RegressionValidation<Regression>(model, testy, prediction, metrics);
    }

    public static <T, M extends Regression<T>> RegressionValidations<M> of(Bag[] bags, T[] x, double[] y, BiFunction<T[], double[], M> trainer) {
        ArrayList rounds = new ArrayList(bags.length);
        for (Bag bag : bags) {
            Object[] trainx = MathEx.slice((Object[])x, (int[])bag.samples);
            double[] trainy = MathEx.slice((double[])y, (int[])bag.samples);
            Object[] testx = MathEx.slice((Object[])x, (int[])bag.oob);
            double[] testy = MathEx.slice((double[])y, (int[])bag.oob);
            rounds.add(RegressionValidation.of(trainx, trainy, testx, testy, trainer));
        }
        return new RegressionValidations(rounds);
    }

    public static <M extends DataFrameRegression> RegressionValidation<M> of(Formula formula, DataFrame train, DataFrame test, BiFunction<Formula, DataFrame, M> trainer) {
        double[] testy = formula.y(test).toDoubleArray();
        long start = System.nanoTime();
        DataFrameRegression model = (DataFrameRegression)trainer.apply(formula, train);
        double fitTime = (double)(System.nanoTime() - start) / 1000000.0;
        start = System.nanoTime();
        int n = test.nrow();
        double[] prediction = new double[n];
        for (int i = 0; i < n; ++i) {
            prediction[i] = model.predict((Tuple)test.get(i));
        }
        double scoreTime = (double)(System.nanoTime() - start) / 1000000.0;
        RegressionMetrics metrics = RegressionMetrics.of(fitTime, scoreTime, testy, prediction);
        return new RegressionValidation<DataFrameRegression>(model, testy, prediction, metrics);
    }

    public static <M extends DataFrameRegression> RegressionValidations<M> of(Bag[] bags, Formula formula, DataFrame data, BiFunction<Formula, DataFrame, M> trainer) {
        ArrayList rounds = new ArrayList(bags.length);
        for (Bag bag : bags) {
            rounds.add(RegressionValidation.of(formula, data.of(bag.samples), data.of(bag.oob), trainer));
        }
        return new RegressionValidations(rounds);
    }
}

