/*
 * Decompiled with CFR 0.152.
 */
package smile.tensor;

import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.lang.runtime.SwitchBootstraps;
import java.util.Arrays;
import java.util.Objects;
import smile.math.MathEx;
import smile.tensor.AbstractTensor;
import smile.tensor.ScalarType;
import smile.tensor.Tensor;

public class JTensor
extends AbstractTensor {
    private final transient MemorySegment memory;
    private final transient ValueLayout valueLayout;

    private JTensor(MemorySegment memory, ValueLayout valueLayout, int[] shape) {
        super(shape);
        this.memory = memory;
        this.valueLayout = valueLayout;
    }

    public ValueLayout valueLayout() {
        return this.valueLayout;
    }

    public MemorySegment memory() {
        return this.memory;
    }

    public String toString() {
        return "Tensor" + Arrays.toString(this.shape());
    }

    @Override
    public ScalarType scalarType() {
        ValueLayout valueLayout = this.valueLayout;
        Objects.requireNonNull(valueLayout);
        ValueLayout valueLayout2 = valueLayout;
        int n = 0;
        return switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{ValueLayout.OfByte.class, ValueLayout.OfShort.class, ValueLayout.OfInt.class, ValueLayout.OfLong.class, ValueLayout.OfFloat.class, ValueLayout.OfDouble.class, ValueLayout.OfBoolean.class, ValueLayout.OfChar.class}, (ValueLayout)valueLayout2, n)) {
            case 0 -> {
                ValueLayout.OfByte layout = (ValueLayout.OfByte)valueLayout2;
                yield ScalarType.Int8;
            }
            case 1 -> {
                ValueLayout.OfShort layout = (ValueLayout.OfShort)valueLayout2;
                yield ScalarType.Int16;
            }
            case 2 -> {
                ValueLayout.OfInt layout = (ValueLayout.OfInt)valueLayout2;
                yield ScalarType.Int32;
            }
            case 3 -> {
                ValueLayout.OfLong layout = (ValueLayout.OfLong)valueLayout2;
                yield ScalarType.Int64;
            }
            case 4 -> {
                ValueLayout.OfFloat layout = (ValueLayout.OfFloat)valueLayout2;
                yield ScalarType.Float32;
            }
            case 5 -> {
                ValueLayout.OfDouble layout = (ValueLayout.OfDouble)valueLayout2;
                yield ScalarType.Float64;
            }
            case 6 -> {
                ValueLayout.OfBoolean layout = (ValueLayout.OfBoolean)valueLayout2;
                yield ScalarType.Int8;
            }
            case 7 -> {
                ValueLayout.OfChar layout = (ValueLayout.OfChar)valueLayout2;
                yield ScalarType.Int16;
            }
            default -> throw new IllegalStateException("Unsupported ValueLayout: " + String.valueOf(this.valueLayout));
        };
    }

    @Override
    public Tensor reshape(int ... shape) {
        long p2;
        long p1 = MathEx.product(shape);
        if (p1 != (p2 = this.length())) {
            throw new IllegalArgumentException(String.format("The length of new shape %d != %d", p1, p2));
        }
        return new JTensor(this.memory, this.valueLayout, shape);
    }

    @Override
    public Tensor set(Tensor value, int ... index) {
        if (value instanceof JTensor) {
            JTensor B = (JTensor)value;
            if (!this.valueLayout.equals(B.valueLayout)) {
                throw new UnsupportedOperationException("set with tensor of different ValueLayout: " + String.valueOf(B.valueLayout));
            }
            long offset = this.offset(index) * this.valueLayout.byteSize();
            MemorySegment.copy(B.memory, 0L, this.memory, offset, B.memory.byteSize());
            return this;
        }
        throw new UnsupportedOperationException("Unsupported Tensor type: " + String.valueOf(value.getClass()));
    }

    @Override
    public JTensor get(int ... index) {
        long offset = this.offset(index) * this.valueLayout.byteSize();
        return new JTensor(this.memory.asSlice(offset), this.valueLayout, Arrays.copyOfRange(this.shape, index.length, this.shape.length));
    }

    public Tensor set(boolean value, int ... index) {
        long offset = this.offset(index);
        this.memory.setAtIndex((ValueLayout.OfBoolean)this.valueLayout, offset, value);
        return this;
    }

    public Tensor set(byte value, int ... index) {
        long offset = this.offset(index);
        this.memory.setAtIndex((ValueLayout.OfByte)this.valueLayout, offset, value);
        return this;
    }

    public Tensor set(short value, int ... index) {
        long offset = this.offset(index);
        this.memory.setAtIndex((ValueLayout.OfShort)this.valueLayout, offset, value);
        return this;
    }

    public Tensor set(int value, int ... index) {
        long offset = this.offset(index);
        this.memory.setAtIndex((ValueLayout.OfInt)this.valueLayout, offset, value);
        return this;
    }

    public Tensor set(long value, int ... index) {
        long offset = this.offset(index);
        this.memory.setAtIndex((ValueLayout.OfLong)this.valueLayout, offset, value);
        return this;
    }

    public Tensor set(float value, int ... index) {
        long offset = this.offset(index);
        this.memory.setAtIndex((ValueLayout.OfFloat)this.valueLayout, offset, value);
        return this;
    }

    public Tensor set(double value, int ... index) {
        long offset = this.offset(index);
        this.memory.setAtIndex((ValueLayout.OfDouble)this.valueLayout, offset, value);
        return this;
    }

    public boolean getBoolean(int ... index) {
        long offset = this.offset(index);
        return this.memory.getAtIndex((ValueLayout.OfBoolean)this.valueLayout, offset);
    }

    public byte getByte(int ... index) {
        long offset = this.offset(index);
        return this.memory.getAtIndex((ValueLayout.OfByte)this.valueLayout, offset);
    }

    public short getShort(int ... index) {
        long offset = this.offset(index);
        return this.memory.getAtIndex((ValueLayout.OfShort)this.valueLayout, offset);
    }

    public int getInt(int ... index) {
        long offset = this.offset(index);
        return this.memory.getAtIndex((ValueLayout.OfInt)this.valueLayout, offset);
    }

    public long getLong(int ... index) {
        long offset = this.offset(index);
        return this.memory.getAtIndex((ValueLayout.OfLong)this.valueLayout, offset);
    }

    public float getFloat(int ... index) {
        long offset = this.offset(index);
        return this.memory.getAtIndex((ValueLayout.OfFloat)this.valueLayout, offset);
    }

    public double getDouble(int ... index) {
        long offset = this.offset(index);
        return this.memory.getAtIndex((ValueLayout.OfDouble)this.valueLayout, offset);
    }

    static JTensor of(byte[] data, int ... shape) {
        long length = MathEx.product(shape);
        if (length != (long)data.length) {
            throw new IllegalArgumentException(String.format("The length of shape %d != %d the length of array", length, data.length));
        }
        MemorySegment memory = MemorySegment.ofArray(data);
        return new JTensor(memory, ValueLayout.JAVA_BYTE, shape);
    }

    static JTensor of(int[] data, int ... shape) {
        long length = MathEx.product(shape);
        if (length != (long)data.length) {
            throw new IllegalArgumentException(String.format("The length of shape %d != %d the length of array", length, data.length));
        }
        MemorySegment memory = MemorySegment.ofArray(data);
        return new JTensor(memory, ValueLayout.JAVA_INT, shape);
    }

    static JTensor of(float[] data, int ... shape) {
        long length = MathEx.product(shape);
        if (length != (long)data.length) {
            throw new IllegalArgumentException(String.format("The length of shape %d != %d the length of array", length, data.length));
        }
        MemorySegment memory = MemorySegment.ofArray(data);
        return new JTensor(memory, ValueLayout.JAVA_FLOAT, shape);
    }

    static JTensor of(double[] data, int ... shape) {
        long length = MathEx.product(shape);
        if (length != (long)data.length) {
            throw new IllegalArgumentException(String.format("The length of shape %d != %d the length of array", length, data.length));
        }
        MemorySegment memory = MemorySegment.ofArray(data);
        return new JTensor(memory, ValueLayout.JAVA_DOUBLE, shape);
    }
}

