package com.github.tjake.jlama.safetensors;

import com.github.tjake.jlama.math.FloatConversions;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.BFloat16BufferTensor;
import com.github.tjake.jlama.tensor.Float16BufferTensor;
import com.github.tjake.jlama.tensor.FloatBufferTensor;
import com.github.tjake.jlama.tensor.Q4ByteBufferTensor;
import com.github.tjake.jlama.tensor.Q8ByteBufferTensor;
import com.github.tjake.jlama.tensor.TensorShape;
import com.github.tjake.jlama.util.Pair;
import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Ints;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.EnumMap;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/github/tjake/jlama/safetensors/Weights.class */
public class Weights implements WeightLoader {
    private static final Logger logger = LoggerFactory.getLogger(Weights.class);
    private final Map<String, String> metadata;
    private final Map<String, TensorInfo> tensorInfoMap;
    private final ByteBuffer bytes;
    private final DType majorityDType = findDType();
    private final Optional<WeightLoader> parent;

    /* JADX INFO: Access modifiers changed from: package-private */
    public Weights(Map<String, String> map, Map<String, TensorInfo> map2, ByteBuffer byteBuffer, Optional<WeightLoader> optional) {
        this.metadata = ImmutableMap.copyOf(map);
        this.tensorInfoMap = ImmutableMap.copyOf(map2);
        this.bytes = byteBuffer.duplicate();
        this.parent = optional;
    }

    private DType findDType() {
        EnumMap enumMap = new EnumMap(DType.class);
        for (Map.Entry<String, TensorInfo> entry : this.tensorInfoMap.entrySet()) {
            if (!entry.getKey().endsWith(".qb")) {
                enumMap.put((EnumMap) entry.getValue().dType, (DType) Integer.valueOf(((Integer) enumMap.getOrDefault(entry.getValue().dType, 0)).intValue() + 1));
            }
        }
        int i = 0;
        DType dType = null;
        for (Map.Entry entry2 : enumMap.entrySet()) {
            if (((Integer) entry2.getValue()).intValue() > i) {
                i = ((Integer) entry2.getValue()).intValue();
                dType = (DType) entry2.getKey();
            }
        }
        return dType == DType.F16 ? DType.F32 : dType;
    }

    @Override // com.github.tjake.jlama.safetensors.WeightLoader
    public Map<String, String> metadata() {
        return this.metadata;
    }

