/*
 * Decompiled with CFR 0.152.
 */
package com.github.tjake.jlama.tensor.operations;

import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.Q4ByteBufferTensor;
import com.github.tjake.jlama.tensor.Q8ByteBufferTensor;
import com.github.tjake.jlama.tensor.operations.NaiveTensorOperations;
import com.github.tjake.jlama.tensor.operations.PanamaTensorOperations;
import com.github.tjake.jlama.tensor.operations.TensorOperations;
import com.github.tjake.jlama.tensor.operations.cnative.NativeSimd;
import com.github.tjake.jlama.tensor.operations.util.JarSupport;
import com.github.tjake.jlama.tensor.operations.util.MemorySegmentSupport;
import com.github.tjake.jlama.util.MachineSpec;
import com.github.tjake.jlama.util.RuntimeSupport;
import java.lang.foreign.MemorySegment;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NativeTensorOperations
implements TensorOperations {
    private static final Logger logger;
    public static final int HAS_F16C;
    public static final int HAS_AVX2;
    private static final TensorOperations delegate;
    final int flags;

    public NativeTensorOperations() {
        int f = 0;
        if (RuntimeSupport.isLinux()) {
            f |= HAS_F16C;
        }
        if (MachineSpec.VECTOR_TYPE == MachineSpec.Type.AVX_512) {
            f |= HAS_AVX2;
        }
        this.flags = f;
        this.checkLib();
    }

    NativeTensorOperations(int flags) {
        this.flags = flags;
    }

    public String name() {
        return "Native SIMD Operations";
    }

    private void checkLib() {
    }

    public int parallelSplitSize() {
        return 128;
    }

    public void batchDotProduct(AbstractTensor result, AbstractTensor at, AbstractTensor bt, int aColumnOffset, int bColumnOffset, int columnLength, int bRowOffset, int rowChunkSize) {
        int M = at.shape().dim(0);
        int N = rowChunkSize;
        int K = columnLength;
        block0 : switch (at.dType()) {
            case BF16: {
                switch (bt.dType()) {
                    case BF16: {
                        NativeSimd.gemm_bf16((int)this.flags, (MemorySegment)at.getMemorySegment(), (int)at.getOffset(new int[]{0, aColumnOffset}), (MemorySegment)bt.getMemorySegment(), (int)bt.getOffset(new int[]{0, bColumnOffset}), (MemorySegment)(result.dType() == DType.BF16 ? result.getMemorySegment() : MemorySegment.NULL), (MemorySegment)(result.dType() == DType.F32 ? result.getMemorySegment() : MemorySegment.NULL), (int)result.shape().sparseOffset(), (int)M, (int)bRowOffset, (int)N, (int)K, (int)at.getStride(), (int)bt.getStride(), (int)result.getStride());
                        break block0;
                    }
                }
                throw new UnsupportedOperationException(at.dType().name() + " " + bt.dType().name());
            }
            case F32: {
                switch (bt.dType()) {
                    case F32: {
                        NativeSimd.gemm_f32((int)this.flags, (MemorySegment)at.getMemorySegment(), (int)at.getOffset(new int[]{0, aColumnOffset}), (MemorySegment)bt.getMemorySegment(), (int)bt.getOffset(new int[]{0, bColumnOffset}), (MemorySegment)result.getMemorySegment(), (int)result.shape().sparseOffset(), (int)M, (int)bRowOffset, (int)N, (int)K, (int)at.getStride(), (int)bt.getStride(), (int)result.getStride());
                        break block0;
                    }
                    case BF16: {
                        NativeSimd.gemm_f32_bf16((int)this.flags, (MemorySegment)at.getMemorySegment(), (int)at.getOffset(new int[]{0, aColumnOffset}), (MemorySegment)bt.getMemorySegment(), (int)bt.getOffset(new int[]{0, bColumnOffset}), (MemorySegment)(result.dType() == DType.BF16 ? result.getMemorySegment() : MemorySegment.NULL), (MemorySegment)(result.dType() == DType.F32 ? result.getMemorySegment() : MemorySegment.NULL), (int)result.shape().sparseOffset(), (int)M, (int)bRowOffset, (int)N, (int)K, (int)at.getStride(), (int)bt.getStride(), (int)result.getStride());
                        break block0;
                    }
                    case Q4: {
                        switch (MachineSpec.VECTOR_TYPE) {
                            case ARM_128: {
                                throw new UnsupportedOperationException("F32 Q4 Unsupported on Arm");
                            }
                        }
                        Q4ByteBufferTensor b = (Q4ByteBufferTensor)bt;
                        NativeSimd.gemm_f32_q4((int)this.flags, (MemorySegment)at.getMemorySegment(), (int)at.getOffset(new int[]{0, aColumnOffset}), (MemorySegment)b.getBlockF().getMemorySegment(), (MemorySegment)b.getMemorySegment(), (int)b.getMemorySegmentOffset(b.getOffset(new int[]{0, bColumnOffset})), (MemorySegment)result.getMemorySegment(), (int)result.shape().sparseOffset(), (int)M, (int)bRowOffset, (int)N, (int)K, (int)at.getStride(), (int)b.getMemorySegmentOffset(b.getStride()), (int)b.getBlockF().getStride(), (int)result.getStride());
                        break block0;
                    }
                }
                throw new UnsupportedOperationException(at.dType().name() + " " + bt.dType().name());
            }
            case I8: {
                switch (bt.dType()) {
                    case Q4: {
                        Q8ByteBufferTensor a = (Q8ByteBufferTensor)at;
                        Q4ByteBufferTensor b = (Q4ByteBufferTensor)bt;
                        NativeSimd.gemm_q8_q4((int)this.flags, (MemorySegment)a.getBlockF().getMemorySegment(), (MemorySegment)a.getMemorySegment(), (int)a.getOffset(new int[]{0, aColumnOffset}), (MemorySegment)b.getBlockF().getMemorySegment(), (MemorySegment)b.getMemorySegment(), (int)b.getMemorySegmentOffset(b.getOffset(new int[]{0, bColumnOffset})), (MemorySegment)result.getMemorySegment(), (int)result.shape().sparseOffset(), (int)M, (int)bRowOffset, (int)N, (int)K, (int)a.getStride(), (int)a.getBlockF().getStride(), (int)b.getMemorySegmentOffset(b.getStride()), (int)b.getBlockF().getStride(), (int)result.getStride());
                        break block0;
                    }
                }
                throw new UnsupportedOperationException(at.dType().name() + " " + bt.dType().name());
            }
            default: {
                throw new UnsupportedOperationException(at.dType().name());
            }
        }
    }

    public void dotProductBatchChunk(AbstractTensor[] r, AbstractTensor a, AbstractTensor[] b, int columnOffset, int columnLength, int bRowOffset, int rowChunkSize) {
        MemorySegment[] tmp = MemorySegmentSupport.setupBatch(i -> r[i].getMemorySegment(), i -> b[i].getMemorySegment(), i -> b[i] instanceof Q4ByteBufferTensor ? ((Q4ByteBufferTensor)b[i]).getBlockF().getMemorySegment() : MemorySegment.NULL, r.length);
        MemorySegment ra = tmp[0];
        MemorySegment rb = tmp[1];
        MemorySegment rc = tmp[2];
        int M = a.shape().dim(0);
        int N = rowChunkSize;
        int K = columnLength;
        block0 : switch (a.dType()) {
            case BF16: {
                switch (b[0].dType()) {
                    case BF16: {
                        NativeSimd.gemm_bf16_batch((int)this.flags, (int)r.length, (MemorySegment)a.getMemorySegment(), (int)a.getOffset(new int[]{0, columnOffset}), (MemorySegment)rb, (int)b[0].getOffset(new int[]{0, columnOffset}), (MemorySegment)(r[0].dType() == DType.BF16 ? ra : MemorySegment.NULL), (MemorySegment)(r[0].dType() == DType.F32 ? ra : MemorySegment.NULL), (int)r[0].shape().sparseOffset(), (int)M, (int)bRowOffset, (int)N, (int)K, (int)a.getStride(), (int)b[0].getStride(), (int)r[0].getStride());
                        break block0;
                    }
                }
                throw new UnsupportedOperationException(a.dType().name() + " " + b[0].dType().name());
            }
            case F32: {
                switch (b[0].dType()) {
                    case F32: {
                        NativeSimd.gemm_f32_batch((int)this.flags, (int)r.length, (MemorySegment)a.getMemorySegment(), (int)a.getOffset(new int[]{0, columnOffset}), (MemorySegment)rb, (int)b[0].getOffset(new int[]{0, columnOffset}), (MemorySegment)ra, (int)r[0].shape().sparseOffset(), (int)M, (int)bRowOffset, (int)N, (int)K, (int)a.getStride(), (int)b[0].getStride(), (int)r[0].getStride());
                        break block0;
                    }
                    case BF16: {
                        NativeSimd.gemm_f32_bf16_batch((int)this.flags, (int)r.length, (MemorySegment)a.getMemorySegment(), (int)a.getOffset(new int[]{0, columnOffset}), (MemorySegment)rb, (int)b[0].getOffset(new int[]{0, columnOffset}), (MemorySegment)(r[0].dType() == DType.BF16 ? ra : MemorySegment.NULL), (MemorySegment)(r[0].dType() == DType.F32 ? ra : MemorySegment.NULL), (int)r[0].shape().sparseOffset(), (int)M, (int)bRowOffset, (int)N, (int)K, (int)a.getStride(), (int)b[0].getStride(), (int)r[0].getStride());
                        break block0;
                    }
                    case Q4: {
                        switch (MachineSpec.VECTOR_TYPE) {
                            case ARM_128: {
                                throw new UnsupportedOperationException("F32 Q4 Unsupported on Arm");
                            }
                        }
                        Q4ByteBufferTensor bt = (Q4ByteBufferTensor)b[0];
                        NativeSimd.gemm_f32_q4_batch((int)this.flags, (int)r.length, (MemorySegment)a.getMemorySegment(), (int)a.getOffset(new int[]{0, columnOffset}), (MemorySegment)rc, (MemorySegment)rb, (int)b[0].getMemorySegmentOffset(b[0].getOffset(new int[]{0, columnOffset})), (MemorySegment)ra, (int)r[0].shape().sparseOffset(), (int)M, (int)bRowOffset, (int)N, (int)K, (int)a.getStride(), (int)b[0].getMemorySegmentOffset(b[0].getStride()), (int)bt.getBlockF().getStride(), (int)r[0].getStride());
                        break block0;
                    }
                }
                throw new UnsupportedOperationException(a.dType().name() + " " + b[0].dType().name());
            }
            case I8: {
                switch (b[0].dType()) {
                    case Q4: {
                        Q8ByteBufferTensor at = (Q8ByteBufferTensor)a;
                        Q4ByteBufferTensor bt = (Q4ByteBufferTensor)b[0];
                        NativeSimd.gemm_q8_q4_batch((int)this.flags, (int)r.length, (MemorySegment)at.getBlockF().getMemorySegment(), (MemorySegment)a.getMemorySegment(), (int)a.getOffset(new int[]{0, columnOffset}), (MemorySegment)rc, (MemorySegment)rb, (int)bt.getMemorySegmentOffset(bt.getOffset(new int[]{0, columnOffset})), (MemorySegment)ra, (int)r[0].shape().sparseOffset(), (int)M, (int)bRowOffset, (int)N, (int)K, (int)a.getStride(), (int)at.getBlockF().getStride(), (int)bt.getMemorySegmentOffset(bt.getStride()), (int)bt.getBlockF().getStride(), (int)r[0].getStride());
                        break block0;
                    }
                }
                throw new UnsupportedOperationException(a.dType().name() + " " + b[0].dType().name());
            }
            default: {
                throw new UnsupportedOperationException(a.dType().name());
            }
        }
    }

    public void accumulate(AbstractTensor a, AbstractTensor b, int offset, int length) {
        delegate.accumulate(a, b, offset, length);
    }

    public void maccumulate(AbstractTensor a, AbstractTensor b, int offset, int length) {
        delegate.maccumulate(a, b, offset, length);
    }

    public void saxpy(float alpha, AbstractTensor x, AbstractTensor y, int xoffset, int yoffset, int limit) {
        delegate.saxpy(alpha, x, y, xoffset, yoffset, limit);
    }

    public void saxpy(AbstractTensor alpha, AbstractTensor x, AbstractTensor y, int xoffset, int yoffset, int limit, int batchSize) {
        delegate.saxpy(alpha, x, y, xoffset, yoffset, limit, batchSize);
    }

    public void scale(float factor, AbstractTensor x, int offset, int length) {
        delegate.scale(factor, x, offset, length);
    }

    public AbstractTensor quantize(AbstractTensor t, DType qtype, int offset, int length) {
        return delegate.quantize(t, qtype, offset, length);
    }

    static {
        PanamaTensorOperations tmp;
        logger = LoggerFactory.getLogger(NativeTensorOperations.class);
        if (!JarSupport.maybeLoadLibrary()) {
            System.loadLibrary("jlama");
        }
        HAS_F16C = NativeSimd.HAS_F16C();
        HAS_AVX2 = NativeSimd.HAS_AVX2();
        try {
            tmp = new PanamaTensorOperations(MachineSpec.VECTOR_TYPE);
        }
        catch (Throwable t) {
            tmp = new NaiveTensorOperations();
        }
        delegate = tmp;
    }
}

