package com.github.tjake.jlama.model.bert;

import com.github.tjake.jlama.math.ActivationFunction;
import com.github.tjake.jlama.math.VectorMath;
import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.CausalSelfAttention;
import com.github.tjake.jlama.model.LayerNorm;
import com.github.tjake.jlama.model.MLPBlock;
import com.github.tjake.jlama.model.TransformerBlock;
import com.github.tjake.jlama.model.functions.EmbedInput;
import com.github.tjake.jlama.model.functions.SampleOutput;
import com.github.tjake.jlama.safetensors.Config;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.WeightLoader;
import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;
import java.util.Arrays;
import java.util.Optional;

/* loaded from: input_file:com/github/tjake/jlama/model/bert/BertModel.class */
public class BertModel extends AbstractModel {
    public BertModel(Config config, WeightLoader weightLoader, Tokenizer tokenizer, DType dType, DType dType2, Optional<DType> optional) {
        super(AbstractModel.InferenceType.FORWARD_PASS, config, weightLoader, tokenizer, dType, dType2, optional);
    }

    public BertModel(AbstractModel.InferenceType inferenceType, Config config, WeightLoader weightLoader, Tokenizer tokenizer, DType dType, DType dType2, Optional<DType> optional) {
        super(inferenceType, config, weightLoader, tokenizer, dType, dType2, optional);
    }

    @Override // com.github.tjake.jlama.model.AbstractModel
    protected EmbedInput loadInputWeights() {
        AbstractTensor load = this.weights.load("embeddings.word_embeddings.weight");
        AbstractTensor load2 = this.weights.load("embeddings.token_type_embeddings.weight");
        AbstractTensor load3 = this.weights.load("embeddings.position_embeddings.weight");
        LayerNorm layerNorm = new LayerNorm(this, this.weights.load("embeddings.LayerNorm.bias"), this.weights.load("embeddings.LayerNorm.weight"));
        return (i, i2) -> {
            AbstractTensor makeTensor = makeTensor(this.c.embeddingLength);
            for (int i = 0; i < this.c.embeddingLength; i++) {
                makeTensor.set(load.get(i, i) + load2.get(0, i) + load3.get(i2, i), 0, i);
            }
            AbstractTensor forward = layerNorm.forward(makeTensor);
            makeTensor.close();
            return forward;
        };
    }

    @Override // com.github.tjake.jlama.model.AbstractModel
    protected TransformerBlock[] loadTransformerBlockWeights() {
        TransformerBlock[] transformerBlockArr = new TransformerBlock[this.c.getNumberOfLayers()];
        for (int layerStart = this.c.layerStart(); layerStart < this.c.layerEnd(); layerStart++) {
            String str = "encoder.layer." + layerStart + ".";
            String str2 = str + "attention.";
            AbstractTensor load = this.weights.load(str2 + "self.key.bias");
            AbstractTensor load2 = this.weights.load(str2 + "self.key.weight");
            transformerBlockArr[layerStart] = new TransformerBlock(this, new CausalSelfAttention(this, load, this.weights.load(str2 + "self.query.bias"), this.weights.load(str2 + "self.value.bias"), load2, this.weights.load(str2 + "self.query.weight"), this.weights.load(str2 + "self.value.weight"), this.weights.load(str2 + "output.dense.bias"), this.weights.load(str2 + "output.dense.weight")), new LayerNorm(this, this.weights.load(str + "attention.output.LayerNorm.bias"), this.weights.load(str + "attention.output.LayerNorm.weight")), new MLPBlock(this, ActivationFunction.Type.GELU, this.weights.load(str + "intermediate.dense.bias"), this.weights.load(str + "intermediate.dense.weight"), this.weights.load(str + "output.dense.bias"), this.weights.load(str + "output.dense.weight")), new LayerNorm(this, this.weights.load(str + "output.LayerNorm.bias"), this.weights.load(str + "output.LayerNorm.weight")));
        }
        return transformerBlockArr;
    }

    @Override // com.github.tjake.jlama.model.AbstractModel
    protected SampleOutput loadOutputWeights() {
        throw new UnsupportedOperationException();
    }

    public float[] embed(String str) {
        int[] array = Arrays.stream(this.tokenizer.encode(str)).mapToInt(Ints::checkedCast).toArray();
        Preconditions.checkArgument(array.length < this.c.contextLength);
        float[] fArr = new float[this.c.embeddingLength];
        AbstractTensor makeTensor = makeTensor(this.c.getNumberOfLayers(), 2, array.length, this.c.embeddingLength);
        try {
            int length = array.length;
            float f = 1.0f / length;
            AbstractTensor batchForward = batchForward(array, 0, makeTensor);
            for (int i = 0; i < length; i++) {
                AbstractTensor slice = batchForward.slice(i);
                for (int i2 = 0; i2 < this.c.embeddingLength; i2++) {
                    int i3 = i2;
                    fArr[i3] = fArr[i3] + (slice.get(0, i2) * f);
                }
            }
            batchForward.close();
            VectorMath.l2normalize(fArr);
            if (makeTensor != null) {
                makeTensor.close();
            }
            return fArr;
        } catch (Throwable th) {
            if (makeTensor != null) {
                try {
                    makeTensor.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }
}
