/*
 * Decompiled with CFR 0.152.
 */
package smile.regression;

import java.util.Arrays;
import java.util.Properties;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.base.cart.CART;
import smile.base.cart.Loss;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.data.type.DataType;
import smile.data.type.DataTypes;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.feature.importance.TreeSHAP;
import smile.math.MathEx;
import smile.regression.DataFrameRegression;
import smile.regression.RegressionTree;
import smile.util.IterativeAlgorithmController;
import smile.validation.RegressionMetrics;

public class GradientTreeBoost
implements DataFrameRegression,
TreeSHAP {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(GradientTreeBoost.class);
    private final Formula formula;
    private final RegressionTree[] trees;
    private final double b;
    private final double[] importance;
    private final double shrinkage;

    public GradientTreeBoost(Formula formula, RegressionTree[] trees, double b, double shrinkage, double[] importance) {
        this.formula = formula;
        this.trees = trees;
        this.b = b;
        this.shrinkage = shrinkage;
        this.importance = importance;
    }

    public static GradientTreeBoost fit(Formula formula, DataFrame data) {
        return GradientTreeBoost.fit(formula, data, new Options(500));
    }

    public static GradientTreeBoost fit(Formula formula, DataFrame data, Options options) {
        long startTime = System.nanoTime();
        formula = formula.expand(data.schema());
        DataFrame x = formula.x(data);
        double[] y = formula.y(data).toDoubleArray();
        Loss loss = options.loss;
        int ntrees = options.ntrees;
        double shrinkage = options.shrinkage;
        int n = x.size();
        int N = (int)Math.round((double)n * options.subsample);
        int[][] order = CART.order(x);
        int[] permutation = IntStream.range(0, n).toArray();
        int[] samples = new int[n];
        StructField field = new StructField("residual", (DataType)DataTypes.DoubleType);
        double b = loss.intercept(y);
        double[] residual = loss.residual();
        DataFrame testx = null;
        double[] testy = null;
        double[] prediction = null;
        if (options.test != null) {
            testx = formula.x(options.test);
            testy = formula.y(options.test).toDoubleArray();
            prediction = new double[testy.length];
            Arrays.fill(prediction, b);
        }
        RegressionTree[] trees = new RegressionTree[ntrees];
        for (int t = 0; t < ntrees; ++t) {
            int i;
            Arrays.fill(samples, 0);
            MathEx.permutate((int[])permutation);
            for (i = 0; i < N; ++i) {
                int n2 = permutation[i];
                samples[n2] = samples[n2] + 1;
            }
            trees[t] = new RegressionTree(x, loss, field, options.maxDepth, options.maxNodes, options.nodeSize, x.ncol(), samples, order);
            for (i = 0; i < n; ++i) {
                int n3 = i;
                residual[n3] = residual[n3] - shrinkage * trees[t].predict(x.get(i));
            }
            double lossValue = loss.value();
            logger.info("Tree {}: loss = {}", (Object)(t + 1), (Object)lossValue);
            double fitTime = (double)(System.nanoTime() - startTime) / 1000000.0;
            RegressionMetrics metrics = null;
            if (options.test != null) {
                long testStartTime = System.nanoTime();
                RegressionTree tree = trees[t];
                for (int i2 = 0; i2 < testy.length; ++i2) {
                    int n4 = i2;
                    prediction[n4] = prediction[n4] + shrinkage * tree.predict(testx.get(i2));
                }
                double scoreTime = (double)(System.nanoTime() - testStartTime) / 1000000.0;
                metrics = RegressionMetrics.of(fitTime, scoreTime, testy, prediction);
                logger.info("Validation metrics = {} ", (Object)metrics);
            }
            if (options.controller == null) continue;
            options.controller.submit((Object)new TrainingStatus(t + 1, lossValue, metrics));
            if (!options.controller.isInterrupted()) continue;
            trees = Arrays.copyOf(trees, t);
            break;
        }
        double[] importance = new double[x.ncol()];
        for (RegressionTree tree : trees) {
            double[] imp = tree.importance();
            for (int i = 0; i < imp.length; ++i) {
                int n5 = i;
                importance[n5] = importance[n5] + imp[i];
            }
        }
        return new GradientTreeBoost(formula, trees, b, shrinkage, importance);
    }

    @Override
    public Formula formula() {
        return this.formula;
    }

    @Override
    public StructType schema() {
        return this.trees[0].schema();
    }

    public double[] importance() {
        return this.importance;
    }

    public int size() {
        return this.trees.length;
    }

    public RegressionTree[] trees() {
        return this.trees;
    }

    public GradientTreeBoost trim(int ntrees) {
        if (ntrees < 1) {
            throw new IllegalArgumentException("Invalid new model size: " + ntrees);
        }
        if (ntrees > this.trees.length) {
            throw new IllegalArgumentException("The new model size is larger than the current size.");
        }
        return new GradientTreeBoost(this.formula, Arrays.copyOf(this.trees, ntrees), this.b, this.shrinkage, this.importance);
    }

    @Override
    public double predict(Tuple x) {
        Tuple xt = this.formula.x(x);
        double y = this.b;
        for (RegressionTree tree : this.trees) {
            y += this.shrinkage * tree.predict(xt);
        }
        return y;
    }

    public double[][] test(DataFrame data) {
        DataFrame x = this.formula.x(data);
        int n = x.size();
        int ntrees = this.trees.length;
        double[][] prediction = new double[ntrees][n];
        for (int j = 0; j < n; ++j) {
            Tuple xj = x.get(j);
            double base = this.b;
            for (int i = 0; i < ntrees; ++i) {
                prediction[i][j] = base += this.shrinkage * this.trees[i].predict(xj);
            }
        }
        return prediction;
    }

    public record Options(Loss loss, int ntrees, int maxDepth, int maxNodes, int nodeSize, double shrinkage, double subsample, DataFrame test, IterativeAlgorithmController<TrainingStatus> controller) {
        public Options {
            if (ntrees < 1) {
                throw new IllegalArgumentException("Invalid number of trees: " + ntrees);
            }
            if (maxDepth < 2) {
                throw new IllegalArgumentException("Invalid maximal tree depth: " + maxDepth);
            }
            if (maxNodes < 2) {
                throw new IllegalArgumentException("Invalid maximum number of nodes: " + maxNodes);
            }
            if (nodeSize < 1) {
                throw new IllegalArgumentException("Invalid node size: " + nodeSize);
            }
            if (shrinkage <= 0.0 || shrinkage > 1.0) {
                throw new IllegalArgumentException("Invalid shrinkage: " + shrinkage);
            }
            if (subsample <= 0.0 || subsample > 1.0) {
                throw new IllegalArgumentException("Invalid sampling fraction: " + subsample);
            }
        }

        public Options(int ntrees) {
            this(Loss.lad(), ntrees);
        }

        public Options(Loss loss, int ntrees) {
            this(loss, ntrees, 20, 6, 5, 0.05, 0.7, null, null);
        }

        public Properties toProperties() {
            Properties props = new Properties();
            props.setProperty("smile.gradient_boost.loss", this.loss.toString());
            props.setProperty("smile.gradient_boost.trees", Integer.toString(this.ntrees));
            props.setProperty("smile.gradient_boost.max_depth", Integer.toString(this.maxDepth));
            props.setProperty("smile.gradient_boost.max_nodes", Integer.toString(this.maxNodes));
            props.setProperty("smile.gradient_boost.node_size", Integer.toString(this.nodeSize));
            props.setProperty("smile.gradient_boost.shrinkage", Double.toString(this.shrinkage));
            props.setProperty("smile.gradient_boost.sampling_rate", Double.toString(this.subsample));
            return props;
        }

        public static Options of(Properties props) {
            Loss loss = Loss.valueOf(props.getProperty("smile.gradient_boost.loss", "LeastAbsoluteDeviation"));
            int ntrees = Integer.parseInt(props.getProperty("smile.gradient_boost.trees", "500"));
            int maxDepth = Integer.parseInt(props.getProperty("smile.gradient_boost.max_depth", "20"));
            int maxNodes = Integer.parseInt(props.getProperty("smile.gradient_boost.max_nodes", "6"));
            int nodeSize = Integer.parseInt(props.getProperty("smile.gradient_boost.node_size", "5"));
            double shrinkage = Double.parseDouble(props.getProperty("smile.gradient_boost.shrinkage", "0.05"));
            double subsample = Double.parseDouble(props.getProperty("smile.gradient_boost.sampling_rate", "0.7"));
            return new Options(loss, ntrees, maxDepth, maxNodes, nodeSize, shrinkage, subsample, null, null);
        }
    }

    public record TrainingStatus(int tree, double loss, RegressionMetrics metrics) {
    }
}

