package com.github.tjake.jlama.safetensors;

import com.github.tjake.jlama.model.DistributedContext;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.TensorShape;
import com.github.tjake.jlama.util.HttpSupport;
import com.github.tjake.jlama.util.JsonSupport;
import com.github.tjake.jlama.util.Pair;
import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;
import java.io.File;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/github/tjake/jlama/safetensors/HTTPSafeTensorLoader.class */
public class HTTPSafeTensorLoader implements WeightLoader {
    private static final Logger logger;
    private final Path modelRoot;
    private final String indexFile;
    private final String modelName;
    private final Optional<String> branch;
    private final Optional<String> authToken;
    private final SafeTensorIndex index;
    private final Map<String, Pair<RandomAccessFile, AbstractTensor>> layerFiles;
    private final Map<String, TensorInfo> dynamicTensorInfoMap;
    private final Map<String, Integer> tensorFileOffsets;
    private final DType modelDType;
    static final /* synthetic */ boolean $assertionsDisabled;

    public HTTPSafeTensorLoader(Path path, String str, String str2, DType dType, Optional<String> optional, Optional<String> optional2) {
        this.modelRoot = path;
        this.modelName = str + "/" + str2;
        this.branch = optional;
        this.indexFile = String.format("%s/%s", path, SafeTensorIndex.MODEL_INDEX_JSON);
        this.authToken = optional2;
        if (new File(this.indexFile).exists()) {
            try {
                this.index = (SafeTensorIndex) JsonSupport.om.readValue(new File(this.indexFile), SafeTensorIndex.class);
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        } else {
            this.index = new SafeTensorIndex(Collections.emptyMap(), Map.of("model-file", SafeTensorIndex.SINGLE_MODEL_NAME));
        }
        this.layerFiles = new HashMap();
        this.dynamicTensorInfoMap = new HashMap();
        this.tensorFileOffsets = new HashMap();
        this.modelDType = dType;
    }

    @Override // com.github.tjake.jlama.safetensors.WeightLoader
    public Map<String, String> metadata() {
        return this.index.metadata();
    }

    @Override // com.github.tjake.jlama.safetensors.WeightLoader
    public Map<String, TensorInfo> tensorInfoMap() {
        return this.dynamicTensorInfoMap;
    }

    @Override // com.github.tjake.jlama.safetensors.WeightLoader
    public AbstractTensor load(String str, DistributedContext distributedContext, boolean z, boolean z2) {
        Preconditions.checkArgument((z2 && z) ? false : true, "Cannot have both sparse rows and columns");
        Preconditions.checkArgument(this.index.weightFileMap.containsKey(str) || this.index.weightFileMap.size() == 1, "Unknown weight: " + str);
        if (this.layerFiles.containsKey(str)) {
            return this.layerFiles.get(str).right();
        }
        try {
            TensorInfo maybeLoadTensorInfo = maybeLoadTensorInfo(str);
            Pair<TensorShape, Pair<Long, Long>> loadOffsets = Weights.getLoadOffsets(maybeLoadTensorInfo, distributedContext, z);
            Integer num = this.tensorFileOffsets.get(str);
            if (!$assertionsDisabled && (num == null || num.intValue() <= 0)) {
                throw new AssertionError("Failed to find header offset for: " + str);
            }
            TensorShape tensorShape = loadOffsets.left;
            long longValue = loadOffsets.right.left.longValue() + num.intValue();
            long longValue2 = loadOffsets.right.right.longValue() + num.intValue();
            String orDefault = this.index.weightFileMap.getOrDefault(str, SafeTensorIndex.SINGLE_MODEL_NAME);
            Path path = this.modelRoot;
            Path resolve = path.resolve(orDefault + ".part." + longValue + "_" + path);
            if (!resolve.toFile().exists()) {
                logger.info("Downloading file: {} for {} {}MB", new Object[]{resolve, str, Long.valueOf(((longValue2 - longValue) / 1024) / 1024)});
                HttpSupport.downloadFile(this.modelName, orDefault, this.branch, this.authToken, Optional.of(Pair.of(Long.valueOf(longValue), Long.valueOf(longValue2))), resolve, Optional.empty());
            }
            int checkedCast = Ints.checkedCast(longValue2 - longValue);
            RandomAccessFile randomAccessFile = new RandomAccessFile(resolve.toFile(), "r");
            ByteBuffer limit = randomAccessFile.getChannel().map(FileChannel.MapMode.READ_ONLY, 0L, randomAccessFile.length()).duplicate().order(ByteOrder.LITTLE_ENDIAN).position(0).limit(checkedCast);
            if (randomAccessFile.length() < checkedCast) {
                long length = randomAccessFile.length();
                String.valueOf(resolve);
                RuntimeException runtimeException = new RuntimeException("Failed to download the correct number of bytes: " + length + " != " + runtimeException + " for " + checkedCast);
                throw runtimeException;
            }
            logger.debug("Loading tensor: {} from {} with offsets: {} {}", new Object[]{str, resolve, Long.valueOf(longValue), Long.valueOf(longValue2)});
            AbstractTensor loadTensorFromBuffer = Weights.loadTensorFromBuffer(str, maybeLoadTensorInfo.dType, this.modelDType, tensorShape, limit, z, z2, distributedContext, this);
            this.layerFiles.put(str, Pair.of(randomAccessFile, loadTensorFromBuffer));
            return loadTensorFromBuffer;
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private TensorInfo maybeLoadTensorInfo(String str) throws IOException {
        if (this.dynamicTensorInfoMap.containsKey(str)) {
            return this.dynamicTensorInfoMap.get(str);
        }
        String orDefault = this.index.weightFileMap.getOrDefault(str, SafeTensorIndex.SINGLE_MODEL_NAME);
        Path resolve = this.modelRoot.resolve(orDefault + ".header");
        if (!Files.exists(resolve, new LinkOption[0])) {
            HttpSupport.downloadFile(this.modelName, orDefault, this.branch, this.authToken, Optional.of(Pair.of(0L, 1048576L)), resolve, Optional.empty());
        }
        RandomAccessFile randomAccessFile = new RandomAccessFile(resolve.toFile(), "r");
        try {
            MappedByteBuffer map = randomAccessFile.getChannel().map(FileChannel.MapMode.READ_ONLY, 0L, Math.min(1048576L, randomAccessFile.length()));
            Map<String, TensorInfo> readTensorInfoMap = SafeTensorSupport.readTensorInfoMap(map, Optional.empty());
            int position = map.position();
            for (Map.Entry<String, TensorInfo> entry : readTensorInfoMap.entrySet()) {
                this.dynamicTensorInfoMap.put(entry.getKey(), entry.getValue());
                this.tensorFileOffsets.put(entry.getKey(), Integer.valueOf(position));
            }
            randomAccessFile.close();
            if ($assertionsDisabled || this.dynamicTensorInfoMap.containsKey(str)) {
                return this.dynamicTensorInfoMap.get(str);
            }
            throw new AssertionError("Failed to load tensor info for: " + str);
        } catch (Throwable th) {
            try {
                randomAccessFile.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Override // com.github.tjake.jlama.safetensors.WeightLoader
    public DType getModelDType() {
        return this.modelDType;
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        Iterator<Pair<RandomAccessFile, AbstractTensor>> it = this.layerFiles.values().iterator();
        while (it.hasNext()) {
            try {
                it.next().left().close();
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
        this.layerFiles.clear();
        this.dynamicTensorInfoMap.clear();
    }

    static {
        $assertionsDisabled = !HTTPSafeTensorLoader.class.desiredAssertionStatus();
        logger = LoggerFactory.getLogger(HTTPSafeTensorLoader.class);
    }
}
