/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.util.Arrays;
import java.util.Properties;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.base.cart.CART;
import smile.base.cart.Loss;
import smile.classification.AbstractClassifier;
import smile.classification.ClassLabels;
import smile.classification.DataFrameClassifier;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.data.type.DataType;
import smile.data.type.DataTypes;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.data.vector.ValueVector;
import smile.feature.importance.SHAP;
import smile.math.MathEx;
import smile.regression.RegressionTree;
import smile.util.IntSet;
import smile.util.IterativeAlgorithmController;
import smile.validation.ClassificationMetrics;

public class GradientTreeBoost
extends AbstractClassifier<Tuple>
implements DataFrameClassifier,
SHAP<Tuple> {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(GradientTreeBoost.class);
    private final Formula formula;
    private final int k;
    private final RegressionTree[][] trees;
    private final double[] importance;
    private final double b;
    private final double shrinkage;

    public GradientTreeBoost(Formula formula, RegressionTree[] trees, double b, double shrinkage, double[] importance) {
        this(formula, trees, b, shrinkage, importance, IntSet.of((int)2));
    }

    public GradientTreeBoost(Formula formula, RegressionTree[] trees, double b, double shrinkage, double[] importance, IntSet labels) {
        super(labels);
        this.formula = formula;
        this.k = 2;
        this.trees = new RegressionTree[][]{trees};
        this.b = b;
        this.shrinkage = shrinkage;
        this.importance = importance;
    }

    public GradientTreeBoost(Formula formula, RegressionTree[][] trees, double shrinkage, double[] importance) {
        this(formula, trees, shrinkage, importance, IntSet.of((int)trees.length));
    }

    public GradientTreeBoost(Formula formula, RegressionTree[][] trees, double shrinkage, double[] importance, IntSet labels) {
        super(labels);
        this.formula = formula;
        this.k = trees.length;
        this.trees = trees;
        this.b = 0.0;
        this.shrinkage = shrinkage;
        this.importance = importance;
    }

    public static GradientTreeBoost fit(Formula formula, DataFrame data) {
        return GradientTreeBoost.fit(formula, data, new Options(500));
    }

    public static GradientTreeBoost fit(Formula formula, DataFrame data, Options options) {
        formula = formula.expand(data.schema());
        DataFrame x = formula.x(data);
        ValueVector y = formula.y(data);
        int[][] order = CART.order(x);
        ClassLabels codec = ClassLabels.fit(y);
        if (codec.k == 2) {
            return GradientTreeBoost.train2(formula, x, codec, order, options);
        }
        return GradientTreeBoost.traink(formula, x, codec, order, options);
    }

    @Override
    public Formula formula() {
        return this.formula;
    }

    @Override
    public StructType schema() {
        return this.trees[0][0].schema();
    }

    public double[] importance() {
        return this.importance;
    }

    private static GradientTreeBoost train2(Formula formula, DataFrame x, ClassLabels codec, int[][] order, Options options) {
        long startTime = System.nanoTime();
        int n = x.nrow();
        int p = x.ncol();
        int k = codec.k;
        int[] y = codec.y;
        int[] nc = new int[k];
        for (int i = 0; i < n; ++i) {
            int n2 = y[i];
            nc[n2] = nc[n2] + 1;
        }
        Loss loss = Loss.logistic(y);
        double b = loss.intercept(null);
        double[] h = loss.residual();
        StructField field = new StructField("residual", (DataType)DataTypes.DoubleType);
        DataFrame testx = null;
        int[] testy = null;
        int[] prediction = null;
        double[] logit = null;
        double[] probability = null;
        if (options.test != null) {
            testx = formula.x(options.test);
            testy = codec.indexOf(formula.y(options.test).toIntArray());
            prediction = new int[testy.length];
            logit = new double[testy.length];
            probability = new double[testy.length];
            Arrays.fill(logit, b);
        }
        int ntrees = options.ntrees;
        double shrinkage = options.shrinkage;
        RegressionTree[] trees = new RegressionTree[ntrees];
        int[] permutation = IntStream.range(0, n).toArray();
        int[] samples = new int[n];
        for (int t = 0; t < ntrees; ++t) {
            RegressionTree tree;
            GradientTreeBoost.sampling(samples, permutation, nc, y, options.subsample);
            trees[t] = tree = new RegressionTree(x, loss, field, options.maxDepth, options.maxNodes, options.nodeSize, p, samples, order);
            for (int i = 0; i < n; ++i) {
                int n3 = i;
                h[n3] = h[n3] + shrinkage * tree.predict(x.get(i));
            }
            double lossValue = loss.value();
            logger.info("Tree {}: loss = {}", (Object)(t + 1), (Object)lossValue);
            double fitTime = (double)(System.nanoTime() - startTime) / 1000000.0;
            ClassificationMetrics metrics = null;
            if (options.test != null) {
                long testStartTime = System.nanoTime();
                for (int i = 0; i < testy.length; ++i) {
                    int n4 = i;
                    logit[n4] = logit[n4] + shrinkage * tree.predict(testx.get(i));
                    prediction[i] = logit[i] > 0.0 ? 1 : 0;
                    probability[i] = 1.0 - 1.0 / (1.0 + Math.exp(2.0 * logit[i]));
                }
                double scoreTime = (double)(System.nanoTime() - testStartTime) / 1000000.0;
                metrics = ClassificationMetrics.binary(fitTime, scoreTime, testy, prediction, probability);
                logger.info("Validation metrics = {} ", (Object)metrics);
            }
            if (options.controller == null) continue;
            options.controller.submit((Object)new TrainingStatus(t + 1, lossValue, metrics));
            if (!options.controller.isInterrupted()) continue;
            trees = Arrays.copyOf(trees, t);
            break;
        }
        double[] importance = new double[p];
        for (RegressionTree tree : trees) {
            double[] imp = tree.importance();
            for (int i = 0; i < imp.length; ++i) {
                int n5 = i;
                importance[n5] = importance[n5] + imp[i];
            }
        }
        return new GradientTreeBoost(formula, trees, b, shrinkage, importance, codec.classes);
    }

    private static GradientTreeBoost traink(Formula formula, DataFrame x, ClassLabels codec, int[][] order, Options options) {
        long startTime = System.nanoTime();
        int n = x.size();
        int p = x.ncol();
        int k = codec.k;
        int[] y = codec.y;
        int[] nc = new int[k];
        for (int i = 0; i < n; ++i) {
            int n2 = y[i];
            nc[n2] = nc[n2] + 1;
        }
        DataFrame testx = null;
        int[] testy = null;
        int[] prediction = null;
        double[][] logit = null;
        double[][] probability = null;
        if (options.test != null) {
            testx = formula.x(options.test);
            testy = codec.indexOf(formula.y(options.test).toIntArray());
            prediction = new int[testy.length];
            logit = new double[testy.length][k];
            probability = new double[testy.length][k];
        }
        int ntrees = options.ntrees;
        double shrinkage = options.shrinkage;
        StructField field = new StructField("residual", (DataType)DataTypes.DoubleType);
        RegressionTree[][] forest = new RegressionTree[k][ntrees];
        double[][] prob = new double[n][k];
        double[][] h = new double[k][];
        Loss[] loss = new Loss[k];
        for (int i = 0; i < k; ++i) {
            loss[i] = Loss.logistic(i, k, y, prob);
            h[i] = loss[i].residual();
        }
        int[] permutation = IntStream.range(0, n).toArray();
        int[] samples = new int[n];
        for (int t = 0; t < ntrees; ++t) {
            for (int i = 0; i < n; ++i) {
                for (int j = 0; j < k; ++j) {
                    prob[i][j] = h[j][i];
                }
                MathEx.softmax((double[])prob[i]);
            }
            for (int j = 0; j < k; ++j) {
                RegressionTree tree;
                GradientTreeBoost.sampling(samples, permutation, nc, y, options.subsample);
                forest[j][t] = tree = new RegressionTree(x, loss[j], field, options.maxDepth, options.maxNodes, options.nodeSize, p, samples, order);
                double[] hj = h[j];
                for (int i = 0; i < n; ++i) {
                    int n3 = i;
                    hj[n3] = hj[n3] + shrinkage * tree.predict(x.get(i));
                }
            }
            double lossValue = loss[0].value();
            logger.info("Tree {}: loss = {}", (Object)(t + 1), (Object)lossValue);
            double fitTime = (double)(System.nanoTime() - startTime) / 1000000.0;
            ClassificationMetrics metrics = null;
            if (options.test != null) {
                long testStartTime = System.nanoTime();
                for (int i = 0; i < testy.length; ++i) {
                    int j;
                    Tuple xt = testx.get(i);
                    for (int j2 = 0; j2 < k; ++j2) {
                        double[] dArray = logit[i];
                        int n4 = j2;
                        dArray[n4] = dArray[n4] + shrinkage * forest[j2][t].predict(xt);
                    }
                    prediction[i] = MathEx.whichMax((double[])logit[i]);
                    double max = logit[i][prediction[i]];
                    double Z = 0.0;
                    for (j = 0; j < k; ++j) {
                        probability[i][j] = Math.exp(logit[i][j] - max);
                        Z += probability[i][j];
                    }
                    j = 0;
                    while (j < k) {
                        double[] dArray = probability[i];
                        int n5 = j++;
                        dArray[n5] = dArray[n5] / Z;
                    }
                }
                double scoreTime = (double)(System.nanoTime() - testStartTime) / 1000000.0;
                metrics = ClassificationMetrics.of(fitTime, scoreTime, testy, prediction, probability);
                logger.info("Validation metrics = {} ", (Object)metrics);
            }
            if (options.controller == null) continue;
            options.controller.submit((Object)new TrainingStatus(t + 1, lossValue, metrics));
            if (!options.controller.isInterrupted()) continue;
            for (int j = 0; j < k; ++j) {
                forest[j] = Arrays.copyOf(forest[j], t);
            }
            break;
        }
        double[] importance = new double[p];
        RegressionTree[][] regressionTreeArray = forest;
        int n6 = regressionTreeArray.length;
        for (int i = 0; i < n6; ++i) {
            RegressionTree[] grove;
            for (RegressionTree tree : grove = regressionTreeArray[i]) {
                double[] imp = tree.importance();
                for (int i2 = 0; i2 < imp.length; ++i2) {
                    int n7 = i2;
                    importance[n7] = importance[n7] + imp[i2];
                }
            }
        }
        return new GradientTreeBoost(formula, forest, shrinkage, importance, codec.classes);
    }

    private static void sampling(int[] samples, int[] permutation, int[] nc, int[] y, double subsample) {
        int n = samples.length;
        int k = nc.length;
        Arrays.fill(samples, 0);
        MathEx.permutate((int[])permutation);
        for (int j = 0; j < k; ++j) {
            int subj = (int)Math.round((double)nc[j] * subsample);
            int nj = 0;
            for (int i = 0; i < n && nj < subj; ++i) {
                int xi = permutation[i];
                if (y[xi] != j) continue;
                samples[xi] = 1;
                ++nj;
            }
        }
    }

    public int size() {
        return this.trees().length;
    }

    public RegressionTree[][] trees() {
        return this.trees;
    }

    public GradientTreeBoost trim(int ntrees) {
        if (ntrees < 1) {
            throw new IllegalArgumentException("Invalid new model size: " + ntrees);
        }
        if (ntrees > this.trees[0].length) {
            throw new IllegalArgumentException("The new model size is larger than the current one.");
        }
        if (this.k == 2) {
            return new GradientTreeBoost(this.formula, Arrays.copyOf(this.trees[0], ntrees), this.b, this.shrinkage, this.importance, this.classes);
        }
        RegressionTree[][] forest = new RegressionTree[this.k][];
        for (int i = 0; i < this.k; ++i) {
            forest[i] = Arrays.copyOf(this.trees[i], ntrees);
        }
        return new GradientTreeBoost(this.formula, forest, this.shrinkage, this.importance, this.classes);
    }

    @Override
    public int predict(Tuple x) {
        Tuple xt = this.formula.x(x);
        if (this.k == 2) {
            double y = this.b;
            for (RegressionTree tree : this.trees[0]) {
                y += this.shrinkage * tree.predict(xt);
            }
            return this.classes.valueOf(y > 0.0 ? 1 : 0);
        }
        double max = Double.NEGATIVE_INFINITY;
        int y = -1;
        for (int j = 0; j < this.k; ++j) {
            double yj = 0.0;
            for (RegressionTree tree : this.trees[j]) {
                yj += this.shrinkage * tree.predict(xt);
            }
            if (!(yj > max)) continue;
            max = yj;
            y = j;
        }
        return this.classes.valueOf(y);
    }

    @Override
    public boolean soft() {
        return true;
    }

    @Override
    public int predict(Tuple x, double[] posteriori) {
        int i;
        if (posteriori.length != this.k) {
            throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", posteriori.length, this.k));
        }
        Tuple xt = this.formula.x(x);
        if (this.k == 2) {
            double y = this.b;
            for (RegressionTree tree : this.trees[0]) {
                y += this.shrinkage * tree.predict(xt);
            }
            posteriori[0] = 1.0 / (1.0 + Math.exp(2.0 * y));
            posteriori[1] = 1.0 - posteriori[0];
            return this.classes.valueOf(y > 0.0 ? 1 : 0);
        }
        double max = Double.NEGATIVE_INFINITY;
        int y = -1;
        for (int j = 0; j < this.k; ++j) {
            posteriori[j] = 0.0;
            for (RegressionTree tree : this.trees[j]) {
                int n = j;
                posteriori[n] = posteriori[n] + this.shrinkage * tree.predict(xt);
            }
            if (!(posteriori[j] > max)) continue;
            max = posteriori[j];
            y = j;
        }
        double Z = 0.0;
        for (i = 0; i < this.k; ++i) {
            posteriori[i] = Math.exp(posteriori[i] - max);
            Z += posteriori[i];
        }
        i = 0;
        while (i < this.k) {
            int n = i++;
            posteriori[n] = posteriori[n] / Z;
        }
        return this.classes.valueOf(y);
    }

    public int[][] test(DataFrame data) {
        DataFrame x = this.formula.x(data);
        int n = x.size();
        int ntrees = this.trees[0].length;
        int[][] prediction = new int[ntrees][n];
        if (this.k == 2) {
            for (int j = 0; j < n; ++j) {
                Tuple xj = x.get(j);
                double base = 0.0;
                for (int i = 0; i < ntrees; ++i) {
                    prediction[i][j] = (base += this.shrinkage * this.trees[0][i].predict(xj)) > 0.0 ? 1 : 0;
                }
            }
        } else {
            double[] p = new double[this.k];
            for (int j = 0; j < n; ++j) {
                Tuple xj = x.get(j);
                Arrays.fill(p, 0.0);
                for (int i = 0; i < ntrees; ++i) {
                    for (int l = 0; l < this.k; ++l) {
                        int n2 = l;
                        p[n2] = p[n2] + this.shrinkage * this.trees[l][i].predict(xj);
                    }
                    prediction[i][j] = MathEx.whichMax((double[])p);
                }
            }
        }
        return prediction;
    }

    @Override
    public double[] shap(DataFrame data) {
        this.formula.bind(data.schema());
        return this.shap((Stream)data.stream().parallel());
    }

    @Override
    public double[] shap(Tuple x) {
        Tuple xt = this.formula.x(x);
        int p = xt.length();
        double[] phi = new double[p * this.k];
        int ntrees = this.trees[0].length;
        if (this.k == 2) {
            for (RegressionTree tree : this.trees[0]) {
                double[] phii = tree.shap(xt);
                for (int i = 0; i < p; ++i) {
                    int n = 2 * i;
                    phi[n] = phi[n] + phii[i];
                    int n2 = 2 * i + 1;
                    phi[n2] = phi[n2] + phii[i];
                }
            }
        } else {
            for (int i = 0; i < this.k; ++i) {
                for (RegressionTree tree : this.trees[i]) {
                    double[] phii = tree.shap(xt);
                    for (int j = 0; j < p; ++j) {
                        int n = j * this.k + i;
                        phi[n] = phi[n] + phii[j];
                    }
                }
            }
        }
        int i = 0;
        while (i < phi.length) {
            int n = i++;
            phi[n] = phi[n] / (double)ntrees;
        }
        return phi;
    }

    public record Options(int ntrees, int maxDepth, int maxNodes, int nodeSize, double shrinkage, double subsample, DataFrame test, IterativeAlgorithmController<TrainingStatus> controller) {
        public Options {
            if (ntrees < 1) {
                throw new IllegalArgumentException("Invalid number of trees: " + ntrees);
            }
            if (maxDepth < 2) {
                throw new IllegalArgumentException("Invalid maximal tree depth: " + maxDepth);
            }
            if (maxNodes < 2) {
                throw new IllegalArgumentException("Invalid maximum number of nodes: " + maxNodes);
            }
            if (nodeSize < 1) {
                throw new IllegalArgumentException("Invalid node size: " + nodeSize);
            }
            if (shrinkage <= 0.0 || shrinkage > 1.0) {
                throw new IllegalArgumentException("Invalid shrinkage: " + shrinkage);
            }
            if (subsample <= 0.0 || subsample > 1.0) {
                throw new IllegalArgumentException("Invalid sampling fraction: " + subsample);
            }
        }

        public Options(int ntrees) {
            this(ntrees, 20, 6, 5, 0.05, 0.7, null, null);
        }

        public Properties toProperties() {
            Properties props = new Properties();
            props.setProperty("smile.gradient_boost.trees", Integer.toString(this.ntrees));
            props.setProperty("smile.gradient_boost.max_depth", Integer.toString(this.maxDepth));
            props.setProperty("smile.gradient_boost.max_nodes", Integer.toString(this.maxNodes));
            props.setProperty("smile.gradient_boost.node_size", Integer.toString(this.nodeSize));
            props.setProperty("smile.gradient_boost.shrinkage", Double.toString(this.shrinkage));
            props.setProperty("smile.gradient_boost.sampling_rate", Double.toString(this.subsample));
            return props;
        }

        public static Options of(Properties props) {
            int ntrees = Integer.parseInt(props.getProperty("smile.gradient_boost.trees", "500"));
            int maxDepth = Integer.parseInt(props.getProperty("smile.gradient_boost.max_depth", "20"));
            int maxNodes = Integer.parseInt(props.getProperty("smile.gradient_boost.max_nodes", "6"));
            int nodeSize = Integer.parseInt(props.getProperty("smile.gradient_boost.node_size", "5"));
            double shrinkage = Double.parseDouble(props.getProperty("smile.gradient_boost.shrinkage", "0.05"));
            double subsample = Double.parseDouble(props.getProperty("smile.gradient_boost.sampling_rate", "0.7"));
            return new Options(ntrees, maxDepth, maxNodes, nodeSize, shrinkage, subsample, null, null);
        }
    }

    public record TrainingStatus(int tree, double loss, ClassificationMetrics metrics) {
    }
}

