package com.github.tjake.jlama.model.gpt2;

import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.CausalSelfAttention;
import com.github.tjake.jlama.model.LayerNorm;
import com.github.tjake.jlama.model.MLPBlock;
import com.github.tjake.jlama.model.ModelSupport;
import com.github.tjake.jlama.model.TransformerBlock;
import com.github.tjake.jlama.model.functions.EmbedInput;
import com.github.tjake.jlama.model.functions.SampleOutput;
import com.github.tjake.jlama.safetensors.Config;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.WeightLoader;
import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer;
import com.github.tjake.jlama.tensor.AbstractTensor;
import java.util.Optional;

/* loaded from: input_file:com/github/tjake/jlama/model/gpt2/GPT2Model.class */
public class GPT2Model extends AbstractModel {
    public GPT2Model(Config config, WeightLoader weightLoader, Tokenizer tokenizer, DType dType, DType dType2, Optional<DType> optional) {
        super(AbstractModel.InferenceType.FULL_GENERATION, config, weightLoader, tokenizer, dType, dType2, optional);
    }

    public GPT2Model(AbstractModel.InferenceType inferenceType, Config config, WeightLoader weightLoader, Tokenizer tokenizer, DType dType, DType dType2, Optional<DType> optional) {
        super(inferenceType, config, weightLoader, tokenizer, dType, dType2, optional);
    }

    @Override // com.github.tjake.jlama.model.AbstractModel
    public ModelSupport.ModelType getModelType() {
        return ModelSupport.ModelType.GPT2;
    }

    @Override // com.github.tjake.jlama.model.AbstractModel
    protected EmbedInput loadInputWeights() {
        AbstractTensor load = this.weights.load("wte.weight");
        AbstractTensor load2 = this.weights.load("wpe.weight");
        return (i, i2) -> {
            AbstractTensor makeDenseTensor = makeDenseTensor(1, this.c.embeddingLength);
            for (int i = 0; i < this.c.embeddingLength; i++) {
                makeDenseTensor.set(load.get(i, i) + load2.get(i2, i), 0, i);
            }
            return makeDenseTensor;
        };
    }

    @Override // com.github.tjake.jlama.model.AbstractModel
    protected TransformerBlock[] loadTransformerBlockWeights() {
        TransformerBlock[] transformerBlockArr = new TransformerBlock[this.c.dctx().numberOfLayers];
        for (int i = this.c.dctx().layerStart; i < this.c.dctx().layerEnd; i++) {
            String str = "h." + i + ".";
            String str2 = str + "attn.";
            AbstractTensor[] split = this.weights.load(str2 + "c_attn.bias").split(3, 1);
            AbstractTensor[] split2 = this.weights.load(str2 + "c_attn.weight").transpose().split(3, 0);
            CausalSelfAttention causalSelfAttention = new CausalSelfAttention(this, i, split[0], split[1], split[2], split2[0], split2[1], split2[2], this.weights.load(str2 + "c_proj.bias"), this.weights.load(str2 + "c_proj.weight").transpose());
            String str3 = str + "mlp.";
            MLPBlock mLPBlock = new MLPBlock(this, this.c.activationFunction, this.weights.load(str3 + "c_fc.bias"), this.weights.load(str3 + "c_fc.weight").transpose(), this.weights.load(str3 + "c_proj.bias"), this.weights.load(str3 + "c_proj.weight").transpose());
            transformerBlockArr[i] = new TransformerBlock(this, i, new LayerNorm(this, this.weights.load(str + "ln_1.bias"), this.weights.load(str + "ln_1.weight")), causalSelfAttention, new LayerNorm(this, this.weights.load(str + "ln_2.bias"), this.weights.load(str + "ln_2.weight")), mLPBlock);
        }
        return transformerBlockArr;
    }

    @Override // com.github.tjake.jlama.model.AbstractModel
    protected SampleOutput loadOutputWeights() {
        final AbstractTensor load = this.weights.load("wte.weight");
        final LayerNorm layerNorm = new LayerNorm(this, this.weights.load("ln_f.bias"), this.weights.load("ln_f.weight"));
        return new SampleOutput(this) { // from class: com.github.tjake.jlama.model.gpt2.GPT2Model.1
            @Override // com.github.tjake.jlama.model.functions.SampleOutput
            public LayerNorm getOutputLayerNorm() {
                return layerNorm;
            }

            @Override // com.github.tjake.jlama.model.functions.SampleOutput
            public AbstractTensor getOutputLogitsWeights() {
                return load;
            }
        };
    }
}
