/*
 * Decompiled with CFR 0.152.
 */
package smile.regression;

import java.util.Optional;
import java.util.PriorityQueue;
import java.util.Properties;
import java.util.stream.IntStream;
import smile.base.cart.CART;
import smile.base.cart.LeafNode;
import smile.base.cart.Loss;
import smile.base.cart.NominalSplit;
import smile.base.cart.OrdinalSplit;
import smile.base.cart.RegressionNode;
import smile.base.cart.Split;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.data.measure.Measure;
import smile.data.measure.NominalScale;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.data.vector.BaseVector;
import smile.math.MathEx;
import smile.regression.DataFrameRegression;
import smile.regression.Regression;

public class RegressionTree
extends CART
implements Regression<Tuple>,
DataFrameRegression {
    private static final long serialVersionUID = 2L;
    private transient double[] y;
    private transient Loss loss;

    @Override
    protected double impurity(LeafNode node) {
        return ((RegressionNode)node).impurity();
    }

    @Override
    protected LeafNode newNode(int[] nodeSamples) {
        int n;
        double out;
        double mean = out = this.loss.output(nodeSamples, this.samples);
        if (!this.loss.toString().equals("LeastSquares")) {
            n = 0;
            mean = 0.0;
            for (int i : nodeSamples) {
                n += this.samples[i];
                mean += this.y[i] * (double)this.samples[i];
            }
            mean /= (double)n;
        }
        n = 0;
        double rss = 0.0;
        for (int i : nodeSamples) {
            n += this.samples[i];
            rss += (double)this.samples[i] * MathEx.sqr((double)(this.y[i] - mean));
        }
        return new RegressionNode(n, out, mean, rss);
    }

    @Override
    protected Optional<Split> findBestSplit(LeafNode leaf, int j, double impurity, int lo, int hi) {
        RegressionNode node = (RegressionNode)leaf;
        BaseVector xj = this.x.column(j);
        double sum = IntStream.range(lo, hi).map(i -> this.index[i]).mapToDouble(i -> this.y[i] * (double)this.samples[i]).sum();
        double nodeMeanSquared = (double)node.size() * node.mean() * node.mean();
        Split split = null;
        double splitScore = 0.0;
        int splitTrueCount = 0;
        int splitFalseCount = 0;
        Measure measure = this.schema.field((int)j).measure;
        if (measure instanceof NominalScale) {
            int splitValue = -1;
            NominalScale scale = (NominalScale)measure;
            int m = scale.size();
            int[] trueCount = new int[m];
            double[] trueSum = new double[m];
            for (int i2 = lo; i2 < hi; ++i2) {
                int idx;
                int o2 = this.index[i2];
                int n = idx = xj.getInt(o2);
                trueCount[n] = trueCount[n] + this.samples[o2];
                int n2 = idx;
                trueSum[n2] = trueSum[n2] + this.y[o2] * (double)this.samples[o2];
            }
            for (int l : scale.values()) {
                double falseMean;
                double trueMean;
                double gain;
                int tc = trueCount[l];
                int fc = node.size() - tc;
                if (tc < this.nodeSize || fc < this.nodeSize || !((gain = (double)tc * (trueMean = trueSum[l] / (double)tc) * trueMean + (double)fc * (falseMean = (sum - trueSum[l]) / (double)fc) * falseMean - nodeMeanSquared) > splitScore)) continue;
                splitValue = l;
                splitTrueCount = tc;
                splitFalseCount = fc;
                splitScore = gain;
            }
            if (splitScore > 0.0) {
                int value = splitValue;
                split = new NominalSplit(leaf, j, splitValue, splitScore, lo, hi, splitTrueCount, splitFalseCount, o -> xj.getInt(o) == value);
            }
        } else {
            double splitValue = 0.0;
            int tc = 0;
            double trueSum = 0.0;
            int[] orderj = this.order[j];
            int first = orderj[lo];
            double prevx = xj.getDouble(first);
            for (int i3 = lo; i3 < hi; ++i3) {
                double falseMean;
                double trueMean;
                double gain;
                int fc = 0;
                int o3 = orderj[i3];
                double xij = xj.getDouble(o3);
                if (xij != prevx) {
                    fc = node.size() - tc;
                }
                if (tc >= this.nodeSize && fc >= this.nodeSize && (gain = (double)tc * (trueMean = trueSum / (double)tc) * trueMean + (double)fc * (falseMean = (sum - trueSum) / (double)fc) * falseMean - nodeMeanSquared) > splitScore) {
                    splitValue = (xij + prevx) / 2.0;
                    splitTrueCount = tc;
                    splitFalseCount = fc;
                    splitScore = gain;
                }
                prevx = xij;
                trueSum += this.y[o3] * (double)this.samples[o3];
                tc += this.samples[o3];
            }
            if (splitScore > 0.0) {
                double value = splitValue;
                split = new OrdinalSplit(leaf, j, splitValue, splitScore, lo, hi, splitTrueCount, splitFalseCount, o -> xj.getDouble(o) <= value);
            }
        }
        return Optional.ofNullable(split);
    }

    public RegressionTree(DataFrame x, Loss loss, StructField response, int maxDepth, int maxNodes, int nodeSize, int mtry, int[] samples, int[][] order) {
        super(x, response, maxDepth, maxNodes, nodeSize, mtry, samples, order);
        this.loss = loss;
        this.y = loss.response();
        LeafNode node = this.newNode(IntStream.range(0, x.size()).filter(i -> this.samples[i] > 0).toArray());
        this.root = node;
        Optional<Split> split = this.findBestSplit(node, 0, this.index.length, new boolean[x.ncols()]);
        if (maxNodes == Integer.MAX_VALUE) {
            split.ifPresent(s -> this.split((Split)s, null));
        } else {
            PriorityQueue<Split> queue = new PriorityQueue<Split>(2 * maxNodes, Split.comparator.reversed());
            split.ifPresent(s -> queue.add((Split)s));
            int leaves = 1;
            while (leaves < this.maxNodes && !queue.isEmpty()) {
                if (!this.split(queue.poll(), queue)) continue;
                ++leaves;
            }
        }
        this.root = this.root.merge();
        this.clear();
    }

    public static RegressionTree fit(Formula formula, DataFrame data) {
        return RegressionTree.fit(formula, data, new Properties());
    }

    public static RegressionTree fit(Formula formula, DataFrame data, Properties prop) {
        int maxDepth = Integer.valueOf(prop.getProperty("smile.cart.max.depth", "20"));
        int maxNodes = Integer.valueOf(prop.getProperty("smile.cart.max.nodes", String.valueOf(data.size() / 5)));
        int nodeSize = Integer.valueOf(prop.getProperty("smile.cart.node.size", "5"));
        return RegressionTree.fit(formula, data, maxDepth, maxNodes, nodeSize);
    }

    public static RegressionTree fit(Formula formula, DataFrame data, int maxDepth, int maxNodes, int nodeSize) {
        DataFrame x = formula.x(data);
        BaseVector y = formula.y(data);
        RegressionTree tree = new RegressionTree(x, Loss.ls(y.toDoubleArray()), y.field(), maxDepth, maxNodes, nodeSize, -1, null, null);
        tree.formula = formula;
        return tree;
    }

    @Override
    public double predict(Tuple x) {
        RegressionNode leaf = (RegressionNode)this.root.predict(this.predictors(x));
        return leaf.output();
    }

    @Override
    public Formula formula() {
        return this.formula;
    }

    @Override
    public StructType schema() {
        return this.schema;
    }
}

