package smile.data;

import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.data.type.DataType;
import smile.data.type.ObjectType;
import smile.data.type.StructType;
import smile.data.vector.BaseVector;
import smile.data.vector.BooleanVector;
import smile.data.vector.ByteVector;
import smile.data.vector.CharVector;
import smile.data.vector.DoubleVector;
import smile.data.vector.FloatVector;
import smile.data.vector.IntVector;
import smile.data.vector.LongVector;
import smile.data.vector.ShortVector;
import smile.data.vector.StringVector;
import smile.data.vector.Vector;
import smile.math.matrix.DenseMatrix;
import smile.math.matrix.Matrix;

/* loaded from: input_file:smile/data/IndexDataFrame.class */
public class IndexDataFrame implements DataFrame {
    private static final Logger logger = LoggerFactory.getLogger(IndexDataFrame.class);
    private DataFrame df;
    private int[] index;

    public IndexDataFrame(DataFrame dataFrame, int[] iArr) {
        this.df = dataFrame;
        this.index = iArr;
    }

    @Override // smile.data.DataFrame
    public StructType schema() {
        return this.df.schema();
    }

    public String toString() {
        return toString(10, true);
    }

    @Override // java.lang.Iterable
    public Iterator<BaseVector> iterator() {
        return this.df.iterator();
    }

    @Override // smile.data.DataFrame
    public int columnIndex(String str) {
        return this.df.columnIndex(str);
    }

    @Override // smile.data.Dataset
    public int size() {
        return this.index.length;
    }

    @Override // smile.data.DataFrame
    public int ncols() {
        return this.df.ncols();
    }

    @Override // smile.data.DataFrame
    public Object get(int i, int i2) {
        return this.df.get(this.index[i], i2);
    }

    @Override // smile.data.Dataset
    public Stream<Tuple> stream() {
        return Arrays.stream(this.index).mapToObj(i -> {
            return this.df.get(i);
        });
    }

    @Override // smile.data.DataFrame
    public BaseVector column(int i) {
        return this.df.column(i).get2(this.index);
    }

    @Override // smile.data.DataFrame
    public <T> Vector<T> vector(int i) {
        return this.df.vector(i);
    }

    @Override // smile.data.DataFrame
    public BooleanVector booleanVector(int i) {
        return this.df.booleanVector(i);
    }

    @Override // smile.data.DataFrame
    public CharVector charVector(int i) {
        return this.df.charVector(i);
    }

    @Override // smile.data.DataFrame
    public ByteVector byteVector(int i) {
        return this.df.byteVector(i);
    }

    @Override // smile.data.DataFrame
    public ShortVector shortVector(int i) {
        return this.df.shortVector(i);
    }

    @Override // smile.data.DataFrame
    public IntVector intVector(int i) {
        return this.df.intVector(i);
    }

    @Override // smile.data.DataFrame
    public LongVector longVector(int i) {
        return this.df.longVector(i);
    }

    @Override // smile.data.DataFrame
    public FloatVector floatVector(int i) {
        return this.df.floatVector(i);
    }

    @Override // smile.data.DataFrame
    public DoubleVector doubleVector(int i) {
        return this.df.doubleVector(i);
    }

    @Override // smile.data.DataFrame
    public StringVector stringVector(int i) {
        return this.df.stringVector(i);
    }

    @Override // smile.data.DataFrame
    public DataFrame select(int... iArr) {
        return new IndexDataFrame(this.df.select(iArr), this.index);
    }

    @Override // smile.data.DataFrame
    public DataFrame drop(int... iArr) {
        return new IndexDataFrame(this.df.drop(iArr), this.index);
    }

    private DataFrame rebase() {
        return DataFrame.of((List<Tuple>) stream().collect(Collectors.toList()));
    }

    @Override // smile.data.DataFrame
    public DataFrame merge(DataFrame... dataFrameArr) {
        for (DataFrame dataFrame : dataFrameArr) {
            if (dataFrame.size() != size()) {
                throw new IllegalArgumentException("Merge data frames with different size: " + size() + " vs " + dataFrame.size());
            }
        }
        return rebase().merge(dataFrameArr);
    }

