/*
 * Decompiled with CFR 0.152.
 */
package smile.validation;

import java.io.Serializable;
import java.util.Arrays;
import smile.data.DataFrame;
import smile.math.MathEx;
import smile.stat.Sampling;
import smile.util.Index;
import smile.util.Tuple2;

public record Bag(int[] samples, int[] oob) implements Serializable
{
    private static final long serialVersionUID = 3L;

    public static Bag split(int n, double holdout) {
        if (n < 0) {
            throw new IllegalArgumentException("Invalid sample size: " + n);
        }
        if (holdout <= 0.0 || holdout >= 1.0) {
            throw new IllegalArgumentException("Invalid holdout proportion: " + holdout);
        }
        int[] index = MathEx.permutate((int)n);
        int trainSize = (int)Math.round((double)n * (1.0 - holdout));
        int[] train = Arrays.copyOf(index, trainSize);
        int[] test = Arrays.copyOfRange(index, trainSize, n);
        return new Bag(train, test);
    }

    public static Tuple2<DataFrame, DataFrame> split(DataFrame data, double holdout) {
        Bag bag = Bag.split(data.nrow(), holdout);
        DataFrame train = data.get(Index.of((int[])bag.samples()));
        DataFrame test = data.get(Index.of((int[])bag.oob()));
        return new Tuple2((Object)train, (Object)test);
    }

    static Bag stratify(int[] category, double holdout) {
        if (holdout <= 0.0 || holdout >= 1.0) {
            throw new IllegalArgumentException("Invalid holdout proportion: " + holdout);
        }
        int[][] strata = Sampling.strata((int[])category);
        int n = category.length;
        int m = strata.length;
        int p = 0;
        int q = 0;
        int[] train = new int[n];
        int[] test = new int[n];
        for (int[] stratum : strata) {
            MathEx.permutate((int[])stratum);
            int size = stratum.length;
            int trainSize = (int)Math.round((double)size * (1.0 - holdout));
            System.arraycopy(stratum, 0, train, p, trainSize);
            System.arraycopy(stratum, trainSize, test, q, size - trainSize);
            p += trainSize;
            q += size - trainSize;
        }
        train = Arrays.copyOf(train, p);
        test = Arrays.copyOf(test, q);
        MathEx.permutate((int[])train);
        MathEx.permutate((int[])test);
        return new Bag(train, test);
    }

    public static Tuple2<DataFrame, DataFrame> stratify(DataFrame data, String category, double holdout) {
        int[] label = data.column(category).toIntArray();
        Bag bag = Bag.stratify(label, holdout);
        DataFrame train = data.get(Index.of((int[])bag.samples()));
        DataFrame test = data.get(Index.of((int[])bag.oob()));
        return new Tuple2((Object)train, (Object)test);
    }
}

