/*
 * Decompiled with CFR 0.152.
 */
package smile.validation;

import java.util.Arrays;
import java.util.function.BiFunction;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.classification.Classifier;
import smile.classification.DataFrameClassifier;
import smile.data.DataFrame;
import smile.data.formula.Formula;
import smile.math.MathEx;
import smile.regression.DataFrameRegression;
import smile.regression.Regression;
import smile.sort.QuickSort;
import smile.util.IntSet;
import smile.validation.Bag;
import smile.validation.ClassificationValidation;
import smile.validation.ClassificationValidations;
import smile.validation.RegressionValidation;
import smile.validation.RegressionValidations;

public interface CrossValidation {
    public static Bag[] of(int n, int k) {
        return CrossValidation.of(n, k, true);
    }

    public static Bag[] of(int n, int k, boolean shuffle) {
        if (n < 0) {
            throw new IllegalArgumentException("Invalid sample size: " + n);
        }
        if (k < 0 || k > n) {
            throw new IllegalArgumentException("Invalid number of CV rounds: " + k);
        }
        Bag[] bags = new Bag[k];
        int[] index = IntStream.range(0, n).toArray();
        if (shuffle) {
            MathEx.permutate((int[])index);
        }
        int chunk = n / k;
        for (int i = 0; i < k; ++i) {
            int start = chunk * i;
            int end = chunk * (i + 1);
            if (i == k - 1) {
                end = n;
            }
            int[] train = new int[n - end + start];
            int[] test = new int[end - start];
            int p = 0;
            int q = 0;
            for (int j = 0; j < n; ++j) {
                if (j >= start && j < end) {
                    test[p++] = index[j];
                    continue;
                }
                train[q++] = index[j];
            }
            bags[i] = new Bag(train, test);
        }
        return bags;
    }

    public static Bag[] of(int[] category, int k) {
        int i;
        if (k < 0) {
            throw new IllegalArgumentException("Invalid number of folds: " + k);
        }
        int[] unique = MathEx.unique((int[])category);
        int m = unique.length;
        Arrays.sort(unique);
        IntSet encoder = new IntSet(unique);
        int n = category.length;
        int[] y = category;
        if (unique[0] != 0 || unique[m - 1] != m - 1) {
            y = new int[n];
            for (int i2 = 0; i2 < n; ++i2) {
                y[i2] = encoder.indexOf(category[i2]);
            }
        }
        int[] ni = new int[m];
        int[] nArray = y;
        int n2 = nArray.length;
        for (int j = 0; j < n2; ++j) {
            int n3 = i = nArray[j];
            ni[n3] = ni[n3] + 1;
        }
        int min = MathEx.min((int[])ni);
        if (min < k) {
            Logger logger = LoggerFactory.getLogger(CrossValidation.class);
            logger.warn("The least populated class has only {} members, which is less than k={}.", (Object)min, (Object)k);
        }
        int[][] strata = new int[m][];
        for (int i3 = 0; i3 < m; ++i3) {
            strata[i3] = new int[ni[i3]];
        }
        int[] pos = new int[m];
        i = 0;
        while (i < n) {
            int j;
            int n4 = j = y[i];
            int n5 = pos[n4];
            pos[n4] = n5 + 1;
            strata[j][n5] = i++;
        }
        int[] chunk = new int[m];
        for (int i4 = 0; i4 < m; ++i4) {
            chunk[i4] = Math.max(1, ni[i4] / k);
        }
        Bag[] bags = new Bag[k];
        for (int i5 = 0; i5 < k; ++i5) {
            int p = 0;
            int q = 0;
            int[] train = new int[n];
            int[] test = new int[n];
            for (int j = 0; j < m; ++j) {
                int size = ni[j];
                int start = chunk[j] * i5;
                int end = chunk[j] * (i5 + 1);
                if (i5 == k - 1) {
                    end = size;
                }
                int[] stratum = strata[j];
                for (int l = 0; l < size; ++l) {
                    if (l >= start && l < end) {
                        test[q++] = stratum[l];
                        continue;
                    }
                    train[p++] = stratum[l];
                }
            }
            train = Arrays.copyOf(train, p);
            test = Arrays.copyOf(test, q);
            MathEx.permutate((int[])train);
            MathEx.permutate((int[])test);
            bags[i5] = new Bag(train, test);
        }
        return bags;
    }

    public static Bag[] nonoverlap(int[] group, int k) {
        int i;
        if (k < 0) {
            throw new IllegalArgumentException("Invalid number of folds: " + k);
        }
        int[] unique = MathEx.unique((int[])group);
        int m = unique.length;
        if (k > m) {
            throw new IllegalArgumentException("k-fold must be not greater than the than number of groups");
        }
        Arrays.sort(unique);
        IntSet encoder = new IntSet(unique);
        int n = group.length;
        int[] y = group;
        if (unique[0] != 0 || unique[m - 1] != m - 1) {
            y = new int[n];
            for (int i2 = 0; i2 < n; ++i2) {
                y[i2] = encoder.indexOf(group[i2]);
            }
        }
        int[] ni = new int[m];
        int[] nArray = y;
        int n2 = nArray.length;
        for (int j = 0; j < n2; ++j) {
            int n3 = i = nArray[j];
            ni[n3] = ni[n3] + 1;
        }
        int[] index = QuickSort.sort((int[])ni);
        int[] foldSize = new int[k];
        int[] group2Fold = new int[m];
        for (i = m - 1; i >= 0; --i) {
            int smallestFold;
            int n4 = smallestFold = MathEx.whichMin((int[])foldSize);
            foldSize[n4] = foldSize[n4] + ni[i];
            group2Fold[index[i]] = smallestFold;
        }
        Bag[] bags = new Bag[k];
        for (int i3 = 0; i3 < k; ++i3) {
            int[] train = new int[n - foldSize[i3]];
            int[] test = new int[foldSize[i3]];
            bags[i3] = new Bag(train, test);
            int trainIndex = 0;
            int testIndex = 0;
            for (int j = 0; j < n; ++j) {
                if (group2Fold[y[j]] == i3) {
                    test[testIndex++] = j;
                    continue;
                }
                train[trainIndex++] = j;
            }
        }
        return bags;
    }

    public static <T, M extends Classifier<T>> ClassificationValidations<M> classification(int k, T[] x, int[] y, BiFunction<T[], int[], M> trainer) {
        return ClassificationValidation.of(CrossValidation.of(x.length, k), x, y, trainer);
    }

    public static <M extends DataFrameClassifier> ClassificationValidations<M> classification(int k, Formula formula, DataFrame data, BiFunction<Formula, DataFrame, M> trainer) {
        return ClassificationValidation.of(CrossValidation.of(data.size(), k), formula, data, trainer);
    }

    public static <T, M extends Regression<T>> RegressionValidations<M> regression(int k, T[] x, double[] y, BiFunction<T[], double[], M> trainer) {
        return RegressionValidation.of(CrossValidation.of(x.length, k), x, y, trainer);
    }

    public static <M extends DataFrameRegression> RegressionValidations<M> regression(int k, Formula formula, DataFrame data, BiFunction<Formula, DataFrame, M> trainer) {
        return RegressionValidation.of(CrossValidation.of(data.size(), k), formula, data, trainer);
    }
}