    @Override // smile.data.DataFrame
    public DataFrame merge(BaseVector... baseVectorArr) {
        for (BaseVector baseVector : baseVectorArr) {
            if (baseVector.size() != size()) {
                throw new IllegalArgumentException("Merge data frames with different size: " + size() + " vs " + baseVector.size());
            }
        }
        return rebase().merge(baseVectorArr);
    }

    @Override // smile.data.DataFrame
    public DataFrame union(DataFrame... dataFrameArr) {
        for (DataFrame dataFrame : dataFrameArr) {
            if (!schema().equals(dataFrame.schema())) {
                throw new IllegalArgumentException("Union data frames with different schema: " + schema() + " vs " + dataFrame.schema());
            }
        }
        return rebase().union(dataFrameArr);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // smile.data.Dataset
    public Tuple get(int i) {
        return this.df.get(this.index[i]);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // smile.data.DataFrame
    public double[][] toArray() {
        int nrows = nrows();
        int ncols = ncols();
        DataType[] types = types();
        double[][] dArr = new double[nrows][ncols];
        for (int i = 0; i < ncols; i++) {
            DataType dataType = types[i];
            switch (dataType.id()) {
                case Double:
                    DoubleVector doubleVector = doubleVector(i);
                    for (int i2 = 0; i2 < nrows; i2++) {
                        dArr[i2][i] = doubleVector.getDouble(this.index[i2]);
                    }
                    break;
                case Integer:
                    IntVector intVector = intVector(i);
                    for (int i3 = 0; i3 < nrows; i3++) {
                        dArr[i3][i] = intVector.getInt(this.index[i3]);
                    }
                    break;
                case Float:
                    FloatVector floatVector = floatVector(i);
                    for (int i4 = 0; i4 < nrows; i4++) {
                        dArr[i4][i] = floatVector.getFloat(this.index[i4]);
                    }
                    break;
                case Long:
                    LongVector longVector = longVector(i);
                    for (int i5 = 0; i5 < nrows; i5++) {
                        dArr[i5][i] = longVector.getLong(this.index[i5]);
                    }
                    break;
                case Boolean:
                    BooleanVector booleanVector = booleanVector(i);
                    for (int i6 = 0; i6 < nrows; i6++) {
                        dArr[i6][i] = booleanVector.getDouble(this.index[i6]);
                    }
                    break;
                case Byte:
                    ByteVector byteVector = byteVector(i);
                    for (int i7 = 0; i7 < nrows; i7++) {
                        dArr[i7][i] = byteVector.getByte(this.index[i7]);
                    }
                    break;
                case Short:
                    ShortVector shortVector = shortVector(i);
                    for (int i8 = 0; i8 < nrows; i8++) {
                        dArr[i8][i] = shortVector.getShort(this.index[i8]);
                    }
                    break;
                case Char:
                    CharVector charVector = charVector(i);
                    for (int i9 = 0; i9 < nrows; i9++) {
                        dArr[i9][i] = charVector.getChar(this.index[i9]);
                    }
                    break;
                case String:
                    StringVector stringVector = stringVector(i);
                    for (int i10 = 0; i10 < nrows; i10++) {
                        String str = stringVector.get(this.index[i10]);
                        dArr[i10][i] = str == null ? Double.NaN : Double.valueOf(str).doubleValue();
                    }
                    break;
                case Object:
                    Class objectClass = ((ObjectType) dataType).getObjectClass();
                    if (objectClass != Boolean.class) {
                        if (!Number.class.isAssignableFrom(objectClass)) {
                            throw new UnsupportedOperationException(String.format("DataFrame.toMatrix() doesn't support type %s", dataType));
                        }
                        Vector vector = vector(i);
                        for (int i11 = 0; i11 < nrows; i11++) {
                            dArr[i11][i] = vector.getDouble(this.index[i11]);
                        }
                        break;
                    } else {
                        Vector vector2 = vector(i);
                        for (int i12 = 0; i12 < nrows; i12++) {
                            Boolean bool = (Boolean) vector2.get(this.index[i12]);
                            if (bool != null) {
                                dArr[i12][i] = bool.booleanValue() ? 1.0d : 0.0d;
                            } else {
                                dArr[i12][i] = Double.NaN;
                            }
                        }
                        break;
                    }
                default:
                    throw new UnsupportedOperationException(String.format("DataFrame.toMatrix() doesn't support type %s", dataType));
            }
        }
        return dArr;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // smile.data.DataFrame
    public DenseMatrix toMatrix() {
        int nrows = nrows();
        int ncols = ncols();
        DataType[] types = types();
        DenseMatrix of = Matrix.of(nrows, ncols, 0.0d);
        for (int i = 0; i < ncols; i++) {
            DataType dataType = types[i];
            switch (dataType.id()) {
                case Double:
                    DoubleVector doubleVector = doubleVector(i);
                    for (int i2 = 0; i2 < nrows; i2++) {
                        of.set(i2, i, doubleVector.getDouble(this.index[i2]));
                    }
                    break;
                case Integer:
                    IntVector intVector = intVector(i);
                    for (int i3 = 0; i3 < nrows; i3++) {
                        of.set(i3, i, intVector.getInt(this.index[i3]));
                    }
                    break;
                case Float:
                    FloatVector floatVector = floatVector(i);
                    for (int i4 = 0; i4 < nrows; i4++) {
                        of.set(i4, i, floatVector.getFloat(this.index[i4]));
                    }
                    break;
                case Long:
                    LongVector longVector = longVector(i);
                    for (int i5 = 0; i5 < nrows; i5++) {
                        of.set(i5, i, longVector.getLong(this.index[i5]));
                    }
                    break;
                case Boolean:
                    BooleanVector booleanVector = booleanVector(i);
                    for (int i6 = 0; i6 < nrows; i6++) {
                        of.set(i6, i, booleanVector.getDouble(this.index[i6]));
                    }
                    break;
                case Byte:
                    ByteVector byteVector = byteVector(i);
                    for (int i7 = 0; i7 < nrows; i7++) {
                        of.set(i7, i, byteVector.getByte(this.index[i7]));
                    }
                    break;
                case Short:
                    ShortVector shortVector = shortVector(i);
                    for (int i8 = 0; i8 < nrows; i8++) {
                        of.set(i8, i, shortVector.getShort(this.index[i8]));
                    }
                    break;
                case Char:
                    CharVector charVector = charVector(i);
                    for (int i9 = 0; i9 < nrows; i9++) {
                        of.set(i9, i, charVector.getChar(this.index[i9]));
                    }
                    break;
                case String:
                    StringVector stringVector = stringVector(i);
                    for (int i10 = 0; i10 < nrows; i10++) {
                        String str = stringVector.get(this.index[i10]);
                        of.set(i10, i, str == null ? Double.NaN : Double.valueOf(str).doubleValue());
                    }
                    break;
                case Object:
                    Class objectClass = ((ObjectType) dataType).getObjectClass();
                    if (objectClass != Boolean.class) {
                        if (!Number.class.isAssignableFrom(objectClass)) {
                            throw new UnsupportedOperationException(String.format("DataFrame.toMatrix() doesn't support type %s", dataType));
                        }
                        Vector vector = vector(i);
                        for (int i11 = 0; i11 < nrows; i11++) {
                            of.set(i11, i, vector.getDouble(this.index[i11]));
                        }
                        break;
                    } else {
                        Vector vector2 = vector(i);
                        for (int i12 = 0; i12 < nrows; i12++) {
                            Boolean bool = (Boolean) vector2.get(this.index[i12]);
                            if (bool != null) {
                                of.set(i12, i, bool.booleanValue() ? 1.0d : 0.0d);
                            } else {
                                of.set(i12, i, Double.NaN);
                            }
                        }
                        break;
                    }
                default:
                    throw new UnsupportedOperationException(String.format("DataFrame.toMatrix() doesn't support type %s", dataType));
            }
        }
        return of;
    }
}
