package smile.classification;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.PriorityQueue;
import java.util.Properties;
import java.util.stream.Collectors;
import smile.base.cart.CART;
import smile.base.cart.DecisionNode;
import smile.base.cart.InternalNode;
import smile.base.cart.LeafNode;
import smile.base.cart.Node;
import smile.base.cart.NominalSplit;
import smile.base.cart.OrdinalSplit;
import smile.base.cart.Split;
import smile.base.cart.SplitRule;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.data.measure.NominalScale;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.data.vector.BaseVector;
import smile.math.MathEx;
import smile.util.IntSet;

/* loaded from: input_file:smile/classification/DecisionTree.class */
public class DecisionTree extends CART implements SoftClassifier<Tuple>, DataFrameClassifier {
    private static final long serialVersionUID = 2;
    private SplitRule rule;
    private int k;
    private IntSet labels;
    private transient int[] y;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:smile/classification/DecisionTree$Prune.class */
    public static class Prune {
        Node node;
        int error;
        int[] count;

        Prune(Node node, int i, int[] iArr) {
            this.node = node;
            this.error = i;
            this.count = iArr;
        }
    }

    @Override // smile.base.cart.CART
    protected double impurity(LeafNode leafNode) {
        return ((DecisionNode) leafNode).impurity(this.rule);
    }

    @Override // smile.base.cart.CART
    protected LeafNode newNode(int[] iArr) {
        int[] iArr2 = new int[this.k];
        for (int i : iArr) {
            int i2 = this.y[i];
            iArr2[i2] = iArr2[i2] + this.samples[i];
        }
        return new DecisionNode(iArr2);
    }

    @Override // smile.base.cart.CART
    protected Optional<Split> findBestSplit(LeafNode leafNode, int i, double d, int i2, int i3) {
        DecisionNode decisionNode = (DecisionNode) leafNode;
        BaseVector column = this.x.column(i);
        int[] iArr = new int[this.k];
        Split split = null;
        double d2 = 0.0d;
        int i4 = 0;
        int i5 = 0;
        NominalScale nominalScale = this.schema.field(i).measure;
        if (nominalScale instanceof NominalScale) {
            int i6 = -1;
            NominalScale nominalScale2 = nominalScale;
            int[][] iArr2 = new int[nominalScale2.size()][this.k];
            for (int i7 = i2; i7 < i3; i7++) {
                int i8 = this.index[i7];
                int[] iArr3 = iArr2[column.getInt(i8)];
                int i9 = this.y[i8];
                iArr3[i9] = iArr3[i9] + this.samples[i8];
            }
            for (int i10 : nominalScale2.values()) {
                int sum = (int) MathEx.sum(iArr2[i10]);
                int size = decisionNode.size() - sum;
                if (sum >= this.nodeSize && size >= this.nodeSize) {
                    for (int i11 = 0; i11 < this.k; i11++) {
                        iArr[i11] = decisionNode.count()[i11] - iArr2[i10][i11];
                    }
                    double size2 = (d - ((sum / decisionNode.size()) * DecisionNode.impurity(this.rule, sum, iArr2[i10]))) - ((size / decisionNode.size()) * DecisionNode.impurity(this.rule, size, iArr));
                    if (size2 > d2) {
                        i6 = i10;
                        i4 = sum;
                        i5 = size;
                        d2 = size2;
                    }
                }
            }
            if (d2 > 0.0d) {
                int i12 = i6;
                split = new NominalSplit(leafNode, i, i6, d2, i2, i3, i4, i5, i13 -> {
                    return column.getInt(i13) == i12;
                });
            }
        } else {
            double d3 = 0.0d;
            int[] iArr4 = new int[this.k];
            int[] iArr5 = this.order[i];
            int i14 = iArr5[i2];
            double d4 = column.getDouble(i14);
            int i15 = this.y[i14];
            for (int i16 = i2; i16 < i3; i16++) {
                int i17 = 0;
                int i18 = 0;
                int i19 = iArr5[i16];
                int i20 = this.y[i19];
                double d5 = column.getDouble(i19);
                if (i20 != i15 && d5 != d4) {
                    i17 = (int) MathEx.sum(iArr4);
                    i18 = decisionNode.size() - i17;
                }
                if (i17 >= this.nodeSize && i18 >= this.nodeSize) {
                    for (int i21 = 0; i21 < this.k; i21++) {
                        iArr[i21] = decisionNode.count()[i21] - iArr4[i21];
                    }
                    double size3 = (d - ((i17 / decisionNode.size()) * DecisionNode.impurity(this.rule, i17, iArr4))) - ((i18 / decisionNode.size()) * DecisionNode.impurity(this.rule, i18, iArr));
                    if (size3 > d2) {
                        d3 = (d5 + d4) / 2.0d;
                        i4 = i17;
                        i5 = i18;
                        d2 = size3;
                    }
                }
                d4 = d5;
                i15 = i20;
                iArr4[i15] = iArr4[i15] + this.samples[i19];
            }
            if (d2 > 0.0d) {
                double d6 = d3;
                split = new OrdinalSplit(leafNode, i, d3, d2, i2, i3, i4, i5, i22 -> {
                    return column.getDouble(i22) <= d6;
                });
            }
        }
        return Optional.ofNullable(split);
    }

