/*
 * Decompiled with CFR 0.152.
 */
package smile.feature;

import java.util.stream.Collectors;
import java.util.stream.DoubleStream;
import java.util.stream.IntStream;
import smile.data.AbstractTuple;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.data.vector.BaseVector;
import smile.data.vector.DoubleVector;
import smile.feature.FeatureTransform;
import smile.math.MathEx;

public class Standardizer
implements FeatureTransform {
    private static final long serialVersionUID = 2L;
    StructType schema;
    double[] mu;
    double[] std;

    public Standardizer(StructType schema, double[] mu, double[] std) {
        if (schema.length() != mu.length || mu.length != std.length) {
            throw new IllegalArgumentException("Schema and scaling factor size don't match");
        }
        for (int i = 0; i < std.length; ++i) {
            if (!MathEx.isZero((double)std[i])) continue;
            std[i] = 1.0;
        }
        this.schema = schema;
        this.mu = mu;
        this.std = std;
    }

    public static Standardizer fit(DataFrame data) {
        if (data.isEmpty()) {
            throw new IllegalArgumentException("Empty data frame");
        }
        StructType schema = data.schema();
        double[] mu = new double[schema.length()];
        double[] std = new double[schema.length()];
        int n = data.nrows();
        for (int i = 0; i < mu.length; ++i) {
            if (!schema.field(i).isNumeric()) continue;
            int col = i;
            double sum = data.stream().mapToDouble(t -> t.getDouble(col)).sum();
            double squaredSum = data.stream().mapToDouble(t -> t.getDouble(col)).map(x -> x * x).sum();
            mu[i] = sum / (double)n;
            std[i] = Math.sqrt(squaredSum / (double)n - mu[i] * mu[i]);
            if (!MathEx.isZero((double)std[i])) continue;
            std[i] = 1.0;
        }
        return new Standardizer(schema, mu, std);
    }

    public static Standardizer fit(double[][] data) {
        return Standardizer.fit(DataFrame.of((double[][])data, (String[])new String[0]));
    }

    private double scale(double x, int i) {
        return (x - this.mu[i]) / this.std[i];
    }

    @Override
    public double[] transform(double[] x) {
        double[] y = new double[x.length];
        for (int i = 0; i < y.length; ++i) {
            y[i] = this.scale(x[i], i);
        }
        return y;
    }

    @Override
    public Tuple transform(final Tuple x) {
        if (!this.schema.equals((Object)x.schema())) {
            throw new IllegalArgumentException(String.format("Invalid schema %s, expected %s", x.schema(), this.schema));
        }
        return new AbstractTuple(){

            public Object get(int i) {
                if (Standardizer.this.schema.field(i).isNumeric()) {
                    return Standardizer.this.scale(x.getDouble(i), i);
                }
                return x.get(i);
            }

            public StructType schema() {
                return Standardizer.this.schema;
            }
        };
    }

    @Override
    public DataFrame transform(DataFrame data) {
        if (!this.schema.equals((Object)data.schema())) {
            throw new IllegalArgumentException(String.format("Invalid schema %s, expected %s", data.schema(), this.schema));
        }
        BaseVector[] vectors = new BaseVector[this.schema.length()];
        for (int i = 0; i < this.mu.length; ++i) {
            StructField field = this.schema.field(i);
            if (field.isNumeric()) {
                int col = i;
                DoubleStream stream = data.stream().mapToDouble(t -> this.scale(t.getDouble(col), col));
                vectors[i] = DoubleVector.of((StructField)field, (DoubleStream)stream);
                continue;
            }
            vectors[i] = data.column(i);
        }
        return DataFrame.of((BaseVector[])vectors);
    }

    public String toString() {
        return IntStream.range(0, this.mu.length).mapToObj(i -> String.format("%s[%.4f, %.4f]", this.schema.field((int)i).name, this.mu[i], this.std[i])).collect(Collectors.joining(",", "Standardizer(", ")"));
    }
}

