package com.github.tjake.jlama.tensor.operations;

import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.Q4ByteBufferTensor;
import com.github.tjake.jlama.tensor.Q8ByteBufferTensor;
import com.github.tjake.jlama.tensor.operations.cnative.NativeSimd;
import com.github.tjake.jlama.tensor.operations.util.JarSupport;
import com.github.tjake.jlama.tensor.operations.util.MemorySegmentSupport;
import com.github.tjake.jlama.util.MachineSpec;
import com.github.tjake.jlama.util.RuntimeSupport;
import java.lang.foreign.MemorySegment;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/github/tjake/jlama/tensor/operations/NativeTensorOperations.class */
public class NativeTensorOperations implements TensorOperations {
    private static final Logger logger = LoggerFactory.getLogger(NativeTensorOperations.class);
    public static final int HAS_F16C;
    public static final int HAS_AVX2;
    private static final TensorOperations delegate;
    final int flags;

    /* renamed from: com.github.tjake.jlama.tensor.operations.NativeTensorOperations$1, reason: invalid class name */
    /* loaded from: input_file:com/github/tjake/jlama/tensor/operations/NativeTensorOperations$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$com$github$tjake$jlama$safetensors$DType;
        static final /* synthetic */ int[] $SwitchMap$com$github$tjake$jlama$util$MachineSpec$Type = new int[MachineSpec.Type.values().length];

