/*
 * Decompiled with CFR 0.152.
 */
package smile.feature.selection;

import java.util.function.BiFunction;
import smile.classification.Classifier;
import smile.classification.DataFrameClassifier;
import smile.data.DataFrame;
import smile.data.formula.Formula;
import smile.gap.BitString;
import smile.gap.Chromosome;
import smile.gap.Crossover;
import smile.gap.Fitness;
import smile.gap.GeneticAlgorithm;
import smile.gap.Selection;
import smile.math.MathEx;
import smile.regression.DataFrameRegression;
import smile.regression.Regression;
import smile.validation.metric.ClassificationMetric;
import smile.validation.metric.RegressionMetric;

public class GAFE {
    private final Selection selection;
    private final int elitism;
    private final Crossover crossover;
    private final double crossoverRate;
    private final double mutationRate;

    public GAFE() {
        this(Selection.Tournament((int)3, (double)0.95), 1, Crossover.TWO_POINT, 1.0, 0.01);
    }

    public GAFE(Selection selection, int elitism, Crossover crossover, double crossoverRate, double mutationRate) {
        this.selection = selection;
        this.elitism = elitism;
        this.crossover = crossover;
        this.crossoverRate = crossoverRate;
        this.mutationRate = mutationRate;
    }

    public BitString[] apply(int size, int generation, int length, Fitness<BitString> fitness) {
        if (size <= 0) {
            throw new IllegalArgumentException("Invalid population size: " + size);
        }
        BitString[] seeds = new BitString[size];
        for (int i = 0; i < size; ++i) {
            seeds[i] = new BitString(length, fitness, this.crossover, this.crossoverRate, this.mutationRate);
        }
        GeneticAlgorithm ga = new GeneticAlgorithm((Chromosome[])seeds, this.selection, this.elitism);
        ga.evolve(generation);
        return seeds;
    }

    private static int[] indexOf(byte[] bits) {
        int p = MathEx.sum((byte[])bits);
        if (p == 0) {
            return null;
        }
        int[] index = new int[p];
        int ii = 0;
        for (int i = 0; i < bits.length; ++i) {
            if (bits[i] != 1) continue;
            index[ii++] = i;
        }
        return index;
    }

    private static double[][] select(double[][] x, int[] features) {
        int p = features.length;
        int n = x.length;
        double[][] xx = new double[n][p];
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < p; ++j) {
                xx[i][j] = x[i][features[j]];
            }
        }
        return xx;
    }

    public static Fitness<BitString> fitness(double[][] x, int[] y, double[][] testx, int[] testy, ClassificationMetric metric, BiFunction<double[][], int[], Classifier<double[]>> trainer) {
        return chromosome -> {
            byte[] bits = chromosome.bits();
            int[] features = GAFE.indexOf(bits);
            if (features == null) {
                return 0.0;
            }
            double[][] xx = GAFE.select(x, features);
            double[][] testxx = GAFE.select(testx, features);
            Classifier model = (Classifier)trainer.apply(xx, y);
            return metric.score(testy, model.predict((T[])testxx));
        };
    }

    public static Fitness<BitString> fitness(double[][] x, double[] y, double[][] testx, double[] testy, RegressionMetric metric, BiFunction<double[][], double[], Regression<double[]>> trainer) {
        return chromosome -> {
            byte[] bits = chromosome.bits();
            int[] features = GAFE.indexOf(bits);
            if (features == null) {
                return Double.NEGATIVE_INFINITY;
            }
            double[][] xx = GAFE.select(x, features);
            double[][] testxx = GAFE.select(testx, features);
            Regression model = (Regression)trainer.apply(xx, y);
            return -metric.score(testy, model.predict((T[])testxx));
        };
    }

    private static String[] selectedFeatures(byte[] bits, String[] names, String y) {
        int p = MathEx.sum((byte[])bits);
        if (p == 0) {
            return null;
        }
        int offset = 0;
        String[] features = new String[p];
        int ii = 0;
        for (int i = 0; i < bits.length; ++i) {
            if (names[i].equals(y)) {
                ++offset;
            }
            if (bits[i] != 1) continue;
            features[ii++] = names[i + offset];
        }
        return features;
    }

    public static Fitness<BitString> fitness(String y, DataFrame train, DataFrame test, ClassificationMetric metric, BiFunction<Formula, DataFrame, DataFrameClassifier> trainer) {
        String[] names = train.names();
        int[] testy = test.column(y).toIntArray();
        return chromosome -> {
            byte[] bits = chromosome.bits();
            String[] features = GAFE.selectedFeatures(bits, names, y);
            if (features == null) {
                return 0.0;
            }
            Formula formula = Formula.of((String)y, (String[])features);
            DataFrameClassifier model = (DataFrameClassifier)trainer.apply(formula, train);
            return metric.score(testy, model.predict(test));
        };
    }

    public static Fitness<BitString> fitness(String y, DataFrame train, DataFrame test, RegressionMetric metric, BiFunction<Formula, DataFrame, DataFrameRegression> trainer) {
        String[] names = train.names();
        double[] testy = test.column(y).toDoubleArray();
        return chromosome -> {
            byte[] bits = chromosome.bits();
            String[] features = GAFE.selectedFeatures(bits, names, y);
            if (features == null) {
                return Double.NEGATIVE_INFINITY;
            }
            Formula formula = Formula.of((String)y, (String[])features);
            DataFrameRegression model = (DataFrameRegression)trainer.apply(formula, train);
            return -metric.score(testy, model.predict(test));
        };
    }
}

