package com.github.tjake.jlama.model;

import com.github.tjake.jlama.tensor.AbstractTensor;

/* loaded from: input_file:com/github/tjake/jlama/model/RMSNorm.class */
public class RMSNorm extends LayerNorm {
    private final float weightAdjustment;

    public RMSNorm(AbstractModel abstractModel, AbstractTensor abstractTensor) {
        this(abstractModel, abstractTensor, 0.0f);
    }

    public RMSNorm(AbstractModel abstractModel, AbstractTensor abstractTensor, float f) {
        super(abstractModel, null, abstractTensor);
        this.weightAdjustment = f;
    }

    @Override // com.github.tjake.jlama.model.LayerNorm
    public AbstractTensor forward(AbstractTensor abstractTensor, int i, int i2) {
        int first = abstractTensor.shape().first();
        AbstractTensor copyShape = abstractTensor.copyShape();
        int i3 = i + i2;
        for (int i4 = 0; i4 < first; i4++) {
            double d = 0.0d;
            for (int i5 = i; i5 < i3; i5++) {
                float f = abstractTensor.get(i4, i5);
                d += f * f;
            }
            double sqrt = 1.0d / StrictMath.sqrt((d / this.m.c.embeddingLength) + this.m.c.layerNormEps);
            for (int i6 = i; i6 < i3; i6++) {
                copyShape.set((this.weightAdjustment + this.weights.get(0, i6)) * ((float) sqrt) * abstractTensor.get(i4, i6), i4, i6);
            }
        }
        return copyShape;
    }
}
