package com.github.tjake.jlama.tensor;

import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.TensorInfo;
import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;
import java.io.IOException;
import java.lang.Number;
import java.lang.foreign.MemorySegment;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.FileChannel;
import java.util.Arrays;
import jdk.incubator.vector.Vector;
import jdk.incubator.vector.VectorMask;
import jdk.incubator.vector.VectorSpecies;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/github/tjake/jlama/tensor/AbstractTensor.class */
public abstract class AbstractTensor<V extends Vector<?>, T extends Number> implements AutoCloseable {
    private static final Logger logger = LoggerFactory.getLogger(AbstractTensor.class);
    protected final TensorShape shape;
    protected final DType dType;
    protected final AbstractTensor[] sliceCache;
    private final int stride;
    private volatile TensorCache originCache = null;

    /* JADX INFO: Access modifiers changed from: protected */
    public AbstractTensor(DType dType, TensorShape tensorShape, boolean z) {
        Preconditions.checkArgument(tensorShape != null && tensorShape.dims() > 0);
        this.dType = dType;
        this.shape = tensorShape;
        this.sliceCache = z ? new AbstractTensor[tensorShape.first()] : null;
        this.stride = (tensorShape.first() <= 1 || dims() != 2) ? 0 : getOffset(tensorShape.sparseRowOffset() + 1, tensorShape.sparseColumnOffset());
    }

    public static AbstractTensor make(DType dType, TensorShape tensorShape) {
        switch (dType) {
            case F32:
                return new FloatBufferTensor(tensorShape);
            case BF16:
                return new BFloat16BufferTensor(tensorShape);
            case I8:
                return new Q8ByteBufferTensor(tensorShape);
            default:
                throw new RuntimeException("Unsupported tensor type: " + String.valueOf(dType));
        }
    }

    protected abstract AbstractTensor make(TensorShape tensorShape);

    protected abstract AbstractTensor make(int i, int i2, TensorShape tensorShape, boolean z);

    public AbstractTensor copyShape() {
        return TensorCache.instance.get(this.dType, this.shape);
    }

    public final int dims() {
        return this.shape.dims();
    }

    public final TensorShape shape() {
        return this.shape;
    }

    public final long size() {
        return this.shape.size();
    }

    public abstract float get(int... iArr);

    public abstract void set(float f, int... iArr);

    public AbstractTensor slice(int... iArr) {
        return slice(false, iArr);
    }

    public AbstractTensor slice(boolean z, int... iArr) {
        Preconditions.checkArgument(iArr.length < this.shape.dims(), "Too many dimensions specified for tensor");
        try {
            if (iArr.length == 1 && this.sliceCache != null && this.sliceCache[iArr[0]] != null) {
                return this.sliceCache[iArr[0]];
            }
            TensorShape slice = this.shape.slice(iArr.length);
            int i = 0;
            if (iArr.length == 1 && this.shape.dims() == 2) {
                i = this.shape.sparseColumnLength() * iArr[0];
            } else {
                for (int i2 = 0; i2 <= iArr.length - 1; i2++) {
                    int sparseColumnLength = this.shape.sparseColumnLength();
                    for (int dims = this.shape.dims() - 2; dims > i2; dims--) {
                        sparseColumnLength *= this.shape.dim(dims);
                    }
                    i += iArr[i2] * sparseColumnLength;
                }
            }
            AbstractTensor make = make(i, (int) slice.size(), slice, z);
            if (iArr.length == 1 && this.sliceCache != null) {
                this.sliceCache[iArr[0]] = make;
            }
            return make;
        } catch (Throwable th) {
            logger.warn("Dims = {}", Arrays.toString(iArr), th);
            throw th;
        }
    }

    public AbstractTensor<V, T> sparsify(int i, int i2) {
        if (!this.shape.isSparse() && i2 != this.shape.last()) {
            AbstractTensor<V, T> make = make(this.shape.sparsifyColumns(i, i2));
            int last = this.shape.last();
            int[] iArr = new int[this.shape.dims()];
            do {
                try {
                    iArr[iArr.length - 1] = i;
                    make.copyFrom(this, getOffset(iArr), make.getOffset(iArr), i2);
                    iArr[iArr.length - 1] = last - 1;
                } catch (Throwable th) {
                    logger.warn("Cursor = {}", Arrays.toString(iArr), th);
                    throw th;
                }
            } while (iterate(iArr));
            return make;
        }
        return this;
    }

