package smile.feature;

import java.util.function.BiFunction;
import smile.classification.Classifier;
import smile.classification.DataFrameClassifier;
import smile.data.DataFrame;
import smile.data.formula.Formula;
import smile.gap.BitString;
import smile.gap.Crossover;
import smile.gap.Fitness;
import smile.gap.GeneticAlgorithm;
import smile.gap.Selection;
import smile.math.MathEx;
import smile.regression.DataFrameRegression;
import smile.regression.Regression;
import smile.validation.metric.ClassificationMetric;
import smile.validation.metric.RegressionMetric;

/* loaded from: input_file:smile/feature/GAFE.class */
public class GAFE {
    private Selection selection;
    private int elitism;
    private Crossover crossover;
    private double crossoverRate;
    private double mutationRate;

    public GAFE() {
        this(Selection.Tournament(3, 0.95d), 1, Crossover.TWO_POINT, 1.0d, 0.01d);
    }

    public GAFE(Selection selection, int i, Crossover crossover, double d, double d2) {
        this.selection = selection;
        this.elitism = i;
        this.crossover = crossover;
        this.crossoverRate = d;
        this.mutationRate = d2;
    }

    public BitString[] apply(int i, int i2, int i3, Fitness<BitString> fitness) {
        if (i <= 0) {
            throw new IllegalArgumentException("Invalid population size: " + i);
        }
        BitString[] bitStringArr = new BitString[i];
        for (int i4 = 0; i4 < i; i4++) {
            bitStringArr[i4] = new BitString(i3, fitness, this.crossover, this.crossoverRate, this.mutationRate);
        }
        new GeneticAlgorithm(bitStringArr, this.selection, this.elitism).evolve(i2);
        return bitStringArr;
    }

    private static int[] indexOfOnes(byte[] bArr) {
        int sum = MathEx.sum(bArr);
        if (sum == 0) {
            return null;
        }
        int[] iArr = new int[sum];
        int i = 0;
        for (int i2 = 0; i2 < bArr.length; i2++) {
            if (bArr[i2] == 1) {
                int i3 = i;
                i++;
                iArr[i3] = i2;
            }
        }
        return iArr;
    }

    private static double[][] select(double[][] dArr, int[] iArr) {
        int length = iArr.length;
        int length2 = dArr.length;
        double[][] dArr2 = new double[length2][length];
        for (int i = 0; i < length2; i++) {
            for (int i2 = 0; i2 < length; i2++) {
                dArr2[i][i2] = dArr[i][iArr[i2]];
            }
        }
        return dArr2;
    }

    public static Fitness<BitString> fitness(double[][] dArr, int[] iArr, double[][] dArr2, int[] iArr2, ClassificationMetric classificationMetric, BiFunction<double[][], int[], Classifier<double[]>> biFunction) {
        return bitString -> {
            int[] indexOfOnes = indexOfOnes(bitString.bits());
            if (indexOfOnes == null) {
                return 0.0d;
            }
            double[][] select = select(dArr, indexOfOnes);
            return classificationMetric.score(iArr2, ((Classifier) biFunction.apply(select, iArr)).predict((Object[]) select(dArr2, indexOfOnes)));
        };
    }

    public static Fitness<BitString> fitness(double[][] dArr, double[] dArr2, double[][] dArr3, double[] dArr4, RegressionMetric regressionMetric, BiFunction<double[][], double[], Regression<double[]>> biFunction) {
        return bitString -> {
            int[] indexOfOnes = indexOfOnes(bitString.bits());
            if (indexOfOnes == null) {
                return Double.NEGATIVE_INFINITY;
            }
            double[][] select = select(dArr, indexOfOnes);
            return -regressionMetric.score(dArr4, ((Regression) biFunction.apply(select, dArr2)).predict((Object[]) select(dArr3, indexOfOnes)));
        };
    }

    private static String[] selectedFeatures(byte[] bArr, String[] strArr, String str) {
        int sum = MathEx.sum(bArr);
        if (sum == 0) {
            return null;
        }
        int i = 0;
        String[] strArr2 = new String[sum];
        int i2 = 0;
        for (int i3 = 0; i3 < bArr.length; i3++) {
            if (strArr[i3].equals(str)) {
                i++;
            }
            if (bArr[i3] == 1) {
                int i4 = i2;
                i2++;
                strArr2[i4] = strArr[i3 + i];
            }
        }
        return strArr2;
    }

    public static Fitness<BitString> fitness(String str, DataFrame dataFrame, DataFrame dataFrame2, ClassificationMetric classificationMetric, BiFunction<Formula, DataFrame, DataFrameClassifier> biFunction) {
        String[] names = dataFrame.names();
        int[] intArray = dataFrame2.column(str).toIntArray();
        return bitString -> {
            String[] selectedFeatures = selectedFeatures(bitString.bits(), names, str);
            if (selectedFeatures == null) {
                return 0.0d;
            }
            return classificationMetric.score(intArray, ((DataFrameClassifier) biFunction.apply(Formula.of(str, selectedFeatures), dataFrame)).predict(dataFrame2));
        };
    }

    public static Fitness<BitString> fitness(String str, DataFrame dataFrame, DataFrame dataFrame2, RegressionMetric regressionMetric, BiFunction<Formula, DataFrame, DataFrameRegression> biFunction) {
        String[] names = dataFrame.names();
        double[] doubleArray = dataFrame2.column(str).toDoubleArray();
        return bitString -> {
            String[] selectedFeatures = selectedFeatures(bitString.bits(), names, str);
            if (selectedFeatures == null) {
                return Double.NEGATIVE_INFINITY;
            }
            return -regressionMetric.score(doubleArray, ((DataFrameRegression) biFunction.apply(Formula.of(str, selectedFeatures), dataFrame)).predict(dataFrame2));
        };
    }
}
