package com.github.tjake.jlama.tensor.operations;

import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.FloatBufferTensor;
import com.github.tjake.jlama.tensor.TensorCache;
import com.github.tjake.jlama.tensor.TensorShape;
import com.google.common.base.Preconditions;

/* loaded from: input_file:com/github/tjake/jlama/tensor/operations/TensorOperations.class */
public interface TensorOperations {
    public static final ThreadLocal<FloatBufferTensor> scratch = ThreadLocal.withInitial(() -> {
        return new FloatBufferTensor(TensorShape.one);
    });

    String name();

    boolean requiresOffHeapTensor();

    default int parallelSplitSize() {
        return 1;
    }

    default float dotProduct(AbstractTensor abstractTensor, AbstractTensor abstractTensor2, int i) {
        return dotProduct(abstractTensor, abstractTensor2, 0, 0, i);
    }

    default float dotProduct(AbstractTensor abstractTensor, AbstractTensor abstractTensor2, int i, int i2, int i3) {
        FloatBufferTensor floatBufferTensor = scratch.get();
        batchDotProduct(floatBufferTensor, abstractTensor, abstractTensor2, i, i2, i3);
        return floatBufferTensor.get(0, 0);
    }

    default void batchDotProduct(AbstractTensor abstractTensor, AbstractTensor abstractTensor2, AbstractTensor abstractTensor3, int i, int i2, int i3) {
        batchDotProduct(abstractTensor, abstractTensor2, abstractTensor3, i, i2, i3, 0, abstractTensor3.shape().first());
    }

    void batchDotProduct(AbstractTensor abstractTensor, AbstractTensor abstractTensor2, AbstractTensor abstractTensor3, int i, int i2, int i3, int i4, int i5);

    default void dotProductChunk(AbstractTensor abstractTensor, AbstractTensor abstractTensor2, AbstractTensor abstractTensor3, int i, int i2, int i3, int i4) {
        batchDotProduct(abstractTensor, abstractTensor2, abstractTensor3, i, i, i2, i3, i4);
    }

    default void dotProductBatchChunk(AbstractTensor[] abstractTensorArr, AbstractTensor abstractTensor, AbstractTensor[] abstractTensorArr2, int i, int i2, int i3, int i4) {
        Preconditions.checkArgument(abstractTensorArr2[0].dims() == 2 && abstractTensorArr.length == abstractTensorArr2.length);
        for (int i5 = 0; i5 < abstractTensorArr.length; i5++) {
            dotProductChunk(abstractTensorArr[i5], abstractTensor, abstractTensorArr2[i5], i, i2, i3, i4);
        }
    }

    void accumulate(AbstractTensor abstractTensor, AbstractTensor abstractTensor2, int i, int i2);

    void maccumulate(AbstractTensor abstractTensor, AbstractTensor abstractTensor2, int i, int i2);

    void saxpy(float f, AbstractTensor abstractTensor, AbstractTensor abstractTensor2, int i, int i2, int i3);

    default void saxpy(AbstractTensor abstractTensor, AbstractTensor abstractTensor2, AbstractTensor abstractTensor3, int i, int i2, int i3, int i4) {
        Preconditions.checkArgument(abstractTensor.shape().first() == abstractTensor2.shape().first());
        for (int i5 = 0; i5 < i4; i5++) {
            saxpy(abstractTensor.get(i5), abstractTensor2.slice(i5), abstractTensor3, i, i2, i3);
        }
    }

    void sxpby(float f, AbstractTensor abstractTensor, AbstractTensor abstractTensor2, int i, int i2, int i3);

    void scale(float f, AbstractTensor abstractTensor, int i, int i2);

    default AbstractTensor quantize(AbstractTensor abstractTensor, DType dType, int i, int i2) {
        AbstractTensor abstractTensor2 = TensorCache.instance.get(abstractTensor.dType(), abstractTensor.shape());
        abstractTensor2.copyFrom(abstractTensor, i, i, i2);
        return abstractTensor2;
    }

    default float sum(AbstractTensor abstractTensor) {
        float f = 0.0f;
        int[] iArr = new int[abstractTensor.dims()];
        while (abstractTensor.iterate(iArr)) {
            f += abstractTensor.get(iArr);
        }
        return f;
    }
}
