/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.util.Arrays;
import java.util.Properties;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.base.cart.CART;
import smile.base.cart.Loss;
import smile.classification.AbstractClassifier;
import smile.classification.ClassLabels;
import smile.classification.DataFrameClassifier;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.data.type.DataType;
import smile.data.type.DataTypes;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.data.vector.BaseVector;
import smile.feature.importance.SHAP;
import smile.math.MathEx;
import smile.regression.RegressionTree;
import smile.util.IntSet;
import smile.util.Strings;

public class GradientTreeBoost
extends AbstractClassifier<Tuple>
implements DataFrameClassifier,
SHAP<Tuple> {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(GradientTreeBoost.class);
    private final Formula formula;
    private final int k;
    private RegressionTree[] trees;
    private RegressionTree[][] forest;
    private final double[] importance;
    private double b = 0.0;
    private final double shrinkage;

    public GradientTreeBoost(Formula formula, RegressionTree[] trees, double b, double shrinkage, double[] importance) {
        this(formula, trees, b, shrinkage, importance, IntSet.of((int)2));
    }

    public GradientTreeBoost(Formula formula, RegressionTree[] trees, double b, double shrinkage, double[] importance, IntSet labels) {
        super(labels);
        this.formula = formula;
        this.k = 2;
        this.trees = trees;
        this.b = b;
        this.shrinkage = shrinkage;
        this.importance = importance;
    }

    public GradientTreeBoost(Formula formula, RegressionTree[][] forest, double shrinkage, double[] importance) {
        this(formula, forest, shrinkage, importance, IntSet.of((int)forest.length));
    }

    public GradientTreeBoost(Formula formula, RegressionTree[][] forest, double shrinkage, double[] importance, IntSet labels) {
        super(labels);
        this.formula = formula;
        this.k = forest.length;
        this.forest = forest;
        this.shrinkage = shrinkage;
        this.importance = importance;
    }

    public static GradientTreeBoost fit(Formula formula, DataFrame data) {
        return GradientTreeBoost.fit(formula, data, new Properties());
    }

    public static GradientTreeBoost fit(Formula formula, DataFrame data, Properties params) {
        int ntrees = Integer.parseInt(params.getProperty("smile.gradient_boost.trees", "500"));
        int maxDepth = Integer.parseInt(params.getProperty("smile.gradient_boost.max_depth", "20"));
        int maxNodes = Integer.parseInt(params.getProperty("smile.gradient_boost.max_nodes", "6"));
        int nodeSize = Integer.parseInt(params.getProperty("smile.gradient_boost.node_size", "5"));
        double shrinkage = Double.parseDouble(params.getProperty("smile.gradient_boost.shrinkage", "0.05"));
        double subsample = Double.parseDouble(params.getProperty("smile.gradient_boost.sampling_rate", "0.7"));
        return GradientTreeBoost.fit(formula, data, ntrees, maxDepth, maxNodes, nodeSize, shrinkage, subsample);
    }

    public static GradientTreeBoost fit(Formula formula, DataFrame data, int ntrees, int maxDepth, int maxNodes, int nodeSize, double shrinkage, double subsample) {
        if (ntrees < 1) {
            throw new IllegalArgumentException("Invalid number of trees: " + ntrees);
        }
        if (shrinkage <= 0.0 || shrinkage > 1.0) {
            throw new IllegalArgumentException("Invalid shrinkage: " + shrinkage);
        }
        if (subsample <= 0.0 || subsample > 1.0) {
            throw new IllegalArgumentException("Invalid sampling fraction: " + subsample);
        }
        formula = formula.expand(data.schema());
        DataFrame x = formula.x(data);
        BaseVector y = formula.y(data);
        int[][] order = CART.order(x);
        ClassLabels codec = ClassLabels.fit(y);
        if (codec.k == 2) {
            return GradientTreeBoost.train2(formula, x, codec, order, ntrees, maxDepth, maxNodes, nodeSize, shrinkage, subsample);
        }
        return GradientTreeBoost.traink(formula, x, codec, order, ntrees, maxDepth, maxNodes, nodeSize, shrinkage, subsample);
    }

    @Override
    public Formula formula() {
        return this.formula;
    }

    @Override
    public StructType schema() {
        if (this.trees != null) {
            return this.trees[0].schema();
        }
        return this.forest[0][0].schema();
    }

    public double[] importance() {
        return this.importance;
    }

    private static GradientTreeBoost train2(Formula formula, DataFrame x, ClassLabels codec, int[][] order, int ntrees, int maxDepth, int maxNodes, int nodeSize, double shrinkage, double subsample) {
        int n = x.nrow();
        int k = codec.k;
        int[] y = codec.y;
        int[] nc = new int[k];
        for (int i = 0; i < n; ++i) {
            int n2 = y[i];
            nc[n2] = nc[n2] + 1;
        }
        Loss loss = Loss.logistic(y);
        double b = loss.intercept(null);
        double[] h = loss.residual();
        StructField field = new StructField("residual", (DataType)DataTypes.DoubleType);
        RegressionTree[] trees = new RegressionTree[ntrees];
        int[] permutation = IntStream.range(0, n).toArray();
        int[] samples = new int[n];
        for (int t = 0; t < ntrees; ++t) {
            RegressionTree tree;
            GradientTreeBoost.sampling(samples, permutation, nc, y, subsample);
            logger.info("Training {} tree", (Object)Strings.ordinal((int)(t + 1)));
            trees[t] = tree = new RegressionTree(x, loss, field, maxDepth, maxNodes, nodeSize, x.ncol(), samples, order);
            for (int i = 0; i < n; ++i) {
                int n3 = i;
                h[n3] = h[n3] + shrinkage * tree.predict((Tuple)x.get(i));
            }
        }
        double[] importance = new double[x.ncol()];
        for (RegressionTree tree : trees) {
            double[] imp = tree.importance();
            for (int i = 0; i < imp.length; ++i) {
                int n4 = i;
                importance[n4] = importance[n4] + imp[i];
            }
        }
        return new GradientTreeBoost(formula, trees, b, shrinkage, importance, codec.classes);
    }

    private static GradientTreeBoost traink(Formula formula, DataFrame x, ClassLabels codec, int[][] order, int ntrees, int maxDepth, int maxNodes, int nodeSize, double shrinkage, double subsample) {
        int n = x.nrow();
        int k = codec.k;
        int[] y = codec.y;
        int[] nc = new int[k];
        for (int i = 0; i < n; ++i) {
            int n2 = y[i];
            nc[n2] = nc[n2] + 1;
        }
        StructField field = new StructField("residual", (DataType)DataTypes.DoubleType);
        RegressionTree[][] forest = new RegressionTree[k][ntrees];
        double[][] p = new double[n][k];
        double[][] h = new double[k][];
        Loss[] loss = new Loss[k];
        for (int i = 0; i < k; ++i) {
            loss[i] = Loss.logistic(i, k, y, p);
            h[i] = loss[i].residual();
        }
        int[] permutation = IntStream.range(0, n).toArray();
        int[] samples = new int[n];
        for (int t = 0; t < ntrees; ++t) {
            logger.info("Training {} tree", (Object)Strings.ordinal((int)(t + 1)));
            for (int i = 0; i < n; ++i) {
                for (int j = 0; j < k; ++j) {
                    p[i][j] = h[j][i];
                }
                MathEx.softmax((double[])p[i]);
            }
            for (int j = 0; j < k; ++j) {
                RegressionTree tree;
                GradientTreeBoost.sampling(samples, permutation, nc, y, subsample);
                forest[j][t] = tree = new RegressionTree(x, loss[j], field, maxDepth, maxNodes, nodeSize, x.ncol(), samples, order);
                double[] hj = h[j];
                for (int i = 0; i < n; ++i) {
                    int n3 = i;
                    hj[n3] = hj[n3] + shrinkage * tree.predict((Tuple)x.get(i));
                }
            }
        }
        double[] importance = new double[x.ncol()];
        RegressionTree[][] regressionTreeArray = forest;
        int n4 = regressionTreeArray.length;
        for (int i = 0; i < n4; ++i) {
            RegressionTree[] grove;
            for (RegressionTree tree : grove = regressionTreeArray[i]) {
                double[] imp = tree.importance();
                for (int i2 = 0; i2 < imp.length; ++i2) {
                    int n5 = i2;
                    importance[n5] = importance[n5] + imp[i2];
                }
            }
        }
        return new GradientTreeBoost(formula, forest, shrinkage, importance, codec.classes);
    }

    private static void sampling(int[] samples, int[] permutation, int[] nc, int[] y, double subsample) {
        int n = samples.length;
        int k = nc.length;
        Arrays.fill(samples, 0);
        MathEx.permutate((int[])permutation);
        for (int j = 0; j < k; ++j) {
            int subj = (int)Math.round((double)nc[j] * subsample);
            int nj = 0;
            for (int i = 0; i < n && nj < subj; ++i) {
                int xi = permutation[i];
                if (y[xi] != j) continue;
                samples[xi] = 1;
                ++nj;
            }
        }
    }

    public int size() {
        return this.trees().length;
    }

    public RegressionTree[] trees() {
        if (this.trees != null) {
            return this.trees;
        }
        return (RegressionTree[])Arrays.stream(this.forest).flatMap(Arrays::stream).toArray(RegressionTree[]::new);
    }

    public void trim(int ntrees) {
        if (ntrees < 1) {
            throw new IllegalArgumentException("Invalid new model size: " + ntrees);
        }
        if (this.k == 2) {
            if (ntrees > this.trees.length) {
                throw new IllegalArgumentException("The new model size is larger than the current size.");
            }
            if (ntrees < this.trees.length) {
                this.trees = Arrays.copyOf(this.trees, ntrees);
            }
        } else {
            if (ntrees > this.forest[0].length) {
                throw new IllegalArgumentException("The new model size is larger than the current one.");
            }
            if (ntrees < this.forest[0].length) {
                for (int i = 0; i < this.forest.length; ++i) {
                    this.forest[i] = Arrays.copyOf(this.forest[i], ntrees);
                }
            }
        }
    }

    @Override
    public int predict(Tuple x) {
        Tuple xt = this.formula.x(x);
        if (this.k == 2) {
            double y = this.b;
            for (RegressionTree tree : this.trees) {
                y += this.shrinkage * tree.predict(xt);
            }
            return this.classes.valueOf(y > 0.0 ? 1 : 0);
        }
        double max = Double.NEGATIVE_INFINITY;
        int y = -1;
        for (int j = 0; j < this.k; ++j) {
            double yj = 0.0;
            for (RegressionTree tree : this.forest[j]) {
                yj += this.shrinkage * tree.predict(xt);
            }
            if (!(yj > max)) continue;
            max = yj;
            y = j;
        }
        return this.classes.valueOf(y);
    }

    @Override
    public boolean soft() {
        return true;
    }

    @Override
    public int predict(Tuple x, double[] posteriori) {
        int i;
        if (posteriori.length != this.k) {
            throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", posteriori.length, this.k));
        }
        Tuple xt = this.formula.x(x);
        if (this.k == 2) {
            double y = this.b;
            for (RegressionTree tree : this.trees) {
                y += this.shrinkage * tree.predict(xt);
            }
            posteriori[0] = 1.0 / (1.0 + Math.exp(2.0 * y));
            posteriori[1] = 1.0 - posteriori[0];
            return this.classes.valueOf(y > 0.0 ? 1 : 0);
        }
        double max = Double.NEGATIVE_INFINITY;
        int y = -1;
        for (int j = 0; j < this.k; ++j) {
            posteriori[j] = 0.0;
            for (RegressionTree tree : this.forest[j]) {
                int n = j;
                posteriori[n] = posteriori[n] + this.shrinkage * tree.predict(xt);
            }
            if (!(posteriori[j] > max)) continue;
            max = posteriori[j];
            y = j;
        }
        double Z = 0.0;
        for (i = 0; i < this.k; ++i) {
            posteriori[i] = Math.exp(posteriori[i] - max);
            Z += posteriori[i];
        }
        i = 0;
        while (i < this.k) {
            int n = i++;
            posteriori[n] = posteriori[n] / Z;
        }
        return this.classes.valueOf(y);
    }

    public int[][] test(DataFrame data) {
        DataFrame x = this.formula.x(data);
        int n = x.nrow();
        int ntrees = this.trees != null ? this.trees.length : this.forest[0].length;
        int[][] prediction = new int[ntrees][n];
        if (this.k == 2) {
            for (int j = 0; j < n; ++j) {
                Tuple xj = (Tuple)x.get(j);
                double base = 0.0;
                for (int i = 0; i < ntrees; ++i) {
                    prediction[i][j] = (base += this.shrinkage * this.trees[i].predict(xj)) > 0.0 ? 1 : 0;
                }
            }
        } else {
            double[] p = new double[this.k];
            for (int j = 0; j < n; ++j) {
                Tuple xj = (Tuple)x.get(j);
                Arrays.fill(p, 0.0);
                for (int i = 0; i < ntrees; ++i) {
                    for (int l = 0; l < this.k; ++l) {
                        int n2 = l;
                        p[n2] = p[n2] + this.shrinkage * this.forest[l][i].predict(xj);
                    }
                    prediction[i][j] = MathEx.whichMax((double[])p);
                }
            }
        }
        return prediction;
    }

    @Override
    public double[] shap(DataFrame data) {
        this.formula.bind(data.schema());
        return this.shap((Stream)data.stream().parallel());
    }

    @Override
    public double[] shap(Tuple x) {
        int ntrees;
        Tuple xt = this.formula.x(x);
        int p = xt.length();
        double[] phi = new double[p * this.k];
        if (this.trees != null) {
            ntrees = this.trees.length;
            for (RegressionTree tree : this.trees) {
                double[] phii = tree.shap(xt);
                for (int i = 0; i < p; ++i) {
                    int n = 2 * i;
                    phi[n] = phi[n] + phii[i];
                    int n2 = 2 * i + 1;
                    phi[n2] = phi[n2] + phii[i];
                }
            }
        } else {
            ntrees = this.forest[0].length;
            for (int i = 0; i < this.k; ++i) {
                for (RegressionTree tree : this.forest[i]) {
                    double[] phii = tree.shap(xt);
                    for (int j = 0; j < p; ++j) {
                        int n = j * this.k + i;
                        phi[n] = phi[n] + phii[j];
                    }
                }
            }
        }
        int i = 0;
        while (i < phi.length) {
            int n = i++;
            phi[n] = phi[n] / (double)ntrees;
        }
        return phi;
    }
}

