package com.github.tjake.jlama.tensor;

import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.util.Pair;
import java.io.IOError;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.nio.ByteOrder;
import java.nio.channels.FileChannel;
import java.nio.file.Paths;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;

/* loaded from: input_file:com/github/tjake/jlama/tensor/KvBufferCache.class */
public class KvBufferCache {
    public static final String TOKEN_COUNT = "TOKEN_COUNT";
    private final ConcurrentMap<UUID, Pair<RandomAccessFile, AbstractTensor>> kvBufferCache = new ConcurrentHashMap();
    private final AbstractModel model;

    public KvBufferCache(AbstractModel abstractModel) {
        this.model = abstractModel;
    }

    public AbstractTensor getKvBuffer(UUID uuid) {
        return this.kvBufferCache.computeIfAbsent(uuid, this::makeKvBuffer).right;
    }

    private Pair<RandomAccessFile, AbstractTensor> makeKvBuffer(UUID uuid) {
        TensorShape of;
        Object bFloat16BufferTensor;
        int[] iArr = {this.model.getConfig().getNumberOfLayers(), 2, Math.min(1024, this.model.getConfig().contextLength), this.model.getConfig().kvLength};
        if (this.model.getConfig().offset().isPresent()) {
            Pair<Integer, Integer> pair = this.model.getConfig().offset().get();
            of = TensorShape.sparse(iArr, Pair.create(Integer.valueOf(pair.left.intValue() / this.model.getConfig().headGroupSize), Integer.valueOf(pair.right.intValue() / this.model.getConfig().headGroupSize)));
        } else {
            of = TensorShape.of(iArr);
        }
        if (this.model.getConfig().workingDirectory().isEmpty()) {
            return Pair.create(null, AbstractTensor.make(this.model.getWorkingDType(), of));
        }
        try {
            RandomAccessFile randomAccessFile = new RandomAccessFile(Paths.get(this.model.getConfig().workingDirectory().get().toString(), uuid.toString()).toFile(), "rw");
            long size = of.size() * this.model.getWorkingDType().size();
            randomAccessFile.setLength(size);
            if (this.model.getWorkingDType() == DType.F32) {
                bFloat16BufferTensor = new FloatBufferTensor(randomAccessFile.getChannel().map(FileChannel.MapMode.READ_WRITE, 0L, size).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer(), of, true);
            } else {
                if (this.model.getWorkingDType() != DType.BF16) {
                    throw new UnsupportedOperationException("Only F32/BF16 is supported for now");
                }
                bFloat16BufferTensor = new BFloat16BufferTensor("kvmem", randomAccessFile.getChannel().map(FileChannel.MapMode.READ_WRITE, 0L, size).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer(), of, true);
            }
            return Pair.create(randomAccessFile, bFloat16BufferTensor);
        } catch (IOException e) {
            throw new IOError(e);
        }
    }
}