        static {
            try {
                $SwitchMap$com$github$tjake$jlama$util$MachineSpec$Type[MachineSpec.Type.ARM_128.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            $SwitchMap$com$github$tjake$jlama$safetensors$DType = new int[DType.values().length];
            try {
                $SwitchMap$com$github$tjake$jlama$safetensors$DType[DType.BF16.ordinal()] = 1;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$com$github$tjake$jlama$safetensors$DType[DType.F32.ordinal()] = 2;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$com$github$tjake$jlama$safetensors$DType[DType.Q4.ordinal()] = 3;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$com$github$tjake$jlama$safetensors$DType[DType.I8.ordinal()] = 4;
            } catch (NoSuchFieldError e5) {
            }
        }
    }

    public NativeTensorOperations() {
        int i = RuntimeSupport.isLinux() ? 0 | HAS_F16C : 0;
        this.flags = MachineSpec.VECTOR_TYPE == MachineSpec.Type.AVX_512 ? i | HAS_AVX2 : i;
        checkLib();
    }

    NativeTensorOperations(int i) {
        this.flags = i;
    }

    public String name() {
        return "Native SIMD Operations";
    }

    private void checkLib() {
    }

    public int parallelSplitSize() {
        return 128;
    }

    public void batchDotProduct(AbstractTensor abstractTensor, AbstractTensor abstractTensor2, AbstractTensor abstractTensor3, int i, int i2, int i3, int i4, int i5, int i6) {
        int dim = abstractTensor2.shape().dim(0);
        int offset = abstractTensor2.getOffset(new int[]{0, i});
        int offset2 = abstractTensor3.getOffset(new int[]{abstractTensor3.shape().sparseRowOffset(), i2});
        int sparseColumnOffset = (abstractTensor.shape().sparseColumnOffset() - abstractTensor3.shape().sparseRowOffset()) - i4;
        int sparseRowOffset = i5 - abstractTensor3.shape().sparseRowOffset();
        switch (AnonymousClass1.$SwitchMap$com$github$tjake$jlama$safetensors$DType[abstractTensor2.dType().ordinal()]) {
            case 1:
                switch (AnonymousClass1.$SwitchMap$com$github$tjake$jlama$safetensors$DType[abstractTensor3.dType().ordinal()]) {
                    case 1:
                        NativeSimd.gemm_bf16(this.flags, abstractTensor2.getMemorySegment(), offset, abstractTensor3.getMemorySegment(), offset2, abstractTensor.dType() == DType.BF16 ? abstractTensor.getMemorySegment() : MemorySegment.NULL, abstractTensor.dType() == DType.F32 ? abstractTensor.getMemorySegment() : MemorySegment.NULL, sparseColumnOffset, dim, sparseRowOffset, i6, i3, abstractTensor2.getStride(), abstractTensor3.getStride(), abstractTensor.getStride());
                        return;
                    default:
                        throw new UnsupportedOperationException(abstractTensor2.dType().name() + " " + abstractTensor3.dType().name());
                }
            case 2:
                switch (AnonymousClass1.$SwitchMap$com$github$tjake$jlama$safetensors$DType[abstractTensor3.dType().ordinal()]) {
                    case 1:
                        NativeSimd.gemm_f32_bf16(this.flags, abstractTensor2.getMemorySegment(), offset, abstractTensor3.getMemorySegment(), offset2, abstractTensor.dType() == DType.BF16 ? abstractTensor.getMemorySegment() : MemorySegment.NULL, abstractTensor.dType() == DType.F32 ? abstractTensor.getMemorySegment() : MemorySegment.NULL, sparseColumnOffset, dim, sparseRowOffset, i6, i3, abstractTensor2.getStride(), abstractTensor3.getStride(), abstractTensor.getStride());
                        return;
                    case 2:
                        NativeSimd.gemm_f32(this.flags, abstractTensor2.getMemorySegment(), offset, abstractTensor3.getMemorySegment(), offset2, abstractTensor.getMemorySegment(), sparseColumnOffset, dim, sparseRowOffset, i6, i3, abstractTensor2.getStride(), abstractTensor3.getStride(), abstractTensor.getStride());
                        return;
                    case 3:
                        switch (AnonymousClass1.$SwitchMap$com$github$tjake$jlama$util$MachineSpec$Type[MachineSpec.VECTOR_TYPE.ordinal()]) {
                            case 1:
                                throw new UnsupportedOperationException("F32 Q4 Unsupported on Arm");
                            default:
                                Q4ByteBufferTensor q4ByteBufferTensor = (Q4ByteBufferTensor) abstractTensor3;
                                NativeSimd.gemm_f32_q4(this.flags, abstractTensor2.getMemorySegment(), offset, q4ByteBufferTensor.getBlockF().getMemorySegment(), q4ByteBufferTensor.getMemorySegment(), q4ByteBufferTensor.getMemorySegmentOffset(offset2), abstractTensor.getMemorySegment(), sparseColumnOffset, dim, sparseRowOffset, i6, i3, abstractTensor2.getStride(), q4ByteBufferTensor.getMemorySegmentOffset(q4ByteBufferTensor.getStride()), q4ByteBufferTensor.getBlockF().getStride(), abstractTensor.getStride());
                                return;
                        }
                    default:
                        throw new UnsupportedOperationException(abstractTensor2.dType().name() + " " + abstractTensor3.dType().name());
                }
            case 3:
            default:
                throw new UnsupportedOperationException(abstractTensor2.dType().name());
            case 4:
                switch (AnonymousClass1.$SwitchMap$com$github$tjake$jlama$safetensors$DType[abstractTensor3.dType().ordinal()]) {
                    case 3:
                        Q8ByteBufferTensor q8ByteBufferTensor = (Q8ByteBufferTensor) abstractTensor2;
                        Q4ByteBufferTensor q4ByteBufferTensor2 = (Q4ByteBufferTensor) abstractTensor3;
                        NativeSimd.gemm_q8_q4(this.flags, q8ByteBufferTensor.getBlockF().getMemorySegment(), q8ByteBufferTensor.getMemorySegment(), offset, q4ByteBufferTensor2.getBlockF().getMemorySegment(), q4ByteBufferTensor2.getMemorySegment(), q4ByteBufferTensor2.getMemorySegmentOffset(offset2), abstractTensor.getMemorySegment(), sparseColumnOffset, dim, sparseRowOffset, i6, i3, q8ByteBufferTensor.getStride(), q8ByteBufferTensor.getBlockF().getStride(), q4ByteBufferTensor2.getMemorySegmentOffset(q4ByteBufferTensor2.getStride()), q4ByteBufferTensor2.getBlockF().getStride(), abstractTensor.getStride());
                        return;
                    default:
                        throw new UnsupportedOperationException(abstractTensor2.dType().name() + " " + abstractTensor3.dType().name());
                }
        }
    }

    public void dotProductBatchChunk(AbstractTensor[] abstractTensorArr, AbstractTensor abstractTensor, AbstractTensor[] abstractTensorArr2, int i, int i2, int i3, int i4) {
        MemorySegment[] memorySegmentArr = MemorySegmentSupport.setupBatch(num -> {
            return abstractTensorArr[num.intValue()].getMemorySegment();
        }, num2 -> {
            return abstractTensorArr2[num2.intValue()].getMemorySegment();
        }, num3 -> {
            return abstractTensorArr2[num3.intValue()] instanceof Q4ByteBufferTensor ? ((Q4ByteBufferTensor) abstractTensorArr2[num3.intValue()]).getBlockF().getMemorySegment() : MemorySegment.NULL;
        }, abstractTensorArr.length);
        MemorySegment memorySegment = memorySegmentArr[0];
        MemorySegment memorySegment2 = memorySegmentArr[1];
        MemorySegment memorySegment3 = memorySegmentArr[2];
        int dim = abstractTensor.shape().dim(0);
        int offset = abstractTensor.getOffset(new int[]{0, i});
        int offset2 = abstractTensorArr2[0].getOffset(new int[]{abstractTensorArr2[0].shape().sparseRowOffset(), i});
        int sparseRowOffset = i3 - abstractTensorArr2[0].shape().sparseRowOffset();
        int sparseColumnOffset = abstractTensorArr[0].shape().sparseColumnOffset() - abstractTensorArr2[0].shape().sparseRowOffset();
        switch (AnonymousClass1.$SwitchMap$com$github$tjake$jlama$safetensors$DType[abstractTensor.dType().ordinal()]) {
            case 1:
                switch (AnonymousClass1.$SwitchMap$com$github$tjake$jlama$safetensors$DType[abstractTensorArr2[0].dType().ordinal()]) {
                    case 1:
                        NativeSimd.gemm_bf16_batch(this.flags, abstractTensorArr.length, abstractTensor.getMemorySegment(), offset, memorySegment2, offset2, abstractTensorArr[0].dType() == DType.BF16 ? memorySegment : MemorySegment.NULL, abstractTensorArr[0].dType() == DType.F32 ? memorySegment : MemorySegment.NULL, sparseColumnOffset, dim, sparseRowOffset, i4, i2, abstractTensor.getStride(), abstractTensorArr2[0].getStride(), abstractTensorArr[0].getStride());
                        return;
                    default:
                        throw new UnsupportedOperationException(abstractTensor.dType().name() + " " + abstractTensorArr2[0].dType().name());
                }
            case 2:
                switch (AnonymousClass1.$SwitchMap$com$github$tjake$jlama$safetensors$DType[abstractTensorArr2[0].dType().ordinal()]) {
                    case 1:
                        NativeSimd.gemm_f32_bf16_batch(this.flags, abstractTensorArr.length, abstractTensor.getMemorySegment(), offset, memorySegment2, offset2, abstractTensorArr[0].dType() == DType.BF16 ? memorySegment : MemorySegment.NULL, abstractTensorArr[0].dType() == DType.F32 ? memorySegment : MemorySegment.NULL, sparseColumnOffset, dim, sparseRowOffset, i4, i2, abstractTensor.getStride(), abstractTensorArr2[0].getStride(), abstractTensorArr[0].getStride());
                        return;
                    case 2:
                        NativeSimd.gemm_f32_batch(this.flags, abstractTensorArr.length, abstractTensor.getMemorySegment(), offset, memorySegment2, offset2, memorySegment, sparseColumnOffset, dim, i3, i4, i2, abstractTensor.getStride(), abstractTensorArr2[0].getStride(), abstractTensorArr[0].getStride());
                        return;
                    case 3:
                        switch (AnonymousClass1.$SwitchMap$com$github$tjake$jlama$util$MachineSpec$Type[MachineSpec.VECTOR_TYPE.ordinal()]) {
                            case 1:
                                throw new UnsupportedOperationException("F32 Q4 Unsupported on Arm");
                            default:
                                Q4ByteBufferTensor q4ByteBufferTensor = (Q4ByteBufferTensor) abstractTensorArr2[0];
                                NativeSimd.gemm_f32_q4_batch(this.flags, abstractTensorArr.length, abstractTensor.getMemorySegment(), offset, memorySegment3, memorySegment2, q4ByteBufferTensor.getMemorySegmentOffset(offset2), memorySegment, sparseColumnOffset, dim, i3, i4, i2, abstractTensor.getStride(), abstractTensorArr2[0].getMemorySegmentOffset(abstractTensorArr2[0].getStride()), q4ByteBufferTensor.getBlockF().getStride(), abstractTensorArr[0].getStride());
                                return;
                        }
                    default:
                        throw new UnsupportedOperationException(abstractTensor.dType().name() + " " + abstractTensorArr2[0].dType().name());
                }
            case 3:
            default:
                throw new UnsupportedOperationException(abstractTensor.dType().name());
            case 4:
                switch (AnonymousClass1.$SwitchMap$com$github$tjake$jlama$safetensors$DType[abstractTensorArr2[0].dType().ordinal()]) {
                    case 3:
                        Q8ByteBufferTensor q8ByteBufferTensor = (Q8ByteBufferTensor) abstractTensor;
                        Q4ByteBufferTensor q4ByteBufferTensor2 = (Q4ByteBufferTensor) abstractTensorArr2[0];
                        NativeSimd.gemm_q8_q4_batch(this.flags, abstractTensorArr.length, q8ByteBufferTensor.getBlockF().getMemorySegment(), abstractTensor.getMemorySegment(), offset, memorySegment3, memorySegment2, q4ByteBufferTensor2.getMemorySegmentOffset(offset2), memorySegment, sparseColumnOffset, dim, sparseRowOffset, i4, i2, abstractTensor.getStride(), q8ByteBufferTensor.getBlockF().getStride(), q4ByteBufferTensor2.getMemorySegmentOffset(q4ByteBufferTensor2.getStride()), q4ByteBufferTensor2.getBlockF().getStride(), abstractTensorArr[0].getStride());
                        return;
                    default:
                        throw new UnsupportedOperationException(abstractTensor.dType().name() + " " + abstractTensorArr2[0].dType().name());
                }
        }
    }

    public void accumulate(AbstractTensor abstractTensor, AbstractTensor abstractTensor2, int i, int i2) {
        delegate.accumulate(abstractTensor, abstractTensor2, i, i2);
    }

    public void maccumulate(AbstractTensor abstractTensor, AbstractTensor abstractTensor2, int i, int i2) {
        delegate.maccumulate(abstractTensor, abstractTensor2, i, i2);
    }

    public void saxpy(float f, AbstractTensor abstractTensor, AbstractTensor abstractTensor2, int i, int i2, int i3) {
        delegate.saxpy(f, abstractTensor, abstractTensor2, i, i2, i3);
    }

    public void saxpy(AbstractTensor abstractTensor, AbstractTensor abstractTensor2, AbstractTensor abstractTensor3, int i, int i2, int i3, int i4, int i5, int i6) {
        delegate.saxpy(abstractTensor, abstractTensor2, abstractTensor3, i, i2, i3, i4, i5, i6);
    }

    public void scale(float f, AbstractTensor abstractTensor, int i, int i2) {
        delegate.scale(f, abstractTensor, i, i2);
    }

    public AbstractTensor quantize(AbstractTensor abstractTensor, DType dType, int i, int i2) {
        return delegate.quantize(abstractTensor, dType, i, i2);
    }

    static {
        PanamaTensorOperations naiveTensorOperations;
        if (!JarSupport.maybeLoadLibrary()) {
            System.loadLibrary("jlama");
        }
        HAS_F16C = NativeSimd.HAS_F16C();
        HAS_AVX2 = NativeSimd.HAS_AVX2();
        try {
            naiveTensorOperations = new PanamaTensorOperations(MachineSpec.VECTOR_TYPE);
        } catch (Throwable th) {
            naiveTensorOperations = new NaiveTensorOperations();
        }
        delegate = naiveTensorOperations;
    }
}
