/*
 * Decompiled with CFR 0.152.
 */
package smile.feature.extraction;

import java.util.stream.IntStream;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.transform.Transform;
import smile.data.type.DataType;
import smile.data.type.DataTypes;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.data.vector.DoubleVector;
import smile.data.vector.ValueVector;
import smile.tensor.DenseMatrix;

public class Projection
implements Transform {
    public final DenseMatrix projection;
    public final StructType schema;
    public final String[] columns;

    public Projection(DenseMatrix projection, String prefix, String ... columns) {
        this.projection = projection;
        int p = projection.nrow();
        StructField[] fields = (StructField[])IntStream.range(1, p + 1).mapToObj(i -> new StructField(prefix + i, (DataType)DataTypes.DoubleType)).toArray(StructField[]::new);
        this.schema = new StructType(fields);
        this.columns = columns;
    }

    public Tuple apply(Tuple x) {
        double[] y = this.apply(x.toArray(this.columns));
        return Tuple.of((StructType)this.schema, (double[])y);
    }

    public DataFrame apply(DataFrame data) {
        double[][] y = this.apply(data.toArray(this.columns));
        int n = data.size();
        int p = this.projection.nrow();
        DoubleVector[] vectors = new DoubleVector[p];
        for (int j = 0; j < p; ++j) {
            double[] x = new double[n];
            for (int i = 0; i < x.length; ++i) {
                x[i] = y[i][j];
            }
            vectors[j] = new DoubleVector(this.schema.field(j), x);
        }
        return new DataFrame((ValueVector[])vectors);
    }

    public double[] apply(double[] x) {
        return this.postprocess(this.projection.mv(this.preprocess(x)).toArray(new double[0]));
    }

    public double[][] apply(double[][] x) {
        double[][] y = new double[x.length][];
        for (int i = 0; i < x.length; ++i) {
            y[i] = this.apply(x[i]);
        }
        return y;
    }

    protected double[] preprocess(double[] x) {
        return x;
    }

    protected double[] postprocess(double[] x) {
        return x;
    }
}