    public DecisionTree(DataFrame dataFrame, int[] iArr, StructField structField, int i, SplitRule splitRule, int i2, int i3, int i4, int i5, int[] iArr2, int[][] iArr3) {
        super(dataFrame, structField, i2, i3, i4, i5, iArr2, iArr3);
        this.rule = SplitRule.GINI;
        this.k = 2;
        this.labels = null;
        this.k = i;
        this.y = iArr;
        this.rule = splitRule;
        int[] iArr4 = new int[i];
        int size = dataFrame.size();
        for (int i6 = 0; i6 < size; i6++) {
            int i7 = iArr[i6];
            iArr4[i7] = iArr4[i7] + this.samples[i6];
        }
        DecisionNode decisionNode = new DecisionNode(iArr4);
        this.root = decisionNode;
        Optional<Split> findBestSplit = findBestSplit(decisionNode, 0, this.index.length, new boolean[dataFrame.ncols()]);
        if (i3 == Integer.MAX_VALUE) {
            findBestSplit.ifPresent(split -> {
                split(split, null);
            });
        } else {
            PriorityQueue<Split> priorityQueue = new PriorityQueue<>(2 * i3, Split.comparator.reversed());
            findBestSplit.ifPresent(split2 -> {
                priorityQueue.add(split2);
            });
            int i8 = 1;
            while (i8 < this.maxNodes && !priorityQueue.isEmpty()) {
                if (split(priorityQueue.poll(), priorityQueue)) {
                    i8++;
                }
            }
        }
        this.root = this.root.merge();
        clear();
    }

    public static DecisionTree fit(Formula formula, DataFrame dataFrame) {
        return fit(formula, dataFrame, new Properties());
    }

    public static DecisionTree fit(Formula formula, DataFrame dataFrame, Properties properties) {
        return fit(formula, dataFrame, SplitRule.valueOf(properties.getProperty("smile.cart.split.rule", "GINI")), Integer.valueOf(properties.getProperty("smile.cart.max.depth", "20")).intValue(), Integer.valueOf(properties.getProperty("smile.cart.max.nodes", String.valueOf(dataFrame.size() / 5))).intValue(), Integer.valueOf(properties.getProperty("smile.cart.node.size", "5")).intValue());
    }

    public static DecisionTree fit(Formula formula, DataFrame dataFrame, SplitRule splitRule, int i, int i2, int i3) {
        DataFrame x = formula.x(dataFrame);
        ClassLabels fit = ClassLabels.fit(formula.y(dataFrame));
        DecisionTree decisionTree = new DecisionTree(x, fit.y, fit.field, fit.k, splitRule, i, i2, i3, -1, null, (int[][]) null);
        decisionTree.formula = formula;
        decisionTree.labels = fit.labels;
        return decisionTree;
    }

    @Override // smile.classification.Classifier
    public int predict(Tuple tuple) {
        int output = ((DecisionNode) this.root.predict(predictors(tuple))).output();
        return this.labels == null ? output : this.labels.valueOf(output);
    }

    @Override // smile.classification.SoftClassifier
    public int predict(Tuple tuple, double[] dArr) {
        DecisionNode decisionNode = (DecisionNode) this.root.predict(predictors(tuple));
        decisionNode.posteriori(dArr);
        int output = decisionNode.output();
        return this.labels == null ? output : this.labels.valueOf(output);
    }

    @Override // smile.classification.DataFrameClassifier
    public Formula formula() {
        return this.formula;
    }

    @Override // smile.classification.DataFrameClassifier
    public StructType schema() {
        return this.schema;
    }

    private DecisionTree(Formula formula, StructType structType, StructField structField, Node node, int i, SplitRule splitRule, double[] dArr, IntSet intSet) {
        super(formula, structType, structField, node, dArr);
        this.rule = SplitRule.GINI;
        this.k = 2;
        this.labels = null;
        this.k = i;
        this.rule = splitRule;
        this.labels = intSet;
    }

    public DecisionTree prune(DataFrame dataFrame) {
        return prune(dataFrame, this.formula, this.labels);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public DecisionTree prune(DataFrame dataFrame, Formula formula, IntSet intSet) {
        double[] dArr = (double[]) this.importance.clone();
        return new DecisionTree(this.formula, this.schema, this.response, prune(this.root, (List) dataFrame.stream().collect(Collectors.toList()), dArr, formula, intSet).node, this.k, this.rule, dArr, this.labels);
    }

    private Prune prune(Node node, List<Tuple> list, double[] dArr, Formula formula, IntSet intSet) {
        Node replace;
        if (node instanceof DecisionNode) {
            DecisionNode decisionNode = (DecisionNode) node;
            int output = decisionNode.output();
            int i = 0;
            Iterator<Tuple> it = list.iterator();
            while (it.hasNext()) {
                if (output != intSet.indexOf(formula.yint(it.next()))) {
                    i++;
                }
            }
            return new Prune(node, i, decisionNode.count());
        }
        InternalNode internalNode = (InternalNode) node;
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (Tuple tuple : list) {
            if (internalNode.branch(formula.x(tuple))) {
                arrayList.add(tuple);
            } else {
                arrayList2.add(tuple);
            }
        }
        Prune prune = prune(internalNode.trueChild(), arrayList, dArr, formula, intSet);
        Prune prune2 = prune(internalNode.falseChild(), arrayList2, dArr, formula, intSet);
        int[] iArr = new int[this.k];
        for (int i2 = 0; i2 < this.k; i2++) {
            iArr[i2] = prune.count[i2] + prune2.count[i2];
        }
        int whichMax = MathEx.whichMax(iArr);
        int i3 = 0;
        Iterator<Tuple> it2 = list.iterator();
        while (it2.hasNext()) {
            if (whichMax != intSet.indexOf(formula.yint(it2.next()))) {
                i3++;
            }
        }
        if (i3 < prune.error + prune2.error) {
            replace = new DecisionNode(iArr);
            int feature = internalNode.feature();
            dArr[feature] = dArr[feature] - internalNode.score();
        } else {
            i3 = prune.error + prune2.error;
            replace = internalNode.replace(prune.node, prune2.node);
        }
        return new Prune(replace, i3, iArr);
    }
}
