/*
 * Decompiled with CFR 0.152.
 */
package smile.regression;

import java.util.Arrays;
import java.util.Properties;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.base.mlp.Cost;
import smile.base.mlp.Layer;
import smile.base.mlp.LayerBuilder;
import smile.base.mlp.MultilayerPerceptron;
import smile.base.mlp.OutputFunction;
import smile.base.mlp.OutputLayer;
import smile.math.MathEx;
import smile.math.Scaler;
import smile.regression.Regression;
import smile.tensor.Vector;
import smile.util.Strings;

public class MLP
extends MultilayerPerceptron
implements Regression<double[]> {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(MLP.class);
    private final Scaler scaler;

    public MLP(LayerBuilder ... builders) {
        this((Scaler)null, builders);
    }

    public MLP(Scaler scaler, LayerBuilder ... builders) {
        super(MLP.net(builders));
        this.scaler = scaler;
    }

    private static Layer[] net(LayerBuilder ... builders) {
        int p = 0;
        int l = builders.length;
        Layer[] net = new Layer[l];
        for (int i = 0; i < l; ++i) {
            net[i] = builders[i].build(p);
            p = builders[i].neurons();
        }
        if (!(net[l - 1] instanceof OutputLayer)) {
            net = Arrays.copyOf(net, l + 1);
            net[l] = new OutputLayer(1, p, OutputFunction.LINEAR, Cost.MEAN_SQUARED_ERROR);
        }
        return net;
    }

    @Override
    public double predict(double[] x) {
        this.propagate(this.vector(x), false);
        double y = this.output.output().get(0);
        return this.scaler == null ? y : this.scaler.inv(y);
    }

    @Override
    public boolean online() {
        return true;
    }

    @Override
    public void update(double[] x, double y) {
        this.propagate(this.vector(x), true);
        this.setTarget(y);
        this.backpropagate(true);
        ++this.t;
    }

    public void update(double[][] x, double[] y) {
        for (int i = 0; i < x.length; ++i) {
            this.propagate(this.vector(x[i]), true);
            this.setTarget(y[i]);
            this.backpropagate(false);
        }
        this.update(x.length);
        ++this.t;
    }

    private void setTarget(double y) {
        ((Vector)this.target.get()).set(0, this.scaler == null ? y : this.scaler.f(y));
    }

    public static MLP fit(double[][] x, double[] y, Properties params) {
        int p = x[0].length;
        Scaler scaler = Scaler.of((String)params.getProperty("smile.mlp.scaler"), (double[])y);
        LayerBuilder[] layers = Layer.of(0, p, params.getProperty("smile.mlp.layers", "ReLU(100)"));
        MLP model = new MLP(scaler, layers);
        model.setParameters(params);
        int epochs = Integer.parseInt(params.getProperty("smile.mlp.epochs", "100"));
        int batch = Integer.parseInt(params.getProperty("smile.mlp.mini_batch", "32"));
        double[][] batchx = new double[batch][];
        double[] batchy = new double[batch];
        for (int epoch = 1; epoch <= epochs; ++epoch) {
            logger.info("{} epoch", (Object)Strings.ordinal((int)epoch));
            int[] permutation = MathEx.permutate((int)x.length);
            for (int i = 0; i < x.length; i += batch) {
                int size = Math.min(batch, x.length - i);
                for (int j = 0; j < size; ++j) {
                    int index = permutation[i + j];
                    batchx[j] = x[index];
                    batchy[j] = y[index];
                }
                if (size < batch) {
                    model.update((double[][])Arrays.copyOf(batchx, size), Arrays.copyOf(batchy, size));
                    continue;
                }
                model.update(batchx, batchy);
            }
        }
        return model;
    }
}

