/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.PriorityQueue;
import java.util.Properties;
import smile.base.cart.CART;
import smile.base.cart.DecisionNode;
import smile.base.cart.InternalNode;
import smile.base.cart.LeafNode;
import smile.base.cart.Node;
import smile.base.cart.NominalSplit;
import smile.base.cart.OrdinalSplit;
import smile.base.cart.Split;
import smile.base.cart.SplitRule;
import smile.classification.ClassLabels;
import smile.classification.Classifier;
import smile.classification.DataFrameClassifier;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.data.measure.Measure;
import smile.data.measure.NominalScale;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.data.vector.ValueVector;
import smile.math.MathEx;
import smile.util.IntSet;

public class DecisionTree
extends CART
implements Classifier<Tuple>,
DataFrameClassifier {
    private static final long serialVersionUID = 2L;
    private final SplitRule rule;
    private final int k;
    private IntSet classes;
    private transient int[] y;

    @Override
    protected double impurity(LeafNode node) {
        return ((DecisionNode)node).impurity(this.rule);
    }

    @Override
    protected LeafNode newNode(int[] nodeSamples) {
        int[] count = new int[this.k];
        for (int i : nodeSamples) {
            int n = this.y[i];
            count[n] = count[n] + this.samples[i];
        }
        return new DecisionNode(count);
    }

    @Override
    protected Optional<Split> findBestSplit(LeafNode leaf, int j, double impurity, int lo, int hi) {
        DecisionNode node = (DecisionNode)leaf;
        ValueVector xj = this.x.column(j);
        int[] falseCount = new int[this.k];
        Split split = null;
        double splitScore = 0.0;
        int splitTrueCount = 0;
        int splitFalseCount = 0;
        Measure measure = this.schema.field(j).measure();
        if (measure instanceof NominalScale) {
            NominalScale scale = (NominalScale)measure;
            int splitValue = -1;
            int m = scale.size();
            int[][] trueCount = new int[m][this.k];
            for (int i = lo; i < hi; ++i) {
                int o2 = this.index[i];
                int[] nArray = trueCount[xj.getInt(o2)];
                int n = this.y[o2];
                nArray[n] = nArray[n] + this.samples[o2];
            }
            for (int l : scale.values()) {
                int tc = (int)MathEx.sum((int[])trueCount[l]);
                int fc = node.size() - tc;
                if (tc < this.nodeSize || fc < this.nodeSize) continue;
                for (int q = 0; q < this.k; ++q) {
                    falseCount[q] = node.count()[q] - trueCount[l][q];
                }
                double gain = impurity - (double)tc / (double)node.size() * DecisionNode.impurity(this.rule, tc, trueCount[l]) - (double)fc / (double)node.size() * DecisionNode.impurity(this.rule, fc, falseCount);
                if (!(gain > splitScore)) continue;
                splitValue = l;
                splitTrueCount = tc;
                splitFalseCount = fc;
                splitScore = gain;
            }
            if (splitScore > 0.0) {
                int value = splitValue;
                split = new NominalSplit(leaf, j, splitValue, splitScore, lo, hi, splitTrueCount, splitFalseCount, o -> xj.getInt(o) == value);
            }
        } else {
            double splitValue = 0.0;
            int[] trueCount = new int[this.k];
            int[] orderj = this.order[j];
            int first = orderj[lo];
            double prevx = xj.getDouble(first);
            int prevy = this.y[first];
            for (int i = lo; i < hi; ++i) {
                int tc = 0;
                int fc = 0;
                int o3 = orderj[i];
                int yi = this.y[o3];
                double xij = xj.getDouble(o3);
                if (yi != prevy && !MathEx.isZero((double)(xij - prevx), (double)1.0E-7)) {
                    tc = (int)MathEx.sum((int[])trueCount);
                    fc = node.size() - tc;
                }
                if (tc >= this.nodeSize && fc >= this.nodeSize) {
                    for (int l = 0; l < this.k; ++l) {
                        falseCount[l] = node.count()[l] - trueCount[l];
                    }
                    double gain = impurity - (double)tc / (double)node.size() * DecisionNode.impurity(this.rule, tc, trueCount) - (double)fc / (double)node.size() * DecisionNode.impurity(this.rule, fc, falseCount);
                    if (gain > splitScore) {
                        splitValue = (xij + prevx) / 2.0;
                        splitTrueCount = tc;
                        splitFalseCount = fc;
                        splitScore = gain;
                    }
                }
                prevx = xij;
                int n = prevy = yi;
                trueCount[n] = trueCount[n] + this.samples[o3];
            }
            if (splitScore > 0.0) {
                double value = splitValue;
                split = new OrdinalSplit(leaf, j, splitValue, splitScore, lo, hi, splitTrueCount, splitFalseCount, o -> xj.getDouble(o) <= value);
            }
        }
        return Optional.ofNullable(split);
    }

    public DecisionTree(DataFrame x, int[] y, StructField response, int k, SplitRule rule, int maxDepth, int maxNodes, int nodeSize, int mtry, int[] samples, int[][] order) {
        super(x, response, maxDepth, maxNodes, nodeSize, mtry, samples, order);
        this.k = k;
        this.y = y;
        this.rule = rule;
        int[] count = new int[k];
        int n = x.size();
        for (int i = 0; i < n; ++i) {
            int n2 = y[i];
            count[n2] = count[n2] + this.samples[i];
        }
        DecisionNode node = new DecisionNode(count);
        this.root = node;
        Optional<Split> split = this.findBestSplit(node, 0, this.index.length, new boolean[x.ncol()]);
        if (maxNodes == Integer.MAX_VALUE) {
            split.ifPresent(s -> this.split((Split)s, null));
        } else {
            PriorityQueue<Split> queue = new PriorityQueue<Split>(2 * maxNodes, Split.comparator.reversed());
            split.ifPresent(queue::add);
            int leaves = 1;
            while (leaves < this.maxNodes && !queue.isEmpty()) {
                if (!this.split(queue.poll(), queue)) continue;
                ++leaves;
            }
        }
        this.root = this.root.merge();
        this.clear();
    }

    public static DecisionTree fit(Formula formula, DataFrame data) {
        return DecisionTree.fit(formula, data, new Options());
    }

    public static DecisionTree fit(Formula formula, DataFrame data, Options options) {
        formula = formula.expand(data.schema());
        DataFrame x = formula.x(data);
        ValueVector y = formula.y(data);
        ClassLabels codec = ClassLabels.fit(y);
        int mtry = x.ncol();
        int maxNodes = options.maxNodes > 0 ? options.maxNodes : data.size() / options.nodeSize;
        DecisionTree tree = new DecisionTree(x, codec.y, y.field(), codec.k, options.rule, options.maxDepth, maxNodes, options.nodeSize, mtry, null, null);
        tree.formula = formula;
        tree.classes = codec.classes;
        return tree;
    }

    @Override
    public int numClasses() {
        return this.classes.size();
    }

    @Override
    public int[] classes() {
        return this.classes.values;
    }

    @Override
    public int predict(Tuple x) {
        DecisionNode leaf = (DecisionNode)this.root.predict(this.predictors(x));
        int y = leaf.output();
        return this.classes == null ? y : this.classes.valueOf(y);
    }

    @Override
    public boolean soft() {
        return true;
    }

    @Override
    public int predict(Tuple x, double[] posteriori) {
        DecisionNode leaf = (DecisionNode)this.root.predict(this.predictors(x));
        leaf.posteriori(posteriori);
        int y = leaf.output();
        return this.classes == null ? y : this.classes.valueOf(y);
    }

    @Override
    public Formula formula() {
        return this.formula;
    }

    @Override
    public StructType schema() {
        return this.schema;
    }

    private DecisionTree(Formula formula, StructType schema, StructField response, Node root, int k, SplitRule rule, double[] importance, IntSet classes) {
        super(formula, schema, response, root, importance);
        this.k = k;
        this.rule = rule;
        this.classes = classes;
    }

    public DecisionTree prune(DataFrame test) {
        return this.prune(test, this.formula, this.classes);
    }

    DecisionTree prune(DataFrame test, Formula formula, IntSet classes) {
        double[] imp = (double[])this.importance.clone();
        Prune prune = this.prune(this.root, test.toList(), imp, formula, classes);
        return new DecisionTree(this.formula, this.schema, this.response, prune.node, this.k, this.rule, imp, this.classes);
    }

    private Prune prune(Node node, List<? extends Tuple> test, double[] importance, Formula formula, IntSet labels) {
        if (node instanceof DecisionNode) {
            DecisionNode leaf = (DecisionNode)node;
            int y = leaf.output();
            int error = 0;
            for (Tuple tuple : test) {
                if (y == labels.indexOf(formula.yint(tuple))) continue;
                ++error;
            }
            return new Prune(node, error, leaf.count());
        }
        InternalNode parent = (InternalNode)node;
        ArrayList<Tuple> trueBranch = new ArrayList<Tuple>();
        ArrayList<Tuple> falseBranch = new ArrayList<Tuple>();
        for (Tuple tuple : test) {
            if (parent.branch(formula.x(tuple))) {
                trueBranch.add(tuple);
                continue;
            }
            falseBranch.add(tuple);
        }
        Prune trueChild = this.prune(parent.trueChild(), trueBranch, importance, formula, labels);
        Prune prune = this.prune(parent.falseChild(), falseBranch, importance, formula, labels);
        int[] count = new int[this.k];
        for (int i = 0; i < this.k; ++i) {
            count[i] = trueChild.count[i] + prune.count[i];
        }
        int y = MathEx.whichMax((int[])count);
        int error = 0;
        for (Tuple tuple : test) {
            if (y == labels.indexOf(formula.yint(tuple))) continue;
            ++error;
        }
        if (error < trueChild.error + prune.error) {
            node = new DecisionNode(count);
            int n = parent.feature();
            importance[n] = importance[n] - parent.score();
        } else {
            error = trueChild.error + prune.error;
            node = parent.replace(trueChild.node, prune.node);
        }
        return new Prune(node, error, count);
    }

    public record Options(SplitRule rule, int maxDepth, int maxNodes, int nodeSize) {
        public Options {
            if (maxDepth < 2) {
                throw new IllegalArgumentException("Invalid maximal tree depth: " + maxDepth);
            }
            if (nodeSize < 1) {
                throw new IllegalArgumentException("Invalid node size: " + nodeSize);
            }
        }

        public Options() {
            this(SplitRule.GINI, 20, 0, 5);
        }

        public Properties toProperties() {
            Properties props = new Properties();
            props.setProperty("smile.cart.split_rule", this.rule.toString());
            props.setProperty("smile.cart.max_depth", Integer.toString(this.maxDepth));
            props.setProperty("smile.cart.max_nodes", Integer.toString(this.maxNodes));
            props.setProperty("smile.cart.node_size", Integer.toString(this.nodeSize));
            return props;
        }

        public static Options of(Properties props) {
            SplitRule rule = SplitRule.valueOf(props.getProperty("smile.cart.split_rule", "GINI"));
            int maxDepth = Integer.parseInt(props.getProperty("smile.cart.max_depth", "20"));
            int maxNodes = Integer.parseInt(props.getProperty("smile.cart.max_nodes", "0"));
            int nodeSize = Integer.parseInt(props.getProperty("smile.cart.node_size", "5"));
            return new Options(rule, maxDepth, maxNodes, nodeSize);
        }
    }

    record Prune(Node node, int error, int[] count) {
    }
}