    public AbstractTensor[] split(int i, int i2) {
        AbstractTensor[] abstractTensorArr = new AbstractTensor[i];
        int dim = this.shape.dim(i2) / i;
        if (dim * i != this.shape.dim(i2)) {
            throw new IllegalStateException("Chunks must be of equal size");
        }
        TensorShape dimValue = this.shape.setDimValue(i2, dim);
        for (int i3 = 0; i3 < i; i3++) {
            abstractTensorArr[i3] = make(Ints.checkedCast(i3 * dimValue.size()), Ints.checkedCast(dimValue.size()), dimValue, true);
        }
        return abstractTensorArr;
    }

    public final boolean iterate(int[] iArr) {
        Preconditions.checkArgument(iArr.length == this.shape.dims());
        for (int length = iArr.length - 1; length >= 0; length--) {
            Preconditions.checkArgument(iArr[length] >= 0 && iArr[length] < this.shape.dim(length));
            if (iArr[length] + 1 < this.shape.dim(length)) {
                int i = length;
                iArr[i] = iArr[i] + 1;
                return true;
            }
            iArr[length] = 0;
            if (length == 0) {
                return false;
            }
        }
        return true;
    }

    public final int getStride() {
        return this.stride;
    }

    public final int getOffset(int... iArr) {
        return this.shape.getOffset(iArr);
    }

    public final AbstractTensor transpose() {
        Preconditions.checkArgument(!this.shape.isSparse(), "Cannot transpose a sparse tensor");
        int[] iArr = new int[dims()];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = this.shape.dim((this.shape.dims() - i) - 1);
        }
        AbstractTensor make = make(TensorShape.of(iArr));
        int[] iArr2 = new int[dims()];
        int[] iArr3 = new int[dims()];
        do {
            float f = get(iArr2);
            for (int i2 = 0; i2 < iArr3.length; i2++) {
                iArr3[i2] = iArr2[(iArr2.length - i2) - 1];
            }
            make.set(f, iArr3);
        } while (iterate(iArr2));
        return make;
    }

    public final DType dType() {
        return this.dType;
    }

    public abstract V getVector(VectorSpecies<T> vectorSpecies, int... iArr);

    public abstract void intoTensor(V v, int... iArr);

    public void intoTensor(V v, VectorMask<T> vectorMask, int... iArr) {
        throw new UnsupportedOperationException();
    }

    public abstract MemorySegment getMemorySegment();

    public abstract int getMemorySegmentOffset(int i);

    public abstract void copyFrom(AbstractTensor abstractTensor, int i, int i2, int i3);

    public abstract void clear();

    @Override // java.lang.AutoCloseable
    public void close() {
        if (this.originCache != null) {
            this.originCache.release(this);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void setOwnerCache(TensorCache tensorCache) {
        this.originCache = tensorCache;
    }

    public AbstractTensor quantize(DType dType) {
        return quantize(dType, false);
    }

    public AbstractTensor quantize(DType dType, boolean z) {
        if (!z && (shape().first() == 1 || this.dType == dType || this.dType.size() < dType.size())) {
            return this;
        }
        if (this.shape.isSparse()) {
            logger.info("Quantizing sparse tensor is not supported");
            return this;
        }
        switch (dType) {
            case F32:
                return new FloatBufferTensor(this);
            case BF16:
                return new BFloat16BufferTensor(this);
            case I8:
                return new Q8ByteBufferTensor(this);
            case Q4:
                return new Q4ByteBufferTensor(this);
            default:
                return this;
        }
    }

    public TensorInfo save(FileChannel fileChannel) throws IOException {
        Preconditions.checkArgument(!this.shape.isSparse(), "Cannot save a sparse tensor");
        ByteBuffer order = getMemorySegment().asByteBuffer().order(ByteOrder.LITTLE_ENDIAN);
        long position = fileChannel.position();
        fileChannel.write(order);
        long[] jArr = new long[this.shape.dims()];
        for (int i = 0; i < this.shape.dims(); i++) {
            jArr[i] = this.shape.dim(i);
        }
        return new TensorInfo(this.dType, jArr, new long[]{position, fileChannel.position()});
    }

    public void debug(String str) {
        double d = 0.0d;
        for (int i = 0; i < size(); i++) {
            d += get(0, i);
        }
        System.out.println(String.format("%s = %.5f", str, Double.valueOf(d)));
    }
}
