/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.io.Serializable;
import java.util.Arrays;
import java.util.Collections;
import java.util.Properties;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.base.cart.CART;
import smile.base.cart.SplitRule;
import smile.classification.AbstractClassifier;
import smile.classification.ClassLabels;
import smile.classification.DataFrameClassifier;
import smile.classification.DecisionTree;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.data.type.StructType;
import smile.data.vector.ValueVector;
import smile.feature.importance.TreeSHAP;
import smile.math.MathEx;
import smile.util.IntSet;
import smile.util.IterativeAlgorithmController;
import smile.util.Strings;
import smile.validation.ClassificationMetrics;
import smile.validation.metric.Accuracy;
import smile.validation.metric.Error;

public class RandomForest
extends AbstractClassifier<Tuple>
implements DataFrameClassifier,
TreeSHAP {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(RandomForest.class);
    private final Formula formula;
    private final Model[] models;
    private final int k;
    private final ClassificationMetrics metrics;
    private final double[] importance;

    public RandomForest(Formula formula, int k, Model[] models, ClassificationMetrics metrics, double[] importance) {
        this(formula, k, models, metrics, importance, IntSet.of((int)k));
    }

    public RandomForest(Formula formula, int k, Model[] models, ClassificationMetrics metrics, double[] importance, IntSet labels) {
        super(labels);
        this.formula = formula;
        this.k = k;
        this.models = models;
        this.metrics = metrics;
        this.importance = importance;
    }

    public static RandomForest fit(Formula formula, DataFrame data) {
        return RandomForest.fit(formula, data, new Options(500));
    }

    public static RandomForest fit(Formula formula, DataFrame data, Options options) {
        formula = formula.expand(data.schema());
        DataFrame x = formula.x(data);
        ValueVector y = formula.y(data);
        int ncol = x.ncol();
        if (options.mtry > ncol) {
            throw new IllegalArgumentException("Invalid number of variables to split on at a node of the tree: " + options.mtry);
        }
        int mtry = options.mtry > 0 ? options.mtry : (int)Math.sqrt(ncol);
        int maxNodes = options.maxNodes > 0 ? options.maxNodes : Math.max(2, data.size() / 5);
        int ntrees = options.ntrees;
        double subsample = options.subsample;
        ClassLabels codec = ClassLabels.fit(y);
        int k = codec.k;
        int n = x.size();
        int[] weight = options.classWeight != null ? options.classWeight : Collections.nCopies(k, 1).stream().mapToInt(i -> i).toArray();
        int[][] order = CART.order(x);
        int[][] prediction = new int[n][k];
        int[] count = new int[k];
        for (int i2 = 0; i2 < n; ++i2) {
            int n2 = codec.y[i2];
            count[n2] = count[n2] + 1;
        }
        int[][] yi = new int[k][];
        for (int i3 = 0; i3 < k; ++i3) {
            yi[i3] = new int[count[i3]];
        }
        int[] idx = new int[k];
        int i4 = 0;
        while (i4 < n) {
            int j;
            int n3 = j = codec.y[i4];
            int n4 = idx[n3];
            idx[n3] = n4 + 1;
            yi[j][n4] = i4++;
        }
        Model[] models = (Model[])IntStream.range(0, ntrees).parallel().mapToObj(t -> {
            ClassificationMetrics metrics;
            if (options.seeds != null) {
                MathEx.setSeed((long)options.seeds[t]);
            }
            int[] samples = new int[n];
            if (subsample == 1.0) {
                for (i = 0; i < k; ++i) {
                    int ni = count[i];
                    int size = ni / weight[i];
                    int[] yj = yi[i];
                    for (j = 0; j < size; ++j) {
                        xj = MathEx.randomInt((int)ni);
                        int n2 = yj[xj];
                        samples[n2] = samples[n2] + 1;
                    }
                }
            } else {
                for (i = 0; i < k; ++i) {
                    int size = (int)Math.round(subsample * (double)count[i] / (double)weight[i]);
                    int[] yj = yi[i];
                    int[] permutation = MathEx.permutate((int)count[i]);
                    for (j = 0; j < size; ++j) {
                        xj = permutation[j];
                        int n3 = yj[xj];
                        samples[n3] = samples[n3] + 1;
                    }
                }
            }
            long start = System.nanoTime();
            DecisionTree tree = new DecisionTree(x, codec.y, y.field(), k, options.rule, options.maxDepth, maxNodes, options.nodeSize, mtry, samples, order);
            double fitTime = (double)(System.nanoTime() - start) / 1000000.0;
            start = System.nanoTime();
            int noob = 0;
            for (int i = 0; i < n; ++i) {
                if (samples[i] != 0) continue;
                ++noob;
            }
            int[] truth = new int[noob];
            int[] oob = new int[noob];
            double[][] posteriori = new double[noob][k];
            int j = 0;
            for (int i = 0; i < n; ++i) {
                int p2;
                if (samples[i] != 0) continue;
                truth[j] = codec.y[i];
                oob[j] = p2 = tree.predict(x.get(i), posteriori[j]);
                int[] nArray = prediction[i];
                int n4 = p2;
                nArray[n4] = nArray[n4] + 1;
                ++j;
            }
            double scoreTime = (double)(System.nanoTime() - start) / 1000000.0;
            int oobk = MathEx.unique((int[])truth).length;
            if (oobk == 2) {
                double[] probability = Arrays.stream(posteriori).mapToDouble(p -> p[1]).toArray();
                metrics = ClassificationMetrics.binary(fitTime, scoreTime, truth, oob, probability);
            } else {
                metrics = ClassificationMetrics.of(fitTime, scoreTime, truth, oob);
            }
            logger.info("Tree {}: OOB = {}, accuracy = {}%", new Object[]{t + 1, noob, String.format("%.2f", 100.0 * metrics.accuracy())});
            if (options.controller != null) {
                options.controller.submit((Object)new TrainingStatus(t + 1, metrics));
            }
            return new Model(tree, metrics);
        }).toArray(Model[]::new);
        double fitTime = 0.0;
        double scoreTime = 0.0;
        for (Model model : models) {
            fitTime += model.metrics.fitTime();
            scoreTime += model.metrics.scoreTime();
        }
        int[] vote = new int[n];
        for (int i5 = 0; i5 < n; ++i5) {
            vote[i5] = MathEx.whichMax((int[])prediction[i5]);
        }
        ClassificationMetrics metrics = new ClassificationMetrics(fitTime, scoreTime, n, Error.of(codec.y, vote), Accuracy.of(codec.y, vote));
        return new RandomForest(formula, k, models, metrics, RandomForest.importance(models), codec.classes);
    }

    private static double[] importance(Model[] models) {
        int p = models[0].tree.importance().length;
        double[] importance = new double[p];
        for (Model model : models) {
            double[] imp = model.tree.importance();
            for (int i = 0; i < p; ++i) {
                int n = i;
                importance[n] = importance[n] + imp[i];
            }
        }
        return importance;
    }

    @Override
    public Formula formula() {
        return this.formula;
    }

    @Override
    public StructType schema() {
        return this.models[0].tree.schema();
    }

    public ClassificationMetrics metrics() {
        return this.metrics;
    }

    public double[] importance() {
        return this.importance;
    }

    public int size() {
        return this.models.length;
    }

    public Model[] models() {
        return this.models;
    }

    public DecisionTree[] trees() {
        return (DecisionTree[])Arrays.stream(this.models).map(model -> model.tree).toArray(DecisionTree[]::new);
    }

    public RandomForest trim(int ntrees) {
        if (ntrees > this.models.length) {
            throw new IllegalArgumentException("The new model size is larger than the current size.");
        }
        if (ntrees <= 0) {
            throw new IllegalArgumentException("Invalid new model size: " + ntrees);
        }
        Arrays.sort(this.models);
        return new RandomForest(this.formula, this.k, Arrays.copyOf(this.models, ntrees), this.metrics, RandomForest.importance(this.models), this.classes);
    }

    public RandomForest merge(RandomForest other) {
        if (!this.formula.equals((Object)other.formula)) {
            throw new IllegalArgumentException("RandomForest have different model formula");
        }
        Model[] forest = new Model[this.models.length + other.models.length];
        System.arraycopy(this.models, 0, forest, 0, this.models.length);
        System.arraycopy(other.models, 0, forest, this.models.length, other.models.length);
        ClassificationMetrics mergedMetrics = new ClassificationMetrics(this.metrics.fitTime() + other.metrics.fitTime(), this.metrics.scoreTime() + other.metrics.scoreTime(), this.metrics.size(), (this.metrics.error() + other.metrics.error()) / 2, (this.metrics.accuracy() + other.metrics.accuracy()) / 2.0, (this.metrics.sensitivity() + other.metrics.sensitivity()) / 2.0, (this.metrics.specificity() + other.metrics.specificity()) / 2.0, (this.metrics.precision() + other.metrics.precision()) / 2.0, (this.metrics.f1() + other.metrics.f1()) / 2.0, (this.metrics.mcc() + other.metrics.mcc()) / 2.0, (this.metrics.auc() + other.metrics.auc()) / 2.0, (this.metrics.logloss() + other.metrics.logloss()) / 2.0, (this.metrics.crossEntropy() + other.metrics.crossEntropy()) / 2.0);
        double[] mergedImportance = (double[])this.importance.clone();
        for (int i = 0; i < this.importance.length; ++i) {
            int n = i;
            mergedImportance[n] = mergedImportance[n] + other.importance[i];
        }
        return new RandomForest(this.formula, this.k, forest, mergedMetrics, mergedImportance, this.classes);
    }

    @Override
    public int predict(Tuple x) {
        Tuple xt = this.formula.x(x);
        int[] y = new int[this.k];
        for (Model model : this.models) {
            int n = model.tree.predict(xt);
            y[n] = y[n] + 1;
        }
        return this.classes.valueOf(MathEx.whichMax((int[])y));
    }

    @Override
    public boolean soft() {
        return true;
    }

    @Override
    public int predict(Tuple x, double[] posteriori) {
        if (posteriori.length != this.k) {
            throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", posteriori.length, this.k));
        }
        Tuple xt = this.formula.x(x);
        double[] prob = new double[this.k];
        Arrays.fill(posteriori, 0.0);
        for (Model model : this.models) {
            model.tree.predict(xt, prob);
            for (int i = 0; i < this.k; ++i) {
                int n = i;
                posteriori[n] = posteriori[n] + model.weight * prob[i];
            }
        }
        MathEx.unitize1((double[])posteriori);
        return this.classes.valueOf(MathEx.whichMax((double[])posteriori));
    }

    public int vote(Tuple x, double[] posteriori) {
        if (posteriori.length != this.k) {
            throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", posteriori.length, this.k));
        }
        Tuple xt = this.formula.x(x);
        Arrays.fill(posteriori, 0.0);
        for (Model model : this.models) {
            int n = model.tree.predict(xt);
            posteriori[n] = posteriori[n] + 1.0;
        }
        MathEx.unitize1((double[])posteriori);
        return this.classes.valueOf(MathEx.whichMax((double[])posteriori));
    }

    public int[][] test(DataFrame data) {
        DataFrame x = this.formula.x(data);
        int n = x.size();
        int ntrees = this.models.length;
        int[] p = new int[this.k];
        int[][] prediction = new int[ntrees][n];
        for (int j = 0; j < n; ++j) {
            Tuple xj = x.get(j);
            Arrays.fill(p, 0);
            for (int i = 0; i < ntrees; ++i) {
                int n2 = this.models[i].tree.predict(xj);
                p[n2] = p[n2] + 1;
                prediction[i][j] = MathEx.whichMax((int[])p);
            }
        }
        return prediction;
    }

    public RandomForest prune(DataFrame test) {
        Model[] forest = (Model[])((Stream)Arrays.stream(this.models).parallel()).map(model -> new Model(model.tree.prune(test, this.formula, this.classes), model.metrics)).toArray(Model[]::new);
        return new RandomForest(this.formula, this.k, forest, this.metrics, RandomForest.importance(forest), this.classes);
    }

    public record Model(DecisionTree tree, ClassificationMetrics metrics, double weight) implements Serializable,
    Comparable<Model>
    {
        public Model(DecisionTree tree, ClassificationMetrics metrics) {
            this(tree, metrics, metrics.accuracy());
        }

        @Override
        public int compareTo(Model o) {
            return Double.compare(o.weight, this.weight);
        }
    }

    public record Options(int ntrees, int mtry, SplitRule rule, int maxDepth, int maxNodes, int nodeSize, double subsample, int[] classWeight, long[] seeds, IterativeAlgorithmController<TrainingStatus> controller) {
        public Options {
            if (ntrees < 1) {
                throw new IllegalArgumentException("Invalid number of trees: " + ntrees);
            }
            if (maxDepth < 2) {
                throw new IllegalArgumentException("Invalid maximal tree depth: " + maxDepth);
            }
            if (nodeSize < 1) {
                throw new IllegalArgumentException("Invalid node size: " + nodeSize);
            }
            if (subsample <= 0.0 || subsample > 1.0) {
                throw new IllegalArgumentException("Invalid sampling rate: " + subsample);
            }
            if (seeds != null && seeds.length < ntrees) {
                throw new IllegalArgumentException("The number of RNG seeds is fewer than that of trees: " + seeds.length);
            }
        }

        public Options(int ntrees) {
            this(ntrees, 0, 20, 0, 5);
        }

        public Options(int ntrees, int mtry, int maxDepth, int maxNodes, int nodeSize) {
            this(ntrees, mtry, SplitRule.GINI, maxDepth, maxNodes, nodeSize, 1.0, null, null, null);
        }

        public Properties toProperties() {
            Properties props = new Properties();
            props.setProperty("smile.random_forest.trees", Integer.toString(this.ntrees));
            props.setProperty("smile.random_forest.mtry", Integer.toString(this.mtry));
            props.setProperty("smile.random_forest.split_rule", this.rule.toString());
            props.setProperty("smile.random_forest.max_depth", Integer.toString(this.maxDepth));
            props.setProperty("smile.random_forest.max_nodes", Integer.toString(this.maxNodes));
            props.setProperty("smile.random_forest.node_size", Integer.toString(this.nodeSize));
            props.setProperty("smile.random_forest.sampling_rate", Double.toString(this.subsample));
            if (this.classWeight != null) {
                props.setProperty("smile.random_forest.class_weight", Arrays.toString(this.classWeight));
            }
            return props;
        }

        public static Options of(Properties props) {
            int ntrees = Integer.parseInt(props.getProperty("smile.random_forest.trees", "500"));
            int mtry = Integer.parseInt(props.getProperty("smile.random_forest.mtry", "0"));
            SplitRule rule = SplitRule.valueOf(props.getProperty("smile.random_forest.split_rule", "GINI"));
            int maxDepth = Integer.parseInt(props.getProperty("smile.random_forest.max_depth", "20"));
            int maxNodes = Integer.parseInt(props.getProperty("smile.random_forest.max_nodes", "0"));
            int nodeSize = Integer.parseInt(props.getProperty("smile.random_forest.node_size", "5"));
            double subsample = Double.parseDouble(props.getProperty("smile.random_forest.sampling_rate", "1.0"));
            int[] classWeight = Strings.parseIntArray((String)props.getProperty("smile.random_forest.class_weight"));
            return new Options(ntrees, mtry, rule, maxDepth, maxNodes, nodeSize, subsample, classWeight, null, null);
        }
    }

    public record TrainingStatus(int tree, ClassificationMetrics metrics) {
    }
}

