/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Properties;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.base.cart.CART;
import smile.base.cart.SplitRule;
import smile.classification.ClassLabels;
import smile.classification.DataFrameClassifier;
import smile.classification.DecisionTree;
import smile.classification.SoftClassifier;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.data.type.StructType;
import smile.data.vector.BaseVector;
import smile.math.MathEx;
import smile.util.IntSet;
import smile.util.Strings;

public class RandomForest
implements SoftClassifier<Tuple>,
DataFrameClassifier {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(RandomForest.class);
    private Formula formula;
    private List<Tree> trees;
    private int k = 2;
    private double error;
    private double[] importance;
    private IntSet labels;

    public RandomForest(Formula formula, int k, List<Tree> trees, double error, double[] importance) {
        this(formula, k, trees, error, importance, IntSet.of((int)k));
    }

    public RandomForest(Formula formula, int k, List<Tree> trees, double error, double[] importance, IntSet labels) {
        this.formula = formula;
        this.k = k;
        this.trees = trees;
        this.error = error;
        this.importance = importance;
        this.labels = labels;
    }

    public static RandomForest fit(Formula formula, DataFrame data) {
        return RandomForest.fit(formula, data, new Properties());
    }

    public static RandomForest fit(Formula formula, DataFrame data, Properties prop) {
        int ntrees = Integer.valueOf(prop.getProperty("smile.random.forest.trees", "500"));
        int mtry = Integer.valueOf(prop.getProperty("smile.random.forest.mtry", "0"));
        SplitRule rule = SplitRule.valueOf(prop.getProperty("smile.random.forest.split.rule", "GINI"));
        int maxDepth = Integer.valueOf(prop.getProperty("smile.random.forest.max.depth", "20"));
        int maxNodes = Integer.valueOf(prop.getProperty("smile.random.forest.max.nodes", String.valueOf(data.size() / 5)));
        int nodeSize = Integer.valueOf(prop.getProperty("smile.random.forest.node.size", "5"));
        double subsample = Double.valueOf(prop.getProperty("smile.random.forest.sample.rate", "1.0"));
        int[] classWeight = Strings.parseIntArray((String)prop.getProperty("smile.random.forest.class.weight"));
        return RandomForest.fit(formula, data, ntrees, mtry, rule, maxDepth, maxNodes, nodeSize, subsample, classWeight, null);
    }

    public static RandomForest fit(Formula formula, DataFrame data, int ntrees, int mtry, SplitRule rule, int maxDepth, int maxNodes, int nodeSize, double subsample) {
        return RandomForest.fit(formula, data, ntrees, mtry, rule, maxDepth, maxNodes, nodeSize, subsample, null);
    }

    public static RandomForest fit(Formula formula, DataFrame data, int ntrees, int mtry, SplitRule rule, int maxDepth, int maxNodes, int nodeSize, double subsample, int[] classWeight) {
        return RandomForest.fit(formula, data, ntrees, mtry, rule, maxDepth, maxNodes, nodeSize, subsample, classWeight, null);
    }

    public static RandomForest fit(Formula formula, DataFrame data, int ntrees, int mtry, SplitRule rule, int maxDepth, int maxNodes, int nodeSize, double subsample, int[] classWeight, LongStream seeds) {
        if (ntrees < 1) {
            throw new IllegalArgumentException("Invalid number of trees: " + ntrees);
        }
        if (subsample <= 0.0 || subsample > 1.0) {
            throw new IllegalArgumentException("Invalid sampling rating: " + subsample);
        }
        DataFrame x = formula.x(data);
        BaseVector y = formula.y(data);
        if (mtry > x.ncols()) {
            throw new IllegalArgumentException("Invalid number of variables to split on at a node of the tree: " + mtry);
        }
        int mtryFinal = mtry > 0 ? mtry : (int)Math.sqrt(x.ncols());
        ClassLabels codec = ClassLabels.fit(y);
        int k = codec.k;
        int n = x.nrows();
        int[] weight = classWeight != null ? classWeight : Collections.nCopies(k, 1).stream().mapToInt(i -> i).toArray();
        int[][] order = CART.order(x);
        int[][] prediction = new int[n][k];
        long[] seedArray = (seeds != null ? seeds : LongStream.range(-ntrees, 0L)).sequential().distinct().limit(ntrees).toArray();
        if (seedArray.length != ntrees) {
            throw new IllegalArgumentException(String.format("seed stream has only %d distinct values, expected %d", seedArray.length, ntrees));
        }
        int[] count = new int[k];
        for (int i2 = 0; i2 < n; ++i2) {
            int n2 = y.getInt(i2);
            count[n2] = count[n2] + 1;
        }
        int[][] yi = new int[k][];
        for (int i3 = 0; i3 < k; ++i3) {
            yi[i3] = new int[count[i3]];
        }
        int[] idx = new int[k];
        int i4 = 0;
        while (i4 < n) {
            int j;
            int n3 = j = y.getInt(i4);
            int n4 = idx[n3];
            idx[n3] = n4 + 1;
            yi[j][n4] = i4++;
        }
        List<Tree> trees = Arrays.stream(seedArray).parallel().mapToObj(seed -> {
            int xj;
            int j;
            int i;
            if (seed > 1L) {
                MathEx.setSeed((long)seed);
            }
            int[] samples = new int[n];
            if (subsample == 1.0) {
                for (i = 0; i < k; ++i) {
                    int ni = count[i];
                    int size = ni / weight[i];
                    int[] yj = yi[i];
                    for (j = 0; j < size; ++j) {
                        xj = MathEx.randomInt((int)ni);
                        int n2 = yj[xj];
                        samples[n2] = samples[n2] + 1;
                    }
                }
            } else {
                for (i = 0; i < k; ++i) {
                    int size = (int)Math.round(subsample * (double)count[i] / (double)weight[i]);
                    int[] yj = yi[i];
                    int[] permutation = MathEx.permutate((int)count[i]);
                    for (j = 0; j < size; ++j) {
                        xj = permutation[j];
                        int n3 = yj[xj];
                        samples[n3] = samples[n3] + 1;
                    }
                }
            }
            DecisionTree tree = new DecisionTree(x, codec.y, codec.field, k, rule, maxDepth, maxNodes, nodeSize, mtryFinal, samples, order);
            int oob = 0;
            int correct = 0;
            for (int i2 = 0; i2 < n; ++i2) {
                if (samples[i2] != 0) continue;
                ++oob;
                int p = tree.predict((Tuple)x.get(i2));
                if (p == y.getInt(i2)) {
                    ++correct;
                }
                int[] nArray = prediction[i2];
                int n4 = p;
                nArray[n4] = nArray[n4] + 1;
            }
            double accuracy = 1.0;
            if (oob != 0) {
                accuracy = (double)correct / (double)oob;
                logger.info("Random forest tree OOB size: {}, accuracy: {}", (Object)oob, (Object)String.format("%.2f%%", 100.0 * accuracy));
            } else {
                logger.error("Random forest has a tree trained without OOB samples.");
            }
            return new Tree(tree, accuracy);
        }).collect(Collectors.toList());
        int err = 0;
        int m = 0;
        for (int i5 = 0; i5 < n; ++i5) {
            int pred = MathEx.whichMax((int[])prediction[i5]);
            if (prediction[i5][pred] <= 0) continue;
            ++m;
            if (pred == y.getInt(i5)) continue;
            ++err;
        }
        double error = m > 0 ? (double)err / (double)m : 0.0;
        return new RandomForest(formula, k, trees, error, RandomForest.importance(trees), codec.labels);
    }

    private static double[] importance(List<Tree> trees) {
        int p = trees.get((int)0).tree.importance().length;
        double[] importance = new double[p];
        for (Tree tree : trees) {
            double[] imp = tree.tree.importance();
            for (int i = 0; i < p; ++i) {
                int n = i;
                importance[n] = importance[n] + imp[i];
            }
        }
        return importance;
    }

    @Override
    public Formula formula() {
        return this.formula;
    }

    @Override
    public StructType schema() {
        return this.trees.get((int)0).tree.schema();
    }

    public double error() {
        return this.error;
    }

    public double[] importance() {
        return this.importance;
    }

    public int size() {
        return this.trees.size();
    }

    public DecisionTree[] trees() {
        return (DecisionTree[])this.trees.stream().map(t -> t.tree).toArray(DecisionTree[]::new);
    }

    public void trim(int ntrees) {
        if (ntrees > this.trees.size()) {
            throw new IllegalArgumentException("The new model size is larger than the current size.");
        }
        if (ntrees <= 0) {
            throw new IllegalArgumentException("Invalid new model size: " + ntrees);
        }
        ArrayList<Tree> model = new ArrayList<Tree>(ntrees);
        for (int i = 0; i < ntrees; ++i) {
            model.add(this.trees.get(i));
        }
        this.trees = model;
    }

    @Override
    public int predict(Tuple x) {
        Tuple xt = this.formula.x(x);
        int[] y = new int[this.k];
        for (Tree tree : this.trees) {
            int n = tree.tree.predict(xt);
            y[n] = y[n] + 1;
        }
        return this.labels.valueOf(MathEx.whichMax((int[])y));
    }

    @Override
    public int predict(Tuple x, double[] posteriori) {
        if (posteriori.length != this.k) {
            throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", posteriori.length, this.k));
        }
        Tuple xt = this.formula.x(x);
        double[] prob = new double[this.k];
        Arrays.fill(posteriori, 0.0);
        for (Tree tree : this.trees) {
            tree.tree.predict(xt, prob);
            for (int i = 0; i < this.k; ++i) {
                int n = i;
                posteriori[n] = posteriori[n] + tree.weight * prob[i];
            }
        }
        MathEx.unitize1((double[])posteriori);
        return this.labels.valueOf(MathEx.whichMax((double[])posteriori));
    }

    public int vote(Tuple x, double[] posteriori) {
        if (posteriori.length != this.k) {
            throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", posteriori.length, this.k));
        }
        Tuple xt = this.formula.x(x);
        Arrays.fill(posteriori, 0.0);
        for (Tree tree : this.trees) {
            int n = tree.tree.predict(xt);
            posteriori[n] = posteriori[n] + 1.0;
        }
        MathEx.unitize1((double[])posteriori);
        return this.labels.valueOf(MathEx.whichMax((double[])posteriori));
    }

    public int[][] test(DataFrame data) {
        DataFrame x = this.formula.x(data);
        int n = x.size();
        int ntrees = this.trees.size();
        int[] p = new int[this.k];
        int[][] prediction = new int[ntrees][n];
        for (int j = 0; j < n; ++j) {
            Tuple xj = (Tuple)x.get(j);
            Arrays.fill(p, 0);
            for (int i = 0; i < ntrees; ++i) {
                int n2 = this.trees.get((int)i).tree.predict(xj);
                p[n2] = p[n2] + 1;
                prediction[i][j] = MathEx.whichMax((int[])p);
            }
        }
        return prediction;
    }

    public RandomForest prune(DataFrame test) {
        List<Tree> forest = ((Stream)this.trees.stream().parallel()).map(tree -> new Tree(tree.tree.prune(test, this.formula, this.labels), tree.weight)).collect(Collectors.toList());
        return new RandomForest(this.formula, this.k, forest, this.error, RandomForest.importance(forest), this.labels);
    }

    static class Tree
    implements Serializable {
        DecisionTree tree;
        double weight;

        Tree(DecisionTree tree, double weight) {
            this.tree = tree;
            this.weight = weight;
        }
    }
}

