/*
 * Decompiled with CFR 0.152.
 */
package smile.base.cart;

import java.io.Serializable;
import java.math.BigInteger;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedList;
import java.util.Objects;
import java.util.Optional;
import java.util.PriorityQueue;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.base.cart.InternalNode;
import smile.base.cart.LeafNode;
import smile.base.cart.Node;
import smile.base.cart.Split;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.data.measure.Measure;
import smile.data.measure.NominalScale;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.math.MathEx;
import smile.sort.QuickSort;

public abstract class CART
implements Serializable {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(CART.class);
    protected StructType schema;
    protected StructField response;
    protected Formula formula = null;
    protected Node root;
    protected int maxDepth = 20;
    protected int maxNodes = 6;
    protected int nodeSize = 5;
    protected int mtry = -1;
    protected double[] importance;
    protected transient DataFrame x;
    protected transient int[] samples;
    protected transient int[] index;
    protected transient int[][] order;
    private transient int[] buffer;

    private CART() {
    }

    public CART(Formula formula, StructType schema, StructField response, Node root, double[] importance) {
        this.formula = formula;
        this.schema = schema;
        this.response = response;
        this.root = root;
        this.importance = importance;
    }

    public CART(DataFrame x, StructField y, int maxDepth, int maxNodes, int nodeSize, int mtry, int[] samples, int[][] order) {
        IntStream idx;
        this.x = x;
        this.response = y;
        this.schema = x.schema();
        this.importance = new double[x.ncols()];
        this.maxDepth = maxDepth;
        this.maxNodes = maxNodes;
        this.nodeSize = nodeSize;
        this.mtry = mtry;
        int n = x.size();
        int p = x.ncols();
        if (mtry < 1 || mtry > p) {
            logger.debug("Invalid mtry. Use all features.");
            this.mtry = this.schema.length();
        }
        if (maxDepth < 1) {
            throw new IllegalArgumentException("Invalid maximum depth: " + maxDepth);
        }
        if (maxNodes < 2) {
            throw new IllegalArgumentException("Invalid maximum leaves: " + maxNodes);
        }
        if (nodeSize < 1) {
            throw new IllegalArgumentException("Invalid minimum size of leaf nodes: " + nodeSize);
        }
        if (samples == null) {
            this.samples = Collections.nCopies(n, 1).parallelStream().mapToInt(i -> i).toArray();
            idx = IntStream.range(0, n);
        } else {
            this.samples = samples;
            idx = IntStream.range(0, samples.length).filter(i -> samples[i] > 0);
        }
        this.index = idx.toArray();
        this.buffer = new int[this.index.length];
        if (order == null) {
            this.order = CART.order(x);
        } else {
            this.order = new int[order.length][];
            for (int i2 = 0; i2 < order.length; ++i2) {
                if (order[i2] == null) continue;
                this.order[i2] = Arrays.stream(order[i2]).filter(o -> this.samples[o] > 0).toArray();
            }
        }
    }

    public int size() {
        return this.size(this.root);
    }

    private int size(Node node) {
        if (node instanceof LeafNode) {
            return 1;
        }
        InternalNode parent = (InternalNode)node;
        return this.size(parent.trueChild) + this.size(parent.falseChild) + 1;
    }

    public static int[][] order(DataFrame x) {
        int n = x.size();
        int p = x.ncols();
        StructType schema = x.schema();
        double[] a = new double[n];
        int[][] order = new int[p][];
        for (int j = 0; j < p; ++j) {
            Measure measure = schema.field((int)j).measure;
            if (measure != null && measure instanceof NominalScale) continue;
            x.column(j).toDoubleArray(a);
            order[j] = QuickSort.sort((double[])a);
        }
        return order;
    }

    protected Tuple predictors(Tuple x) {
        return this.formula == null ? x : this.formula.x(x);
    }

    protected void clear() {
        this.x = null;
        this.order = null;
        this.index = null;
        this.samples = null;
        this.buffer = null;
    }

    protected boolean split(Split split, PriorityQueue<Split> queue) {
        if (split.feature < 0) {
            throw new IllegalStateException("Split a node with invalid feature.");
        }
        if (split.depth >= this.maxDepth) {
            logger.debug("Reach maximum depth");
            return false;
        }
        if (split.trueCount < this.nodeSize || split.falseCount < this.nodeSize) {
            logger.debug("Node size is too small after splitting");
            return false;
        }
        int[] trueSamples = IntStream.range(split.lo, split.hi).map(i -> this.index[i]).filter(i -> split.predicate().test(i)).toArray();
        boolean[] trues = new boolean[this.samples.length];
        for (int i2 : trueSamples) {
            trues[i2] = true;
        }
        int[] falseSamples = IntStream.range(split.lo, split.hi).map(i -> this.index[i]).filter(i -> !trues[i]).toArray();
        int mid = split.lo + trueSamples.length;
        LeafNode trueChild = this.newNode(trueSamples);
        LeafNode falseChild = this.newNode(falseSamples);
        InternalNode node = split.toNode(trueChild, falseChild);
        this.shuffle(split.lo, mid, split.hi, trues);
        Optional<Split> trueSplit = this.findBestSplit(trueChild, split.lo, mid, (boolean[])split.unsplittable.clone());
        Optional<Split> falseSplit = this.findBestSplit(falseChild, mid, split.hi, split.unsplittable);
        if (trueChild.equals(falseChild) && !trueSplit.isPresent() && !falseSplit.isPresent()) {
            return false;
        }
        if (split.parent == null) {
            this.root = node;
        } else if (split.parent.trueChild == split.leaf) {
            split.parent.trueChild = node;
        } else if (split.parent.falseChild == split.leaf) {
            split.parent.falseChild = node;
        } else {
            throw new IllegalStateException("split.parent and leaf don't match");
        }
        int n = node.feature;
        this.importance[n] = this.importance[n] + node.score;
        trueSplit.ifPresent(s -> {
            s.parent = node;
            s.depth = split.depth + 1;
        });
        falseSplit.ifPresent(s -> {
            s.parent = node;
            s.depth = split.depth + 1;
        });
        if (queue == null) {
            trueSplit.ifPresent(s -> this.split((Split)s, null));
            falseSplit.ifPresent(s -> this.split((Split)s, null));
        } else {
            trueSplit.ifPresent(s -> queue.add((Split)s));
            falseSplit.ifPresent(s -> queue.add((Split)s));
        }
        return true;
    }

    protected Optional<Split> findBestSplit(LeafNode node, int lo, int hi, boolean[] unsplittable) {
        if (node.size() < 2 * this.nodeSize) {
            return Optional.empty();
        }
        double impurity = this.impurity(node);
        if (impurity == 0.0) {
            return Optional.empty();
        }
        int p = this.schema.length();
        int[] columns = IntStream.range(0, p).filter(i -> !unsplittable[i]).toArray();
        if (this.mtry < p) {
            MathEx.permutate((int[])columns);
        }
        IntStream stream = Arrays.stream(columns).limit(this.mtry);
        Optional<Split> split = (this.mtry < p ? stream : stream.parallel()).mapToObj(j -> {
            Optional<Split> s = this.findBestSplit(node, j, impurity, lo, hi);
            if (!s.isPresent()) {
                unsplittable[j] = true;
            }
            return s;
        }).filter(Optional::isPresent).map(Optional::get).max(Split.comparator);
        split.ifPresent(s -> {
            s.unsplittable = unsplittable;
        });
        return split;
    }

    protected abstract double impurity(LeafNode var1);

    protected abstract LeafNode newNode(int[] var1);

    protected abstract Optional<Split> findBestSplit(LeafNode var1, int var2, double var3, int var5, int var6);

    public double[] importance() {
        return this.importance;
    }

    public Node root() {
        return this.root;
    }

    public String dot() {
        StringBuilder builder = new StringBuilder();
        builder.append("digraph CART {\n node [shape=box, style=\"filled, rounded\", color=\"black\", fontname=helvetica];\n edge [fontname=helvetica];\n");
        String trueLabel = " [labeldistance=2.5, labelangle=45, headlabel=\"True\"];\n";
        String falseLabel = " [labeldistance=2.5, labelangle=-45, headlabel=\"False\"];\n";
        LinkedList<AbstractMap.SimpleEntry<Integer, Node>> queue = new LinkedList<AbstractMap.SimpleEntry<Integer, Node>>();
        queue.add(new AbstractMap.SimpleEntry<Integer, Node>(1, this.root));
        while (!queue.isEmpty()) {
            AbstractMap.SimpleEntry entry = (AbstractMap.SimpleEntry)queue.poll();
            int id = (Integer)entry.getKey();
            Node node = (Node)entry.getValue();
            builder.append(node.dot(this.schema, this.response, id));
            if (!(node instanceof InternalNode)) continue;
            int tid = 2 * id;
            int fid = 2 * id + 1;
            InternalNode inode = (InternalNode)node;
            queue.add(new AbstractMap.SimpleEntry<Integer, Node>(tid, inode.trueChild));
            queue.add(new AbstractMap.SimpleEntry<Integer, Node>(fid, inode.falseChild));
            builder.append(' ').append(id).append(" -> ").append(tid).append(trueLabel);
            builder.append(' ').append(id).append(" -> ").append(fid).append(falseLabel);
            if (id != 1) continue;
            trueLabel = "\n";
            falseLabel = "\n";
        }
        builder.append("}");
        return builder.toString();
    }

    private void shuffle(int low, int split, int high, boolean[] predicate) {
        Arrays.stream(this.order).filter(Objects::nonNull).forEach(o -> this.shuffle((int[])o, low, split, high, predicate));
        this.shuffle(this.index, low, split, high, predicate);
    }

    private void shuffle(int[] a, int low, int split, int high, boolean[] predicate) {
        int k = 0;
        int j = low;
        for (int i = low; i < high; ++i) {
            if (predicate[a[i]]) {
                a[j++] = a[i];
                continue;
            }
            this.buffer[k++] = a[i];
        }
        assert (split + k == high);
        System.arraycopy(this.buffer, 0, a, split, k);
    }

    public String toString() {
        ArrayList<String> lines = new ArrayList<String>();
        this.root.toString(this.schema, this.response, null, 0, BigInteger.ONE, lines);
        lines.add("* denotes terminal node");
        lines.add("node), split, n, loss, yval, (yprob)");
        lines.add("n=" + this.root.size());
        Collections.reverse(lines);
        return String.join((CharSequence)"\n", lines);
    }
}

