package com.github.tjake.jlama.safetensors;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.util.Pair;
import com.google.common.collect.ImmutableMap;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/github/tjake/jlama/safetensors/SafeTensorIndex.class */
public class SafeTensorIndex implements WeightLoader, AutoCloseable {
    private static final Logger logger;
    private static final ObjectMapper om;
    public static final String SINGLE_MODEL_NAME = "model.safetensors";
    public static final String MODEL_INDEX_JSON = "model.safetensors.index.json";
    private final Map<String, String> metadata;
    private final Map<String, String> weightFileMap;
    private final Map<String, Weights> weightMap = new HashMap();
    private final Map<String, RandomAccessFile> fileMap = new HashMap();
    static final /* synthetic */ boolean $assertionsDisabled;

    public static SafeTensorIndex loadWithWeights(Path path) throws IOException {
        SafeTensorIndex safeTensorIndex = (SafeTensorIndex) om.readValue(Paths.get(path.toString(), MODEL_INDEX_JSON).toFile(), SafeTensorIndex.class);
        loadWeights(safeTensorIndex, path);
        return safeTensorIndex;
    }

    public static SafeTensorIndex loadSingleFile(Path path, String str) throws IOException {
        SafeTensorIndex safeTensorIndex = new SafeTensorIndex(Collections.emptyMap(), Map.of("model-file", str));
        loadWeights(safeTensorIndex, path);
        return safeTensorIndex;
    }

    static void loadWeights(SafeTensorIndex safeTensorIndex, Path path) throws IOException {
        for (Map.Entry<String, String> entry : safeTensorIndex.weightFileMap.entrySet()) {
            if (!safeTensorIndex.fileMap.containsKey(entry.getValue())) {
                RandomAccessFile randomAccessFile = new RandomAccessFile(Paths.get(path.toString(), entry.getValue()).toFile(), "r");
                safeTensorIndex.fileMap.put(entry.getValue(), randomAccessFile);
                MappedByteBuffer map = randomAccessFile.getChannel().map(FileChannel.MapMode.READ_ONLY, 0L, Math.min(1048576L, randomAccessFile.length()));
                HashMap hashMap = new HashMap();
                Map<String, TensorInfo> readTensorInfoMap = SafeTensorSupport.readTensorInfoMap(map, Optional.of(hashMap));
                int position = map.position();
                for (Map.Entry<List<Long>, List<String>> entry2 : safeTensorIndex.computeMmapSplits(readTensorInfoMap, randomAccessFile.length()).entrySet()) {
                    long longValue = entry2.getKey().get(0).longValue();
                    long longValue2 = entry2.getKey().get(1).longValue();
                    List<String> value = entry2.getValue();
                    Weights weights = new Weights(hashMap, (Map) readTensorInfoMap.entrySet().stream().filter(entry3 -> {
                        return value.contains(entry3.getKey());
                    }).collect(ImmutableMap.toImmutableMap((v0) -> {
                        return v0.getKey();
                    }, (v0) -> {
                        return v0.getValue();
                    })), randomAccessFile.getChannel().map(FileChannel.MapMode.READ_ONLY, position + longValue, longValue2 - longValue), Optional.of(safeTensorIndex));
                    Iterator<String> it = value.iterator();
                    while (it.hasNext()) {
                        safeTensorIndex.weightMap.put(it.next(), weights);
                    }
                }
            }
        }
    }

    private Map<List<Long>, List<String>> computeMmapSplits(Map<String, TensorInfo> map, long j) {
        HashSet hashSet = new HashSet();
        HashMap hashMap = new HashMap();
        long j2 = 0;
        while (true) {
            long j3 = j2;
            if (hashSet.size() >= map.size()) {
                return hashMap;
            }
            ArrayList arrayList = new ArrayList();
            long j4 = j3 + 2147483647L;
            long j5 = j;
            long j6 = 0;
            for (Map.Entry<String, TensorInfo> entry : map.entrySet()) {
                if (!hashSet.contains(entry.getKey())) {
                    TensorInfo value = entry.getValue();
                    if (value.dataOffsets[1] < j4) {
                        arrayList.add(entry.getKey());
                        hashSet.add(entry.getKey());
                        if (value.dataOffsets[1] > j6) {
                            j6 = value.dataOffsets[1];
                        }
                        if (value.dataOffsets[0] < j5) {
                            j5 = value.dataOffsets[0];
                        }
                        long[] jArr = value.dataOffsets;
                        jArr[0] = jArr[0] - j3;
                        long[] jArr2 = value.dataOffsets;
                        jArr2[1] = jArr2[1] - j3;
                        logger.debug("Adding tensor {} to split {}-{}", new Object[]{entry.getKey(), Long.valueOf(value.dataOffsets[0]), Long.valueOf(value.dataOffsets[1])});
                    }
                }
            }
            logger.debug("Adding split {}-{} with {} tensors", new Object[]{Long.valueOf(j5), Long.valueOf(j6), Integer.valueOf(arrayList.size())});
            if (!$assertionsDisabled && j6 - j5 >= 2147483647L) {
                AssertionError assertionError = new AssertionError("Mmap split too large " + (j6 - j5) + " > 2147483647 " + assertionError);
                throw assertionError;
            }
            hashMap.put(List.of(Long.valueOf(j5), Long.valueOf(j6)), arrayList);
            j2 = j6;
        }
    }

    @JsonCreator
    SafeTensorIndex(@JsonProperty("metadata") Map<String, String> map, @JsonProperty("weight_map") Map<String, String> map2) {
        this.metadata = ImmutableMap.copyOf(map);
        this.weightFileMap = ImmutableMap.copyOf(map2);
    }

    @Override // com.github.tjake.jlama.safetensors.WeightLoader
    public Map<String, String> metadata() {
        return this.metadata;
    }

    @Override // com.github.tjake.jlama.safetensors.WeightLoader
    public Map<String, TensorInfo> tensorInfoMap() {
        HashMap hashMap = new HashMap();
        for (String str : this.weightMap.keySet()) {
            Weights weights = this.weightMap.get(str);
            if (weights == null) {
                throw new NoSuchElementException(str);
            }
            hashMap.put(str, weights.tensorInfoMap().get(str));
        }
        return hashMap;
    }

    @Override // com.github.tjake.jlama.safetensors.WeightLoader
    public AbstractTensor load(String str, Optional<Pair<Integer, Integer>> optional) {
        Weights weights = this.weightMap.get(str);
        if (weights == null) {
            throw new NoSuchElementException(str);
        }
        AbstractTensor load = weights.load(str);
        return (AbstractTensor) optional.map(pair -> {
            logger.debug("Sparsifying tensor {} with shape {}", str, pair);
            return load.sparsify(((Integer) pair.left).intValue(), ((Integer) pair.right).intValue());
        }).orElse(load);
    }

    @Override // com.github.tjake.jlama.safetensors.WeightLoader
    public DType getModelDType() {
        return this.weightMap.values().iterator().next().getModelDType();
    }

    @Override // java.lang.AutoCloseable
    public void close() throws Exception {
        this.weightMap.clear();
        this.fileMap.forEach((str, randomAccessFile) -> {
            try {
                randomAccessFile.close();
            } catch (IOException e) {
            }
        });
        this.fileMap.clear();
    }

    static {
        $assertionsDisabled = !SafeTensorIndex.class.desiredAssertionStatus();
        logger = LoggerFactory.getLogger(SafeTensorIndex.class);
        om = new ObjectMapper();
    }
}
