package com.github.tjake.jlama.safetensors;

import com.fasterxml.jackson.core.TreeNode;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.type.MapType;
import com.github.tjake.jlama.model.ModelSupport;
import com.github.tjake.jlama.safetensors.tokenizer.TokenizerModel;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.Q4ByteBufferTensor;
import com.github.tjake.jlama.tensor.Q5ByteBufferTensor;
import com.github.tjake.jlama.tensor.Q8ByteBufferTensor;
import com.github.tjake.jlama.util.HttpSupport;
import com.github.tjake.jlama.util.JsonSupport;
import com.github.tjake.jlama.util.TriConsumer;
import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;
import java.io.File;
import java.io.IOException;
import java.io.OutputStream;
import java.io.RandomAccessFile;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.FileChannel;
import java.nio.file.CopyOption;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.attribute.FileAttribute;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/github/tjake/jlama/safetensors/SafeTensorSupport.class */
public class SafeTensorSupport {
    private static final Logger logger = LoggerFactory.getLogger(SafeTensorSupport.class);
    private static final MapType metadataTypeReference = JsonSupport.om.getTypeFactory().constructMapType(Map.class, String.class, String.class);
    static String FINISHED_MARKER = ".finished";

    public static Map<String, TensorInfo> readTensorInfoMap(ByteBuffer byteBuffer, Optional<Map<String, String>> optional) {
        ByteBuffer order = byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
        byte[] bArr = new byte[Ints.checkedCast(order.getLong())];
        order.get(bArr);
        try {
            Iterator fields = JsonSupport.om.readTree(bArr).fields();
            HashMap hashMap = new HashMap();
            Map emptyMap = Collections.emptyMap();
            while (fields.hasNext()) {
                Map.Entry entry = (Map.Entry) fields.next();
                if (((String) entry.getKey()).equalsIgnoreCase("__metadata__")) {
                    emptyMap = (Map) JsonSupport.om.treeToValue((TreeNode) entry.getValue(), metadataTypeReference);
                } else {
                    hashMap.put((String) entry.getKey(), (TensorInfo) JsonSupport.om.treeToValue((TreeNode) entry.getValue(), TensorInfo.class));
                }
            }
            Map map = emptyMap;
            optional.ifPresent(map2 -> {
                map2.putAll(map);
            });
            return hashMap;
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public static Weights readWeights(ByteBuffer byteBuffer) {
        ByteBuffer duplicate = byteBuffer.duplicate();
        HashMap hashMap = new HashMap();
        return new Weights(hashMap, readTensorInfoMap(duplicate, Optional.of(hashMap)), duplicate.slice(), Optional.empty());
    }

    public static ModelSupport.ModelType detectModel(File file) throws IOException {
        JsonNode readTree = JsonSupport.om.readTree(file);
        if (readTree.has("model_type")) {
            return ModelSupport.ModelType.valueOf(readTree.get("model_type").textValue().toUpperCase());
        }
        throw new IllegalArgumentException("Config missing model_type field.");
    }

    public static WeightLoader loadWeights(File file) {
        if (Files.exists(Paths.get(file.getAbsolutePath(), SafeTensorIndex.MODEL_INDEX_JSON), new LinkOption[0])) {
            return SafeTensorIndex.loadWithWeights(file.toPath());
        }
        if (Files.exists(Paths.get(file.getAbsolutePath(), SafeTensorIndex.SINGLE_MODEL_NAME), new LinkOption[0])) {
            return SafeTensorIndex.loadSingleFile(file.toPath(), SafeTensorIndex.SINGLE_MODEL_NAME);
        }
        throw new IllegalArgumentException("No safetensor model found in: " + String.valueOf(file));
    }

    public static boolean isModelLocal(Path path) {
        if (Files.exists(path.resolve(SafeTensorIndex.SINGLE_MODEL_NAME), new LinkOption[0])) {
            return true;
        }
        try {
            if (!Files.exists(path.resolve(SafeTensorIndex.MODEL_INDEX_JSON), new LinkOption[0])) {
                return false;
            }
            Iterator<String> it = ((SafeTensorIndex) JsonSupport.om.readValue(path.resolve(SafeTensorIndex.MODEL_INDEX_JSON).toFile(), SafeTensorIndex.class)).weightFileMap.values().iterator();
            while (it.hasNext()) {
                if (!Files.exists(path.resolve(it.next()), new LinkOption[0])) {
                    return false;
                }
            }
            return true;
        } catch (IOException e) {
            logger.error("Error reading model index", e);
            return false;
        }
    }

    public static TokenizerModel loadTokenizer(Path path) throws IOException {
        File file = path.resolve("tokenizer.json").toFile();
        Preconditions.checkArgument(file.exists(), "No tokenizer.json found in " + String.valueOf(path));
        JsonNode readTree = JsonSupport.om.readTree(file);
        if (!readTree.has("model")) {
            throw new IllegalArgumentException("Json missing 'model' key");
        }
        TokenizerModel tokenizerModel = (TokenizerModel) JsonSupport.om.treeToValue(readTree.get("model"), TokenizerModel.class);
        if (readTree.has("added_tokens") && readTree.get("added_tokens") != null) {
            tokenizerModel.setAddedTokens((List) JsonSupport.om.convertValue(readTree.get("added_tokens"), List.class));
        }
        if (readTree.has("pre_tokenizer") && readTree.get("pre_tokenizer") != null) {
            tokenizerModel.setPreTokenizer((TokenizerModel.PreTokenizer) JsonSupport.om.treeToValue(readTree.get("pre_tokenizer"), TokenizerModel.PreTokenizer.class));
        }
        if (readTree.has("normalizer") && readTree.get("normalizer") != null) {
            tokenizerModel.setNormalizer((TokenizerModel.Normalizer) JsonSupport.om.treeToValue(readTree.get("normalizer"), TokenizerModel.Normalizer.class));
        }
        File file2 = path.resolve("tokenizer_config.json").toFile();
        if (file2.exists()) {
            JsonNode readTree2 = JsonSupport.om.readTree(file2);
            if (readTree2.has("legacy")) {
                tokenizerModel.setLegacy(readTree2.get("legacy").asBoolean());
            }
            if (readTree2.has("chat_template")) {
                JsonNode jsonNode = readTree2.get("chat_template");
                HashMap hashMap = new HashMap();
                if (jsonNode.isTextual()) {
                    hashMap.put("default", jsonNode.asText());
                } else {
                    if (!jsonNode.isArray()) {
                        throw new IllegalArgumentException("Invalid chat_template format");
                    }
                    for (Map map : (List) JsonSupport.om.convertValue(jsonNode, List.class)) {
                        if (!map.containsKey("name") || !map.containsKey("template")) {
                            throw new IllegalArgumentException("Invalid chat_template format");
                        }
                        hashMap.put((String) map.get("name"), (String) map.get("template"));
                    }
                }
                tokenizerModel.setPromptTemplates(hashMap);
            }
            if (readTree2.has("eos_token")) {
                tokenizerModel.setEosToken(readTree2.get("eos_token").asText());
            }
            if (readTree2.has("bos_token")) {
                tokenizerModel.setBosToken(readTree2.get("bos_token").asText());
            }
        }
        return tokenizerModel;
    }

    public static Path quantizeModel(Path path, DType dType, String[] strArr, String[] strArr2, Optional<Path> optional) throws IOException {
        File createTempFile = File.createTempFile("safe", "tensor");
        createTempFile.deleteOnExit();
        WeightLoader loadWeights = loadWeights(path.toFile());
        HashMap hashMap = new HashMap();
        final RandomAccessFile randomAccessFile = new RandomAccessFile(createTempFile, "rw");
        try {
            for (Map.Entry<String, TensorInfo> entry : loadWeights.tensorInfoMap().entrySet()) {
                boolean z = false;
                if (strArr2 != null) {
                    for (String str : strArr2) {
                        if (entry.getKey().startsWith(str)) {
                            logger.info("Dropping layer: " + entry.getKey());
                            z = true;
                        }
                    }
                }
                if (!z) {
                    AbstractTensor load = loadWeights.load(entry.getKey());
                    boolean z2 = false;
                    if (strArr != null) {
                        try {
                            int length = strArr.length;
                            int i = 0;
                            while (true) {
                                if (i < length) {
                                    if (entry.getKey().contains(strArr[i])) {
                                        logger.info("Skipping quantization of layer: " + entry.getKey());
                                        z2 = true;
                                    } else {
                                        i++;
                                    }
                                }
                            }
                        } finally {
                        }
                    }
                    AbstractTensor quantize = z2 ? load : load.quantize(dType);
                    switch (quantize.dType()) {
                        case F32:
                        case BF16:
                        case F16:
                            hashMap.put(entry.getKey(), quantize.save(randomAccessFile.getChannel()));
                            break;
                        case Q4:
                            hashMap.put(entry.getKey(), quantize.save(randomAccessFile.getChannel()));
                            hashMap.put(entry.getKey() + ".qb", ((Q4ByteBufferTensor) quantize).getBlockF().save(randomAccessFile.getChannel()));
                            break;
                        case Q5:
                            hashMap.put(entry.getKey(), quantize.save(randomAccessFile.getChannel()));
                            hashMap.put(entry.getKey() + ".qb", ((Q5ByteBufferTensor) quantize).getBlockF().save(randomAccessFile.getChannel()));
                            throw new UnsupportedOperationException("TODO");
                        case I8:
                            hashMap.put(entry.getKey(), quantize.save(randomAccessFile.getChannel()));
                            hashMap.put(entry.getKey() + ".qb", ((Q8ByteBufferTensor) quantize).getBlockF().save(randomAccessFile.getChannel()));
                            break;
                        default:
                            throw new UnsupportedOperationException(String.valueOf(quantize.dType()) + " not implemented");
                    }
                    if (load != null) {
                        load.close();
                    }
                }
            }
            randomAccessFile.close();
            String path2 = path.getName(path.getNameCount() - 1).toString();
            Path parent = path.getParent();
            Path orElseGet = optional.orElseGet(() -> {
                return Paths.get(parent.toString(), path2 + "-Jlama-" + dType.name());
            });
            orElseGet.toFile().mkdirs();
            Files.copy(path.resolve("config.json"), orElseGet.resolve("config.json"), new CopyOption[0]);
            Files.copy(path.resolve("tokenizer.json"), orElseGet.resolve("tokenizer.json"), new CopyOption[0]);
            Files.copy(path.resolve("README.md"), orElseGet.resolve("README.md"), new CopyOption[0]);
            if (Files.exists(path.resolve("tokenizer_config.json"), new LinkOption[0])) {
                Files.copy(path.resolve("tokenizer_config.json"), orElseGet.resolve("tokenizer_config.json"), new CopyOption[0]);
            }
            randomAccessFile = new RandomAccessFile(orElseGet.resolve(SafeTensorIndex.SINGLE_MODEL_NAME).toFile(), "rw");
            try {
                FileChannel channel = randomAccessFile.getChannel();
                byte[] writeValueAsBytes = JsonSupport.om.writeValueAsBytes(hashMap);
                logger.debug("pos = {}", Long.valueOf(channel.position()));
                byte[] bArr = new byte[8];
                ByteBuffer.wrap(bArr).order(ByteOrder.LITTLE_ENDIAN).putLong(writeValueAsBytes.length);
                randomAccessFile.write(bArr);
                logger.debug("pos = {}", Long.valueOf(channel.position()));
                randomAccessFile.write(writeValueAsBytes);
                logger.debug("pos = {}", Long.valueOf(channel.position()));
                Files.copy(createTempFile.toPath(), new OutputStream() { // from class: com.github.tjake.jlama.safetensors.SafeTensorSupport.1
                    @Override // java.io.OutputStream
                    public void write(int i2) throws IOException {
                        randomAccessFile.write(i2);
                    }

                    @Override // java.io.OutputStream
                    public void write(byte[] bArr2) throws IOException {
                        randomAccessFile.write(bArr2);
                    }

                    @Override // java.io.OutputStream
                    public void write(byte[] bArr2, int i2, int i3) throws IOException {
                        randomAccessFile.write(bArr2, i2, i3);
                    }
                });
                randomAccessFile.close();
                return orElseGet;
            } finally {
            }
        } finally {
        }
    }

    public static File maybeDownloadModel(String str, String str2) throws IOException {
        String str3;
        String str4;
        String[] split = str2.split("/");
        if (split.length == 0 || split.length > 2) {
            throw new IllegalArgumentException("Model must be in the form owner/name");
        }
        if (split.length == 1) {
            str3 = null;
            str4 = str2;
        } else {
            str3 = split[0];
            str4 = split[1];
        }
        return maybeDownloadModel(str, Optional.ofNullable(str3), str4, true, Optional.empty(), Optional.empty(), Optional.empty());
    }

    public static Path constructLocalModelPath(String str, String str2, String str3) {
        return Paths.get(str, str2 + "_" + str3);
    }

    public static File maybeDownloadModel(String str, Optional<String> optional, String str2, boolean z, Optional<String> optional2, Optional<String> optional3, Optional<TriConsumer<String, Long, Long>> optional4) throws IOException {
        Path constructLocalModelPath = constructLocalModelPath(str, optional.orElse("na"), str2);
        if (Files.exists(constructLocalModelPath.resolve(FINISHED_MARKER), new LinkOption[0])) {
            return constructLocalModelPath.toFile();
        }
        String str3 = (String) optional.map(str4 -> {
            return str4 + "/" + str2;
        }).orElse(str2);
        String readInputStream = HttpSupport.readInputStream(HttpSupport.getResponse("https://huggingface.co/api/models/" + str3 + "/tree/" + optional2.orElse("main"), optional3, Optional.empty()).left);
        if (readInputStream == null) {
            throw new IOException("No valid model found or trying to access a restricted model (please include correct access token)");
        }
        List<String> parseFileList = parseFileList(readInputStream);
        if (parseFileList.isEmpty()) {
            throw new IOException("No valid model found");
        }
        ArrayList<String> arrayList = new ArrayList();
        boolean z2 = false;
        for (String str5 : parseFileList) {
            String lowerCase = str5.toLowerCase();
            if ((lowerCase.contains("safetensor") && !lowerCase.contains("consolidated")) || lowerCase.contains("readme") || lowerCase.equals("config.json") || lowerCase.contains("tokenizer")) {
                if (lowerCase.contains("safetensor")) {
                    z2 = true;
                }
                if (z || !lowerCase.contains("safetensor")) {
                    arrayList.add(str5);
                }
            }
        }
        if (!z2) {
            throw new IOException("Model is not available in safetensor format");
        }
        Files.createDirectories(constructLocalModelPath, new FileAttribute[0]);
        for (String str6 : arrayList) {
            HttpSupport.downloadFile(str3, str6, optional2, optional3, Optional.empty(), constructLocalModelPath.resolve(str6), optional4);
        }
        Files.createFile(constructLocalModelPath.resolve(FINISHED_MARKER), new FileAttribute[0]);
        return constructLocalModelPath.toFile();
    }

    private static List<String> parseFileList(String str) throws IOException {
        ArrayList arrayList = new ArrayList();
        JsonNode readTree = new ObjectMapper().readTree(str);
        if (readTree.isArray()) {
            Iterator it = readTree.iterator();
            while (it.hasNext()) {
                arrayList.add(((JsonNode) it.next()).path("path").asText());
            }
        }
        return arrayList;
    }
}
