/*
 * Decompiled with CFR 0.152.
 */
package smile.regression;

import java.util.Arrays;
import java.util.Optional;
import java.util.Properties;
import java.util.function.LongSupplier;
import java.util.stream.IntStream;
import java.util.stream.LongStream;
import smile.base.cart.CART;
import smile.base.cart.Loss;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.data.vector.BaseVector;
import smile.math.MathEx;
import smile.regression.DataFrameRegression;
import smile.regression.Regression;
import smile.regression.RegressionTree;

public class RandomForest
implements Regression<Tuple>,
DataFrameRegression {
    private static final long serialVersionUID = 2L;
    private Formula formula;
    private RegressionTree[] trees;
    private double error;
    private double[] importance;

    public RandomForest(Formula formula, RegressionTree[] trees, double error, double[] importance) {
        this.formula = formula;
        this.trees = trees;
        this.error = error;
        this.importance = importance;
    }

    public static RandomForest fit(Formula formula, DataFrame data) {
        return RandomForest.fit(formula, data, new Properties());
    }

    public static RandomForest fit(Formula formula, DataFrame data, Properties prop) {
        int ntrees = Integer.valueOf(prop.getProperty("smile.random.forest.trees", "500"));
        int mtry = Integer.valueOf(prop.getProperty("smile.random.forest.mtry", "0"));
        int maxDepth = Integer.valueOf(prop.getProperty("smile.random.forest.max.depth", "20"));
        int maxNodes = Integer.valueOf(prop.getProperty("smile.random.forest.max.nodes", String.valueOf(data.size() / 5)));
        int nodeSize = Integer.valueOf(prop.getProperty("smile.random.forest.node.size", "5"));
        double subsample = Double.valueOf(prop.getProperty("smile.random.forest.sample.rate", "1.0"));
        return RandomForest.fit(formula, data, ntrees, mtry, maxDepth, maxNodes, nodeSize, subsample);
    }

    public static RandomForest fit(Formula formula, DataFrame data, int ntrees, int mtry, int maxDepth, int maxNodes, int nodeSize, double subsample) {
        return RandomForest.fit(formula, data, ntrees, mtry, maxDepth, maxNodes, nodeSize, subsample, Optional.empty());
    }

    public static RandomForest fit(Formula formula, DataFrame data, int ntrees, int mtry, int maxDepth, int maxNodes, int nodeSize, double subsample, LongSupplier seedGenerator) {
        return RandomForest.fit(formula, data, ntrees, mtry, maxDepth, maxNodes, nodeSize, subsample, Optional.of(LongStream.generate(seedGenerator)));
    }

    public static RandomForest fit(Formula formula, DataFrame data, int ntrees, int mtry, int maxDepth, int maxNodes, int nodeSize, double subsample, Optional<LongStream> seeds) {
        if (ntrees < 1) {
            throw new IllegalArgumentException("Invalid number of trees: " + ntrees);
        }
        if (subsample <= 0.0 || subsample > 1.0) {
            throw new IllegalArgumentException("Invalid sampling rate: " + subsample);
        }
        DataFrame x = formula.x(data);
        BaseVector response = formula.y(data);
        StructField field = response.field();
        double[] y = response.toDoubleArray();
        if (mtry > x.ncols()) {
            throw new IllegalArgumentException("Invalid number of variables to split on at a node of the tree: " + mtry);
        }
        int mtryFinal = mtry > 0 ? mtry : Math.max(x.ncols() / 3, 1);
        int n = x.nrows();
        double[] prediction = new double[n];
        int[] oob = new int[n];
        int[][] order = CART.order(x);
        long[] seedArray = seeds.orElse(LongStream.range(-ntrees, 0L)).sequential().distinct().limit(ntrees).toArray();
        if (seedArray.length != ntrees) {
            throw new IllegalArgumentException(String.format("seed stream has only %d distinct values, expected %d", seedArray.length, ntrees));
        }
        RegressionTree[] trees = (RegressionTree[])Arrays.stream(seedArray).parallel().mapToObj(seed -> {
            if (seed > 1L) {
                MathEx.setSeed((long)seed);
            }
            int[] samples = new int[n];
            if (subsample == 1.0) {
                for (int i2 = 0; i2 < n; ++i2) {
                    int n2 = MathEx.randomInt((int)n);
                    samples[n2] = samples[n2] + 1;
                }
            } else {
                int[] permutation = MathEx.permutate((int)n);
                int N = (int)Math.round((double)n * subsample);
                for (int i3 = 0; i3 < N; ++i3) {
                    samples[permutation[i3]] = 1;
                }
            }
            RegressionTree tree = new RegressionTree(x, Loss.ls(y), field, maxDepth, maxNodes, nodeSize, mtryFinal, samples, order);
            IntStream.range(0, n).filter(i -> samples[i] == 0).forEach(i -> {
                double pred = tree.predict((Tuple)x.get(i));
                int n = i;
                prediction[n] = prediction[n] + pred;
                int n2 = i;
                oob[n2] = oob[n2] + 1;
            });
            return tree;
        }).toArray(RegressionTree[]::new);
        int m = 0;
        double error = 0.0;
        for (int i = 0; i < n; ++i) {
            if (oob[i] <= 0) continue;
            ++m;
            double pred = prediction[i] / (double)oob[i];
            error += MathEx.sqr((double)(pred - y[i]));
        }
        if (m > 0) {
            error = Math.sqrt(error / (double)m);
        }
        double[] importance = RandomForest.calculateImportance(trees);
        return new RandomForest(formula, trees, error, importance);
    }

    public RandomForest merge(RandomForest other) {
        if (!this.formula.equals(other.formula)) {
            throw new IllegalArgumentException("RandomForest have different sizes of feature vectors");
        }
        RegressionTree[] forest = new RegressionTree[this.trees.length + other.trees.length];
        System.arraycopy(this.trees, 0, forest, 0, this.trees.length);
        System.arraycopy(other.trees, 0, forest, this.trees.length, other.trees.length);
        double mergedError = this.error * other.error / 2.0;
        double[] mergedImportance = (double[])this.importance.clone();
        for (int i = 0; i < this.importance.length; ++i) {
            int n = i;
            mergedImportance[n] = mergedImportance[n] + other.importance[i];
        }
        return new RandomForest(this.formula, forest, mergedError, mergedImportance);
    }

    private static double[] calculateImportance(RegressionTree[] trees) {
        double[] importance = new double[trees[0].importance().length];
        for (RegressionTree tree : trees) {
            double[] imp = tree.importance();
            for (int i = 0; i < imp.length; ++i) {
                int n = i;
                importance[n] = importance[n] + imp[i];
            }
        }
        return importance;
    }

    @Override
    public Formula formula() {
        return this.formula;
    }

    @Override
    public StructType schema() {
        return this.trees[0].schema();
    }

    public double error() {
        return this.error;
    }

    public double[] importance() {
        return this.importance;
    }

    public int size() {
        return this.trees.length;
    }

    public RegressionTree[] trees() {
        return this.trees;
    }

    public void trim(int ntrees) {
        if (ntrees > this.trees.length) {
            throw new IllegalArgumentException("The new model size is larger than the current size.");
        }
        if (ntrees <= 0) {
            throw new IllegalArgumentException("Invalid new model size: " + ntrees);
        }
        RegressionTree[] model = new RegressionTree[ntrees];
        System.arraycopy(this.trees, 0, model, 0, ntrees);
        this.trees = model;
    }

    @Override
    public double predict(Tuple x) {
        Tuple xt = this.formula.x(x);
        double y = 0.0;
        for (RegressionTree tree : this.trees) {
            y += tree.predict(xt);
        }
        return y / (double)this.trees.length;
    }

    public double[][] test(DataFrame data) {
        DataFrame x = this.formula.x(data);
        int n = x.nrows();
        int ntrees = this.trees.length;
        double[][] prediction = new double[ntrees][n];
        for (int j = 0; j < n; ++j) {
            Tuple xj = (Tuple)x.get(j);
            double base = 0.0;
            for (int i = 0; i < ntrees; ++i) {
                prediction[i][j] = (base += this.trees[i].predict(xj)) / (double)(i + 1);
            }
        }
        return prediction;
    }
}

