/*
 * Decompiled with CFR 0.152.
 */
package smile.validation;

import java.io.Serializable;
import java.util.List;
import smile.math.MathEx;
import smile.validation.RegressionMetrics;
import smile.validation.RegressionValidation;

public record RegressionValidations<M>(List<RegressionValidation<M>> rounds, RegressionMetrics avg, RegressionMetrics std) implements Serializable
{
    private static final long serialVersionUID = 3L;

    public static <M> RegressionValidations<M> of(List<RegressionValidation<M>> rounds) {
        int k = rounds.size();
        double[] fitTime = new double[k];
        double[] scoreTime = new double[k];
        int[] size = new int[k];
        double[] rss = new double[k];
        double[] mse = new double[k];
        double[] rmse = new double[k];
        double[] mad = new double[k];
        double[] r2 = new double[k];
        for (int i = 0; i < k; ++i) {
            RegressionMetrics metrics = rounds.get(i).metrics();
            fitTime[i] = metrics.fitTime();
            scoreTime[i] = metrics.scoreTime();
            size[i] = metrics.size();
            rss[i] = metrics.rss();
            mse[i] = metrics.mse();
            rmse[i] = metrics.rmse();
            mad[i] = metrics.mad();
            r2[i] = metrics.r2();
        }
        RegressionMetrics avg = new RegressionMetrics(MathEx.mean((double[])fitTime), MathEx.mean((double[])scoreTime), (int)Math.round(MathEx.mean((int[])size)), MathEx.mean((double[])rss), MathEx.mean((double[])mse), MathEx.mean((double[])rmse), MathEx.mean((double[])mad), MathEx.mean((double[])r2));
        RegressionMetrics std = new RegressionMetrics(MathEx.stdev((double[])fitTime), MathEx.stdev((double[])scoreTime), (int)Math.round(MathEx.stdev((int[])size)), MathEx.stdev((double[])rss), MathEx.stdev((double[])mse), MathEx.stdev((double[])rmse), MathEx.stdev((double[])mad), MathEx.stdev((double[])r2));
        return new RegressionValidations<M>(rounds, avg, std);
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder("{\n");
        sb.append(String.format("  fit time: %.3f ms \u00b1 %.3f,\n", this.avg.fitTime(), this.std.fitTime()));
        sb.append(String.format("  score time: %.3f ms \u00b1 %.3f,\n", this.avg.scoreTime(), this.std.scoreTime()));
        sb.append(String.format("  validation data size: %d \u00b1 %d,\n", this.avg.size(), this.std.size()));
        sb.append(String.format("  RSS: %.4f \u00b1 %.4f,\n", this.avg.rss(), this.std.rss()));
        sb.append(String.format("  MSE: %.4f \u00b1 %.4f,\n", this.avg.mse(), this.std.mse()));
        sb.append(String.format("  RMSE: %.4f \u00b1 %.4f,\n", this.avg.rmse(), this.std.rmse()));
        sb.append(String.format("  MAD: %.4f \u00b1 %.4f,\n", this.avg.mad(), this.std.mad()));
        sb.append(String.format("  R2: %.2f%% \u00b1 %.2f\n}", 100.0 * this.avg.r2(), 100.0 * this.std.r2()));
        return sb.toString();
    }
}

