package smile.feature.extraction;

import java.util.Arrays;
import java.util.function.Function;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.measure.CategoricalMeasure;
import smile.data.type.StructField;
import smile.data.type.StructType;

/* loaded from: input_file:smile/feature/extraction/BinaryEncoder.class */
public class BinaryEncoder implements Function<Tuple, int[]> {
    private final StructType schema;
    private final String[] columns;
    private final int[] base;

    public BinaryEncoder(StructType structType, String... strArr) {
        this.schema = structType;
        strArr = (strArr == null || strArr.length == 0) ? (String[]) Arrays.stream(structType.fields()).filter(structField -> {
            return structField.measure instanceof CategoricalMeasure;
        }).map(structField2 -> {
            return structField2.name;
        }).toArray(i -> {
            return new String[i];
        }) : strArr;
        this.columns = strArr;
        this.base = new int[strArr.length];
        for (int i2 = 0; i2 < strArr.length; i2++) {
            StructField field = structType.field(strArr[i2]);
            if (!(field.measure instanceof CategoricalMeasure)) {
                throw new IllegalArgumentException("Non-categorical attribute: " + field);
            }
            if (i2 < this.base.length - 1) {
                this.base[i2 + 1] = this.base[i2] + field.measure.size();
            }
        }
    }

    @Override // java.util.function.Function
    public int[] apply(Tuple tuple) {
        int[] iArr = new int[this.columns.length];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = tuple.getInt(this.columns[i]) + this.base[i];
        }
        return iArr;
    }

    public int[][] apply(DataFrame dataFrame) {
        return (int[][]) dataFrame.stream().map(this::apply).toArray(i -> {
            return new int[i];
        });
    }
}
