/*
 * Decompiled with CFR 0.152.
 */
package smile.regression;

import java.util.Arrays;
import java.util.Properties;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.base.cart.CART;
import smile.base.cart.Loss;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.data.type.DataType;
import smile.data.type.DataTypes;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.math.MathEx;
import smile.regression.DataFrameRegression;
import smile.regression.Regression;
import smile.regression.RegressionTree;
import smile.util.Strings;

public class GradientTreeBoost
implements Regression<Tuple>,
DataFrameRegression {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(GradientTreeBoost.class);
    private Formula formula;
    private RegressionTree[] trees;
    private double b;
    private double[] importance;
    private double shrinkage = 0.005;

    public GradientTreeBoost(Formula formula, RegressionTree[] trees, double b, double shrinkage, double[] importance) {
        this.formula = formula;
        this.trees = trees;
        this.b = b;
        this.shrinkage = shrinkage;
        this.importance = importance;
    }

    public static GradientTreeBoost fit(Formula formula, DataFrame data) {
        return GradientTreeBoost.fit(formula, data, new Properties());
    }

    public static GradientTreeBoost fit(Formula formula, DataFrame data, Properties prop) {
        int ntrees = Integer.valueOf(prop.getProperty("smile.gbt.trees", "500"));
        Loss loss = Loss.valueOf(prop.getProperty("smile.gbt.loss", "LeastAbsoluteDeviation"));
        int maxDepth = Integer.valueOf(prop.getProperty("smile.gbt.max.depth", "20"));
        int maxNodes = Integer.valueOf(prop.getProperty("smile.gbt.max.nodes", "6"));
        int nodeSize = Integer.valueOf(prop.getProperty("smile.gbt.node.size", "5"));
        double shrinkage = Double.valueOf(prop.getProperty("smile.gbt.shrinkage", "0.05"));
        double subsample = Double.valueOf(prop.getProperty("smile.gbt.sample.rate", "0.7"));
        return GradientTreeBoost.fit(formula, data, loss, ntrees, maxDepth, maxNodes, nodeSize, shrinkage, subsample);
    }

    public static GradientTreeBoost fit(Formula formula, DataFrame data, Loss loss, int ntrees, int maxDepth, int maxNodes, int nodeSize, double shrinkage, double subsample) {
        if (ntrees < 1) {
            throw new IllegalArgumentException("Invalid number of trees: " + ntrees);
        }
        if (shrinkage <= 0.0 || shrinkage > 1.0) {
            throw new IllegalArgumentException("Invalid shrinkage: " + shrinkage);
        }
        if (subsample <= 0.0 || subsample > 1.0) {
            throw new IllegalArgumentException("Invalid sampling fraction: " + subsample);
        }
        DataFrame x = formula.x(data);
        double[] y = formula.y(data).toDoubleArray();
        int n = x.nrows();
        int N = (int)Math.round((double)n * subsample);
        int[][] order = CART.order(x);
        int[] permutation = IntStream.range(0, n).toArray();
        int[] samples = new int[n];
        StructField field = new StructField("residual", (DataType)DataTypes.DoubleType);
        double b = loss.intercept(y);
        double[] residual = loss.residual();
        RegressionTree[] trees = new RegressionTree[ntrees];
        for (int t = 0; t < ntrees; ++t) {
            int i;
            Arrays.fill(samples, 0);
            MathEx.permutate((int[])permutation);
            for (i = 0; i < N; ++i) {
                int n2 = permutation[i];
                samples[n2] = samples[n2] + 1;
            }
            logger.info("Training {} tree", (Object)Strings.ordinal((int)(t + 1)));
            trees[t] = new RegressionTree(x, loss, field, maxDepth, maxNodes, nodeSize, x.ncols(), samples, order);
            for (i = 0; i < n; ++i) {
                int n3 = i;
                residual[n3] = residual[n3] - shrinkage * trees[t].predict((Tuple)x.get(i));
            }
        }
        double[] importance = new double[x.ncols()];
        for (RegressionTree tree : trees) {
            double[] imp = tree.importance();
            for (int i = 0; i < imp.length; ++i) {
                int n4 = i;
                importance[n4] = importance[n4] + imp[i];
            }
        }
        return new GradientTreeBoost(formula, trees, b, shrinkage, importance);
    }

    @Override
    public Formula formula() {
        return this.formula;
    }

    @Override
    public StructType schema() {
        return this.trees[0].schema();
    }

    public double[] importance() {
        return this.importance;
    }

    public int size() {
        return this.trees.length;
    }

    public RegressionTree[] trees() {
        return this.trees;
    }

    public void trim(int ntrees) {
        if (ntrees > this.trees.length) {
            throw new IllegalArgumentException("The new model size is larger than the current size.");
        }
        if (ntrees < 1) {
            throw new IllegalArgumentException("Invalid new model size: " + ntrees);
        }
        this.trees = Arrays.copyOf(this.trees, ntrees);
    }

    @Override
    public double predict(Tuple x) {
        Tuple xt = this.formula.x(x);
        double y = this.b;
        for (RegressionTree tree : this.trees) {
            y += this.shrinkage * tree.predict(xt);
        }
        return y;
    }

    public double[][] test(DataFrame data) {
        DataFrame x = this.formula.x(data);
        int n = x.nrows();
        int ntrees = this.trees.length;
        double[][] prediction = new double[ntrees][n];
        for (int j = 0; j < n; ++j) {
            Tuple xj = (Tuple)x.get(j);
            double base = this.b;
            for (int i = 0; i < ntrees; ++i) {
                prediction[i][j] = base += this.shrinkage * this.trees[i].predict(xj);
            }
        }
        return prediction;
    }
}

