package com.github.tjake.jlama.tensor;

import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider;
import com.github.tjake.jlama.util.UnsafeDirectByteBuffer;
import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;
import java.lang.foreign.MemorySegment;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.util.Arrays;
import jdk.incubator.vector.FloatVector;
import jdk.incubator.vector.VectorSpecies;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/github/tjake/jlama/tensor/FloatBufferTensor.class */
public final class FloatBufferTensor extends AbstractTensor<FloatVector, Float, float[]> {
    private static final Logger logger = LoggerFactory.getLogger(FloatBufferTensor.class);
    private final FloatBuffer b;
    private final String name;
    private final MemorySegment segment;

    public FloatBufferTensor(AbstractTensor abstractTensor) {
        this(abstractTensor.shape);
        Preconditions.checkArgument(abstractTensor.dType != DType.I32, "This should never happen, likely a bug");
        int[] iArr = new int[abstractTensor.shape.dims()];
        do {
            set(abstractTensor.get(iArr), iArr);
        } while (abstractTensor.iterate(iArr));
    }

    public FloatBufferTensor(int... iArr) {
        this(TensorShape.of(iArr));
    }

    public FloatBufferTensor(TensorShape tensorShape) {
        super(DType.F32, tensorShape, true);
        this.name = "tmp";
        if (TensorOperationsProvider.get().requiresOffHeapTensor()) {
            this.b = UnsafeDirectByteBuffer.allocateAlignedByteBuffer(Ints.checkedCast(tensorShape.size() * dType().size()), 64L).asFloatBuffer();
        } else {
            this.b = FloatBuffer.allocate(Ints.checkedCast(tensorShape.size()));
        }
        this.segment = MemorySegment.ofBuffer(this.b);
    }

    public FloatBufferTensor(FloatBuffer floatBuffer, TensorShape tensorShape, boolean z) {
        this("none", floatBuffer, tensorShape, z);
    }

    public FloatBufferTensor(String str, FloatBuffer floatBuffer, TensorShape tensorShape, boolean z) {
        super(DType.F32, tensorShape, z);
        this.name = str;
        if (TensorOperationsProvider.get().requiresOffHeapTensor()) {
            if (floatBuffer.isDirect()) {
                this.b = floatBuffer;
            } else {
                this.b = UnsafeDirectByteBuffer.allocateAlignedByteBuffer(Ints.checkedCast(size() * dType().size()), 64L).asFloatBuffer();
                this.b.duplicate().put(floatBuffer);
            }
        } else if (floatBuffer.isDirect()) {
            this.b = FloatBuffer.allocate(Ints.checkedCast(size()));
            this.b.duplicate().put(floatBuffer);
        } else {
            this.b = floatBuffer;
        }
        this.segment = MemorySegment.ofBuffer(this.b);
    }

    @Override // com.github.tjake.jlama.tensor.AbstractTensor
    protected AbstractTensor make(TensorShape tensorShape) {
        return new FloatBufferTensor(tensorShape);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.github.tjake.jlama.tensor.AbstractTensor
    public AbstractTensor make(int i, int i2, TensorShape tensorShape, boolean z) {
        return new FloatBufferTensor(this.name, this.b.slice(i, i2), tensorShape, z);
    }

    @Override // com.github.tjake.jlama.tensor.AbstractTensor
    public float get(int... iArr) {
        Preconditions.checkArgument(iArr.length <= this.shape.dims(), "Too many dimensions specified");
        Preconditions.checkArgument(iArr.length == this.shape.dims(), "Must specify all dimensions");
        return this.b.hasArray() ? this.b.array()[this.b.arrayOffset() + getOffset(iArr)] : this.b.get(getOffset(iArr));
    }

    @Override // com.github.tjake.jlama.tensor.AbstractTensor
    public void set(float f, int... iArr) {
        Preconditions.checkArgument(iArr.length <= this.shape.dims(), "Too many dimensions specified for tensor");
        Preconditions.checkArgument(iArr.length == this.shape.dims(), "Must specify all dimensions");
        Preconditions.checkArgument(!this.b.isReadOnly(), "Can't modify a read only buffer");
        this.b.put(getOffset(iArr), f);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.github.tjake.jlama.tensor.AbstractTensor
    public float[] getArray() {
        Preconditions.checkArgument(this.b.hasArray());
        return this.b.array();
    }

    @Override // com.github.tjake.jlama.tensor.AbstractTensor
    public int getArrayOffset(int i) {
        return (this.b.hasArray() ? this.b.arrayOffset() : 0) + i;
    }

    @Override // com.github.tjake.jlama.tensor.AbstractTensor
    public MemorySegment getMemorySegment() {
        return this.segment;
    }

    @Override // com.github.tjake.jlama.tensor.AbstractTensor
    public void copyFrom(AbstractTensor abstractTensor, int i, int i2, int i3) {
        this.segment.asSlice(getMemorySegmentOffset(i2), i3 * this.dType.size()).copyFrom(abstractTensor.getMemorySegment().asSlice(abstractTensor.getMemorySegmentOffset(i), i3 * this.dType.size()));
    }

    @Override // com.github.tjake.jlama.tensor.AbstractTensor
    public int getMemorySegmentOffset(int i) {
        return i * 4;
    }

    @Override // com.github.tjake.jlama.tensor.AbstractTensor
    public FloatVector getVector(VectorSpecies<Float> vectorSpecies, int... iArr) {
        return !this.requiresOffHeapTensor ? FloatVector.fromArray(vectorSpecies, getArray(), getArrayOffset(getOffset(iArr))) : FloatVector.fromMemorySegment(vectorSpecies, this.segment, getMemorySegmentOffset(r0), ByteOrder.LITTLE_ENDIAN);
    }

    @Override // com.github.tjake.jlama.tensor.AbstractTensor
    public void intoTensor(FloatVector floatVector, int... iArr) {
        int offset = getOffset(iArr);
        if (this.requiresOffHeapTensor) {
            floatVector.intoMemorySegment(this.segment, getMemorySegmentOffset(offset), ByteOrder.LITTLE_ENDIAN);
        } else {
            floatVector.intoArray(getArray(), getArrayOffset(offset));
        }
    }

    @Override // com.github.tjake.jlama.tensor.AbstractTensor
    public void clear() {
        if (this.b.hasArray()) {
            Arrays.fill(this.b.array(), getArrayOffset(0), getArrayOffset(Ints.checkedCast(size())), 0.0f);
        } else {
            this.segment.fill((byte) 0);
        }
    }

    public String toString() {
        float[] fArr = new float[Math.min(10, this.b.remaining())];
        this.b.duplicate().get(fArr);
        return "FloatBufferTensor{name='" + this.name + "' shape=" + String.valueOf(this.shape) + ", b=" + Arrays.toString(fArr) + "...}";
    }
}
