package com.github.tjake.jlama.model.bert;

import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.CausalSelfAttention;
import com.github.tjake.jlama.model.LayerNorm;
import com.github.tjake.jlama.model.MLPBlock;
import com.github.tjake.jlama.model.ModelSupport;
import com.github.tjake.jlama.model.TransformerBlock;
import com.github.tjake.jlama.model.functions.ClassifyOutput;
import com.github.tjake.jlama.model.functions.EmbedInput;
import com.github.tjake.jlama.model.functions.PoolingLayer;
import com.github.tjake.jlama.model.functions.SampleOutput;
import com.github.tjake.jlama.safetensors.Config;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.WeightLoader;
import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer;
import com.github.tjake.jlama.tensor.AbstractTensor;
import java.util.Arrays;
import java.util.NoSuchElementException;
import java.util.Optional;

/* loaded from: input_file:com/github/tjake/jlama/model/bert/BertModel.class */
public class BertModel extends AbstractModel {
    private static final String[] prefixes = {"", "bert."};

    public BertModel(Config config, WeightLoader weightLoader, Tokenizer tokenizer, DType dType, DType dType2, Optional<DType> optional) {
        super(AbstractModel.InferenceType.FORWARD_PASS, config, weightLoader, tokenizer, dType, dType2, optional);
    }

    public BertModel(AbstractModel.InferenceType inferenceType, Config config, WeightLoader weightLoader, Tokenizer tokenizer, DType dType, DType dType2, Optional<DType> optional) {
        super(inferenceType, config, weightLoader, tokenizer, dType, dType2, optional);
    }

    protected AbstractTensor loadWeight(String str) {
        for (String str2 : prefixes) {
            String str3 = str2 + str;
            if (this.weights.isWeightPresent(str3)) {
                return this.weights.load(str3);
            }
        }
        throw new NoSuchElementException(Arrays.toString(prefixes) + " " + str + " not found in weights");
    }

    @Override // com.github.tjake.jlama.model.AbstractModel
    public ModelSupport.ModelType getModelType() {
        return ModelSupport.ModelType.BERT;
    }

    @Override // com.github.tjake.jlama.model.AbstractModel
    protected EmbedInput loadInputWeights() {
        AbstractTensor loadWeight = loadWeight("embeddings.word_embeddings.weight");
        AbstractTensor loadWeight2 = loadWeight("embeddings.token_type_embeddings.weight");
        AbstractTensor loadWeight3 = loadWeight("embeddings.position_embeddings.weight");
        LayerNorm layerNorm = new LayerNorm(this, loadWeight("embeddings.LayerNorm.bias"), loadWeight("embeddings.LayerNorm.weight"));
        return (i, i2) -> {
            AbstractTensor makeDenseTensor = makeDenseTensor(this.c.embeddingLength);
            for (int i = 0; i < this.c.embeddingLength; i++) {
                makeDenseTensor.set(loadWeight.get(i, i) + loadWeight2.get(0, i) + loadWeight3.get(i2, i), 0, i);
            }
            AbstractTensor forward = layerNorm.forward(makeDenseTensor);
            makeDenseTensor.close();
            return forward;
        };
    }

    @Override // com.github.tjake.jlama.model.AbstractModel
    protected TransformerBlock[] loadTransformerBlockWeights() {
        TransformerBlock[] transformerBlockArr = new TransformerBlock[this.c.dctx().embeddingSegmentLength];
        for (int i = this.c.dctx().layerStart; i < this.c.dctx().layerEnd; i++) {
            String str = "encoder.layer." + i + ".";
            String str2 = str + "attention.";
            AbstractTensor loadWeight = loadWeight(str2 + "self.key.bias");
            AbstractTensor loadWeight2 = loadWeight(str2 + "self.key.weight");
            int i2 = i;
            CausalSelfAttention causalSelfAttention = new CausalSelfAttention(this, i2, loadWeight, loadWeight(str2 + "self.query.bias"), loadWeight(str2 + "self.value.bias"), loadWeight2, loadWeight(str2 + "self.query.weight"), loadWeight(str2 + "self.value.weight"), loadWeight(str2 + "output.dense.bias"), loadWeight(str2 + "output.dense.weight"));
            MLPBlock mLPBlock = new MLPBlock(this, this.c.activationFunction, loadWeight(str + "intermediate.dense.bias"), loadWeight(str + "intermediate.dense.weight"), loadWeight(str + "output.dense.bias"), loadWeight(str + "output.dense.weight"));
            transformerBlockArr[i] = new TransformerBlock(this, i, causalSelfAttention, new LayerNorm(this, loadWeight(str + "attention.output.LayerNorm.bias"), loadWeight(str + "attention.output.LayerNorm.weight")), mLPBlock, new LayerNorm(this, loadWeight(str + "output.LayerNorm.bias"), loadWeight(str + "output.LayerNorm.weight")));
        }
        return transformerBlockArr;
    }

    @Override // com.github.tjake.jlama.model.AbstractModel
    protected SampleOutput loadOutputWeights() {
        throw new UnsupportedOperationException();
    }

    @Override // com.github.tjake.jlama.model.AbstractModel
    protected PoolingLayer loadPoolingWeights() {
        final AbstractTensor loadWeight = loadWeight("pooler.dense.weight");
        final AbstractTensor loadWeight2 = loadWeight("pooler.dense.bias");
        return new PoolingLayer(this) { // from class: com.github.tjake.jlama.model.bert.BertModel.1
            @Override // com.github.tjake.jlama.model.functions.PoolingLayer
            public AbstractTensor getPoolingWeights() {
                return loadWeight;
            }

            @Override // com.github.tjake.jlama.model.functions.PoolingLayer
            public Optional<AbstractTensor> getPoolingBias() {
                return Optional.of(loadWeight2);
            }
        };
    }

    @Override // com.github.tjake.jlama.model.AbstractModel
    protected ClassifyOutput loadClassifierWeights() {
        if (!this.c.isClassifier()) {
            throw new UnsupportedOperationException("Classification not supported by this model");
        }
        final AbstractTensor loadWeight = loadWeight("classifier.weight");
        final AbstractTensor loadWeight2 = loadWeight("classifier.bias");
        return new ClassifyOutput(this) { // from class: com.github.tjake.jlama.model.bert.BertModel.2
            @Override // com.github.tjake.jlama.model.functions.ClassifyOutput
            public AbstractTensor getClassificationWeights() {
                return loadWeight;
            }

            @Override // com.github.tjake.jlama.model.functions.ClassifyOutput
            public Optional<AbstractTensor> getClassificationBias() {
                return Optional.of(loadWeight2);
            }
        };
    }
}
