package smile.feature;

import java.util.stream.Collectors;
import java.util.stream.DoubleStream;
import java.util.stream.IntStream;
import smile.data.AbstractTuple;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.data.vector.BaseVector;
import smile.data.vector.DoubleVector;
import smile.math.MathEx;

/* loaded from: input_file:smile/feature/MaxAbsScaler.class */
public class MaxAbsScaler implements FeatureTransform {
    private static final long serialVersionUID = 2;
    protected StructType schema;
    private double[] scale;

    public MaxAbsScaler(StructType structType, double[] dArr) {
        if (structType.length() != dArr.length) {
            throw new IllegalArgumentException("Schema and scaling factor size don't match");
        }
        this.schema = structType;
        this.scale = dArr;
        for (int i = 0; i < dArr.length; i++) {
            if (MathEx.isZero(dArr[i])) {
                dArr[i] = 1.0d;
            }
        }
    }

    public static MaxAbsScaler fit(DataFrame dataFrame) {
        if (dataFrame.isEmpty()) {
            throw new IllegalArgumentException("Empty data frame");
        }
        StructType schema = dataFrame.schema();
        double[] dArr = new double[schema.length()];
        for (int i = 0; i < dArr.length; i++) {
            if (schema.field(i).isNumeric()) {
                dArr[i] = ((DoubleStream) dataFrame.doubleVector(i).stream()).map(Math::abs).max().getAsDouble();
            }
        }
        return new MaxAbsScaler(schema, dArr);
    }

    public static MaxAbsScaler fit(double[][] dArr) {
        return fit(DataFrame.of(dArr, new String[0]));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public double scale(double d, int i) {
        return d / this.scale[i];
    }

    @Override // smile.feature.FeatureTransform
    public double[] transform(double[] dArr) {
        double[] dArr2 = new double[dArr.length];
        for (int i = 0; i < dArr2.length; i++) {
            dArr2[i] = scale(dArr[i], i);
        }
        return dArr2;
    }

    @Override // smile.feature.FeatureTransform
    public Tuple transform(final Tuple tuple) {
        if (this.schema.equals(tuple.schema())) {
            return new AbstractTuple() { // from class: smile.feature.MaxAbsScaler.1
                public Object get(int i) {
                    return MaxAbsScaler.this.schema.field(i).isNumeric() ? Double.valueOf(MaxAbsScaler.this.scale(tuple.getDouble(i), i)) : tuple.get(i);
                }

                public StructType schema() {
                    return MaxAbsScaler.this.schema;
                }
            };
        }
        throw new IllegalArgumentException(String.format("Invalid schema %s, expected %s", tuple.schema(), this.schema));
    }

    @Override // smile.feature.FeatureTransform
    public DataFrame transform(DataFrame dataFrame) {
        if (!this.schema.equals(dataFrame.schema())) {
            throw new IllegalArgumentException(String.format("Invalid schema %s, expected %s", dataFrame.schema(), this.schema));
        }
        BaseVector[] baseVectorArr = new BaseVector[this.schema.length()];
        for (int i = 0; i < this.scale.length; i++) {
            StructField field = this.schema.field(i);
            if (field.isNumeric()) {
                int i2 = i;
                baseVectorArr[i] = DoubleVector.of(field, dataFrame.stream().mapToDouble(tuple -> {
                    return scale(tuple.getDouble(i2), i2);
                }));
            } else {
                baseVectorArr[i] = dataFrame.column(i);
            }
        }
        return DataFrame.of(baseVectorArr);
    }

    public String toString() {
        return (String) IntStream.range(0, this.scale.length).mapToObj(i -> {
            return String.format("%s[%.4f]", this.schema.field(i).name, Double.valueOf(this.scale[i]));
        }).collect(Collectors.joining(",", "MaxAbsScaler(", ")"));
    }
}
