package com.github.tjake.jlama.model;

import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.bert.BertConfig;
import com.github.tjake.jlama.model.bert.BertModel;
import com.github.tjake.jlama.model.bert.BertTokenizer;
import com.github.tjake.jlama.model.gemma.GemmaConfig;
import com.github.tjake.jlama.model.gemma.GemmaModel;
import com.github.tjake.jlama.model.gemma.GemmaTokenizer;
import com.github.tjake.jlama.model.gpt2.GPT2Config;
import com.github.tjake.jlama.model.gpt2.GPT2Model;
import com.github.tjake.jlama.model.gpt2.GPT2Tokenizer;
import com.github.tjake.jlama.model.llama.LlamaConfig;
import com.github.tjake.jlama.model.llama.LlamaModel;
import com.github.tjake.jlama.model.llama.LlamaTokenizer;
import com.github.tjake.jlama.model.mistral.MistralConfig;
import com.github.tjake.jlama.model.mistral.MistralModel;
import com.github.tjake.jlama.model.mixtral.MixtralConfig;
import com.github.tjake.jlama.model.mixtral.MixtralModel;
import com.github.tjake.jlama.safetensors.Config;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.SafeTensorSupport;
import com.github.tjake.jlama.safetensors.WeightLoader;
import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer;
import com.github.tjake.jlama.util.JsonSupport;
import com.github.tjake.jlama.util.Pair;
import com.github.tjake.jlama.util.PhysicalCoreExecutor;
import java.io.File;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.nio.file.Path;
import java.util.Objects;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/github/tjake/jlama/model/ModelSupport.class */
public class ModelSupport {
    private static final Logger logger = LoggerFactory.getLogger(ModelSupport.class);

    /* loaded from: input_file:com/github/tjake/jlama/model/ModelSupport$ModelType.class */
    public enum ModelType {
        GEMMA(GemmaModel.class, GemmaConfig.class, GemmaTokenizer.class),
        MISTRAL(MistralModel.class, MistralConfig.class, LlamaTokenizer.class),
        MIXTRAL(MixtralModel.class, MixtralConfig.class, LlamaTokenizer.class),
        LLAMA(LlamaModel.class, LlamaConfig.class, LlamaTokenizer.class),
        GPT2(GPT2Model.class, GPT2Config.class, GPT2Tokenizer.class),
        BERT(BertModel.class, BertConfig.class, BertTokenizer.class);

        public final Class<? extends AbstractModel> modelClass;
        public final Class<? extends Config> configClass;
        public final Class<? extends Tokenizer> tokenizerClass;

        ModelType(Class cls, Class cls2, Class cls3) {
            this.modelClass = cls;
            this.configClass = cls2;
            this.tokenizerClass = cls3;
        }
    }

    public static AbstractModel loadModel(File file, DType dType, DType dType2) {
        return loadModel(file, null, dType, dType2, Optional.empty(), Optional.empty());
    }

    public static AbstractModel loadModel(File file, File file2, DType dType, DType dType2, Optional<DType> optional, Optional<Integer> optional2) {
        return loadModel(AbstractModel.InferenceType.FULL_GENERATION, file, file2, dType, dType2, optional, optional2, Optional.empty());
    }

    public static AbstractModel loadModel(AbstractModel.InferenceType inferenceType, File file, File file2, DType dType, DType dType2, Optional<DType> optional, Optional<Integer> optional2, Optional<Pair<Integer, Integer>> optional3) {
        if (!file.exists()) {
            throw new IllegalArgumentException("Model location does not exist: " + String.valueOf(file));
        }
        File parentFile = file.isFile() ? file.getParentFile() : file;
        if (!parentFile.isDirectory()) {
            throw new IllegalArgumentException("Model directory does not exist: " + String.valueOf(parentFile));
        }
        File file3 = null;
        File[] fileArr = (File[]) Objects.requireNonNull(parentFile.listFiles());
        int length = fileArr.length;
        int i = 0;
        while (true) {
            if (i >= length) {
                break;
            }
            File file4 = fileArr[i];
            if (file4.getName().equals("config.json")) {
                file3 = file4;
                break;
            }
            i++;
        }
        if (file3 == null) {
            throw new IllegalArgumentException("config.json in model directory does not exist: " + String.valueOf(parentFile));
        }
        try {
            optional2.ifPresent((v0) -> {
                PhysicalCoreExecutor.overrideThreadCount(v0);
            });
            ModelType detectModel = SafeTensorSupport.detectModel(file3);
            Config config = (Config) JsonSupport.om.readValue(file3, detectModel.configClass);
            Objects.requireNonNull(config);
            optional3.ifPresent(config::setOffset);
            config.setWorkingDirectory(file2);
            return detectModel.modelClass.getConstructor(AbstractModel.InferenceType.class, Config.class, WeightLoader.class, Tokenizer.class, DType.class, DType.class, Optional.class).newInstance(inferenceType, config, SafeTensorSupport.loadWeights(parentFile), detectModel.tokenizerClass.getConstructor(Path.class).newInstance(parentFile.toPath()), dType, dType2, optional);
        } catch (IOException | IllegalAccessException | InstantiationException | NoSuchMethodException | InvocationTargetException e) {
            throw new RuntimeException(e);
        }
    }
}
