/*
 * Decompiled with CFR 0.152.
 */
package smile.regression;

import java.util.Optional;
import java.util.PriorityQueue;
import java.util.Properties;
import java.util.stream.IntStream;
import smile.base.cart.CART;
import smile.base.cart.LeafNode;
import smile.base.cart.Loss;
import smile.base.cart.NominalSplit;
import smile.base.cart.OrdinalSplit;
import smile.base.cart.RegressionNode;
import smile.base.cart.Split;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.data.measure.Measure;
import smile.data.measure.NominalScale;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.data.vector.ValueVector;
import smile.math.MathEx;
import smile.regression.DataFrameRegression;

public class RegressionTree
extends CART
implements DataFrameRegression {
    private static final long serialVersionUID = 2L;
    private final transient double[] y;
    private final transient Loss loss;

    @Override
    protected double impurity(LeafNode node) {
        return ((RegressionNode)node).impurity();
    }

    @Override
    protected LeafNode newNode(int[] nodeSamples) {
        int n;
        double out;
        double mean = out = this.loss.output(nodeSamples, this.samples);
        if (!this.loss.toString().equals("LeastSquares")) {
            n = 0;
            mean = 0.0;
            for (int i : nodeSamples) {
                n += this.samples[i];
                mean += this.y[i] * (double)this.samples[i];
            }
            mean /= (double)n;
        }
        n = 0;
        double rss = 0.0;
        for (int i : nodeSamples) {
            n += this.samples[i];
            rss += (double)this.samples[i] * MathEx.pow2((double)(this.y[i] - mean));
        }
        return new RegressionNode(n, out, mean, rss);
    }

    @Override
    protected Optional<Split> findBestSplit(LeafNode leaf, int j, double impurity, int lo, int hi) {
        RegressionNode node = (RegressionNode)leaf;
        ValueVector xj = this.x.column(j);
        double sum = 0.0;
        for (int i = lo; i < hi; ++i) {
            int idx = this.index[i];
            sum += this.y[idx] * (double)this.samples[idx];
        }
        double nodeMeanSquared = (double)node.size() * node.mean() * node.mean();
        Split split = null;
        double splitScore = 0.0;
        int splitTrueCount = 0;
        int splitFalseCount = 0;
        Measure measure = this.schema.field(j).measure();
        if (measure instanceof NominalScale) {
            NominalScale scale = (NominalScale)measure;
            int splitValue = -1;
            int m = scale.size();
            int[] trueCount = new int[m];
            double[] trueSum = new double[m];
            for (int i = lo; i < hi; ++i) {
                int idx;
                int o2 = this.index[i];
                int n = idx = xj.getInt(o2);
                trueCount[n] = trueCount[n] + this.samples[o2];
                int n2 = idx;
                trueSum[n2] = trueSum[n2] + this.y[o2] * (double)this.samples[o2];
            }
            for (int l : scale.values()) {
                double falseMean;
                double trueMean;
                double gain;
                int tc = trueCount[l];
                int fc = node.size() - tc;
                if (tc < this.nodeSize || fc < this.nodeSize || !((gain = (double)tc * (trueMean = trueSum[l] / (double)tc) * trueMean + (double)fc * (falseMean = (sum - trueSum[l]) / (double)fc) * falseMean - nodeMeanSquared) > splitScore)) continue;
                splitValue = l;
                splitTrueCount = tc;
                splitFalseCount = fc;
                splitScore = gain;
            }
            if (splitScore > 0.0) {
                int value = splitValue;
                split = new NominalSplit(leaf, j, splitValue, splitScore, lo, hi, splitTrueCount, splitFalseCount, o -> xj.getInt(o) == value);
            }
        } else {
            int[] orderj = this.order[j];
            int bins = Integer.parseInt(System.getProperty("smile.regression_tree.bins", "100"));
            int step = bins > 10 ? Math.max(1, this.y.length / bins) : 1;
            int k = 0;
            if (step > 1) {
                for (int i = 0; i < lo; ++i) {
                    k += this.samples[orderj[i]];
                }
            }
            int checkpoint = k / step;
            double splitValue = 0.0;
            double trueSum = 0.0;
            double prevx = xj.getDouble(orderj[lo]);
            int tc = 0;
            for (int i = lo; i < hi; ++i) {
                int o3 = orderj[i];
                double xij = xj.getDouble(o3);
                if (!MathEx.isZero((double)(xij - prevx), (double)1.0E-7)) {
                    int fc = node.size() - tc;
                    if (tc >= this.nodeSize && fc >= this.nodeSize && k / step > checkpoint) {
                        checkpoint = k / step;
                        double trueMean = trueSum / (double)tc;
                        double falseMean = (sum - trueSum) / (double)fc;
                        double gain = (double)tc * trueMean * trueMean + (double)fc * falseMean * falseMean - nodeMeanSquared;
                        if (gain > splitScore) {
                            splitValue = (xij + prevx) / 2.0;
                            splitTrueCount = tc;
                            splitFalseCount = fc;
                            splitScore = gain;
                        }
                    }
                }
                prevx = xij;
                trueSum += this.y[o3] * (double)this.samples[o3];
                tc += this.samples[o3];
                k += this.samples[o3];
            }
            if (splitScore > 0.0) {
                double value = splitValue;
                split = new OrdinalSplit(leaf, j, splitValue, splitScore, lo, hi, splitTrueCount, splitFalseCount, o -> xj.getDouble(o) <= value);
            }
        }
        return Optional.ofNullable(split);
    }

    public RegressionTree(DataFrame x, Loss loss, StructField response, int maxDepth, int maxNodes, int nodeSize, int mtry, int[] samples, int[][] order) {
        super(x, response, maxDepth, maxNodes, nodeSize, mtry, samples, order);
        this.loss = loss;
        this.y = loss.response();
        LeafNode node = this.newNode(IntStream.range(0, x.size()).filter(i -> this.samples[i] > 0).toArray());
        this.root = node;
        Optional<Split> split = this.findBestSplit(node, 0, this.index.length, new boolean[x.ncol()]);
        if (maxNodes == Integer.MAX_VALUE) {
            split.ifPresent(s -> this.split((Split)s, null));
        } else {
            PriorityQueue<Split> queue = new PriorityQueue<Split>(2 * maxNodes, Split.comparator.reversed());
            split.ifPresent(queue::add);
            int leaves = 1;
            while (leaves < this.maxNodes && !queue.isEmpty()) {
                if (!this.split(queue.poll(), queue)) continue;
                ++leaves;
            }
        }
        this.root = this.root.merge();
        this.clear();
    }

    public static RegressionTree fit(Formula formula, DataFrame data) {
        return RegressionTree.fit(formula, data, new Options());
    }

    public static RegressionTree fit(Formula formula, DataFrame data, Options options) {
        formula = formula.expand(data.schema());
        DataFrame x = formula.x(data);
        ValueVector y = formula.y(data);
        int mtry = x.ncol();
        int maxNodes = options.maxNodes > 0 ? options.maxNodes : data.size() / options.nodeSize;
        RegressionTree tree = new RegressionTree(x, Loss.ls(y.toDoubleArray()), y.field(), options.maxDepth, maxNodes, options.nodeSize, mtry, null, null);
        tree.formula = formula;
        return tree;
    }

    @Override
    public double predict(Tuple x) {
        RegressionNode leaf = (RegressionNode)this.root.predict(this.predictors(x));
        return leaf.output();
    }

    @Override
    public Formula formula() {
        return this.formula;
    }

    @Override
    public StructType schema() {
        return this.schema;
    }

    public record Options(int maxDepth, int maxNodes, int nodeSize) {
        public Options {
            if (maxDepth < 2) {
                throw new IllegalArgumentException("Invalid maximal tree depth: " + maxDepth);
            }
            if (nodeSize < 1) {
                throw new IllegalArgumentException("Invalid node size: " + nodeSize);
            }
        }

        public Options() {
            this(20, 0, 5);
        }

        public Properties toProperties() {
            Properties props = new Properties();
            props.setProperty("smile.cart.max_depth", Integer.toString(this.maxDepth));
            props.setProperty("smile.cart.max_nodes", Integer.toString(this.maxNodes));
            props.setProperty("smile.cart.node_size", Integer.toString(this.nodeSize));
            return props;
        }

        public static Options of(Properties props) {
            int maxDepth = Integer.parseInt(props.getProperty("smile.cart.max_depth", "20"));
            int maxNodes = Integer.parseInt(props.getProperty("smile.cart.max_nodes", "0"));
            int nodeSize = Integer.parseInt(props.getProperty("smile.cart.node_size", "5"));
            return new Options(maxDepth, maxNodes, nodeSize);
        }
    }
}

