package com.github.tjake.jlama.tensor;

import com.github.tjake.jlama.safetensors.DType;
import com.google.common.collect.Maps;
import java.util.Objects;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;
import org.jctools.queues.MpmcUnboundedXaddArrayQueue;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/github/tjake/jlama/tensor/TensorCache.class */
public class TensorCache {
    public static final TensorCache instance = new TensorCache(104857600);
    private static final Logger logger = LoggerFactory.getLogger(TensorCache.class);
    private final long bytesCapacity;
    private final Function<ShapeKey, MpmcUnboundedXaddArrayQueue<AbstractTensor>> queueFactory = shapeKey -> {
        return new MpmcUnboundedXaddArrayQueue(128);
    };
    private final AtomicLong currentBytes = new AtomicLong(0);
    private final ConcurrentMap<ShapeKey, MpmcUnboundedXaddArrayQueue<AbstractTensor>> availableByShape = Maps.newConcurrentMap();

    /* loaded from: input_file:com/github/tjake/jlama/tensor/TensorCache$ShapeKey.class */
    public static class ShapeKey {
        final TensorShape shape;
        final DType dType;

        ShapeKey(DType dType, TensorShape tensorShape) {
            this.dType = dType;
            this.shape = tensorShape;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            ShapeKey shapeKey = (ShapeKey) obj;
            return Objects.equals(this.shape, shapeKey.shape) && this.dType == shapeKey.dType;
        }

        public int hashCode() {
            return Objects.hash(this.shape, this.dType);
        }
    }

    public TensorCache(long j) {
        this.bytesCapacity = j;
    }

    public AbstractTensor get(DType dType, TensorShape tensorShape) {
        AbstractTensor q8ByteBufferTensor;
        AbstractTensor abstractTensor = (AbstractTensor) this.availableByShape.computeIfAbsent(new ShapeKey(dType, tensorShape), this.queueFactory).poll();
        if (abstractTensor != null) {
            return abstractTensor;
        }
        switch (dType) {
            case F32:
                q8ByteBufferTensor = new FloatBufferTensor(tensorShape);
                break;
            case F16:
                q8ByteBufferTensor = new Float16BufferTensor(tensorShape);
                break;
            case BF16:
                q8ByteBufferTensor = new BFloat16BufferTensor(tensorShape);
                break;
            case I8:
                q8ByteBufferTensor = new Q8ByteBufferTensor(tensorShape);
                break;
            default:
                throw new RuntimeException("Unsupported tensor type: " + String.valueOf(dType));
        }
        AbstractTensor abstractTensor2 = q8ByteBufferTensor;
        if (this.currentBytes.addAndGet(abstractTensor2.size()) < this.bytesCapacity) {
            abstractTensor2.setOwnerCache(this);
        } else {
            logger.debug("Full!");
            this.currentBytes.addAndGet(-abstractTensor2.size());
        }
        return abstractTensor2;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void release(AbstractTensor abstractTensor) {
        abstractTensor.clear();
        this.availableByShape.computeIfAbsent(new ShapeKey(abstractTensor.dType(), abstractTensor.shape()), this.queueFactory).offer(abstractTensor);
    }
}