    @Override // com.github.tjake.jlama.safetensors.WeightLoader
    public Map<String, TensorInfo> tensorInfoMap() {
        return this.tensorInfoMap;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v30, types: [com.github.tjake.jlama.tensor.Q4ByteBufferTensor] */
    /* JADX WARN: Type inference failed for: r0v36, types: [com.github.tjake.jlama.tensor.BFloat16BufferTensor] */
    /* JADX WARN: Type inference failed for: r0v47, types: [com.github.tjake.jlama.tensor.FloatBufferTensor] */
    /* JADX WARN: Type inference failed for: r0v61, types: [com.github.tjake.jlama.tensor.Float16BufferTensor] */
    /* JADX WARN: Type inference failed for: r0v72, types: [com.github.tjake.jlama.tensor.FloatBufferTensor] */
    /* JADX WARN: Type inference failed for: r0v88, types: [com.github.tjake.jlama.tensor.FloatBufferTensor] */
    @Override // com.github.tjake.jlama.safetensors.WeightLoader
    public AbstractTensor load(String str, Optional<Pair<Integer, Integer>> optional) throws NoSuchElementException {
        Q8ByteBufferTensor q8ByteBufferTensor;
        TensorInfo tensorInfo = this.tensorInfoMap.get(str);
        if (tensorInfo == null) {
            throw new NoSuchElementException(str + " not found in weights");
        }
        if (tensorInfo.shape.length < 1) {
            throw new RuntimeException("Invalid shape dimensions " + tensorInfo.shape.length + " encountered for " + str);
        }
        ByteBuffer limit = this.bytes.duplicate().order(ByteOrder.LITTLE_ENDIAN).position(Ints.checkedCast(tensorInfo.dataOffsets[0])).limit(Ints.checkedCast(tensorInfo.dataOffsets[1]));
        switch (tensorInfo.dType) {
            case F32:
                q8ByteBufferTensor = new FloatBufferTensor(str, limit.asFloatBuffer().slice(), TensorShape.of(tensorInfo.shape), true);
                break;
            case F16:
                if (this.majorityDType != DType.F32) {
                    q8ByteBufferTensor = new Float16BufferTensor(str, limit.asShortBuffer().slice(), TensorShape.of(tensorInfo.shape), true);
                    break;
                } else {
                    int remaining = limit.remaining() / DType.F16.size();
                    ByteBuffer order = ByteBuffer.allocate(remaining * DType.F32.size()).order(ByteOrder.LITTLE_ENDIAN);
                    int i = 0;
                    while (true) {
                        int i2 = i;
                        if (i2 >= remaining * DType.F32.size()) {
                            q8ByteBufferTensor = new FloatBufferTensor(order.asFloatBuffer(), TensorShape.of(tensorInfo.shape), true);
                            break;
                        } else {
                            order.putFloat(i2, Float.float16ToFloat(limit.getShort()));
                            i = i2 + DType.F32.size();
                        }
                    }
                }
            case BF16:
                if (this.majorityDType != DType.F32) {
                    q8ByteBufferTensor = new BFloat16BufferTensor(str, limit.asShortBuffer().slice(), TensorShape.of(tensorInfo.shape), true);
                    break;
                } else {
                    int remaining2 = limit.remaining() / DType.BF16.size();
                    ByteBuffer order2 = ByteBuffer.allocate(remaining2 * DType.F32.size()).order(ByteOrder.LITTLE_ENDIAN);
                    int i3 = 0;
                    while (true) {
                        int i4 = i3;
                        if (i4 >= remaining2 * DType.F32.size()) {
                            q8ByteBufferTensor = new FloatBufferTensor(order2.asFloatBuffer(), TensorShape.of(tensorInfo.shape), true);
                            break;
                        } else {
                            order2.putFloat(i4, FloatConversions.bFloat16ToFloat32(limit.getShort()));
                            i3 = i4 + DType.F32.size();
                        }
                    }
                }
            case Q4:
                q8ByteBufferTensor = new Q4ByteBufferTensor(str, limit.slice(), (FloatBufferTensor) this.parent.orElse(this).load(str + ".qb", optional), TensorShape.of(tensorInfo.shape), true);
                break;
            case I8:
                q8ByteBufferTensor = new Q8ByteBufferTensor(str, limit.slice(), (FloatBufferTensor) this.parent.orElse(this).load(str + ".qb", optional), TensorShape.of(tensorInfo.shape), true);
                break;
            default:
                throw new IllegalArgumentException("Unsupported Tensor type: " + tensorInfo.dType.name() + " for " + str);
        }
        Q8ByteBufferTensor q8ByteBufferTensor2 = q8ByteBufferTensor;
        return (AbstractTensor) optional.map(pair -> {
            return q8ByteBufferTensor2.sparsify(((Integer) pair.left).intValue(), ((Integer) pair.right).intValue());
        }).orElse(q8ByteBufferTensor);
    }

    @Override // com.github.tjake.jlama.safetensors.WeightLoader
    public DType getModelDType() {
        return this.majorityDType;
    }

    public String toString() {
        return "SafeTensor{metadata=" + String.valueOf(this.metadata) + ", tensorInfoMap=" + String.valueOf(this.tensorInfoMap) + ", bytes=" + String.valueOf(this.bytes) + "}";
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        Weights weights = (Weights) obj;
        return Objects.equals(this.metadata, weights.metadata) && Objects.equals(this.tensorInfoMap, weights.tensorInfoMap);
    }

    public int hashCode() {
        return Objects.hash(this.metadata, this.tensorInfoMap);
    }

    @Override // java.lang.AutoCloseable
    public void close() throws Exception {
    }
}
