package com.github.tjake.jlama.model;

import com.github.tjake.jlama.math.VectorMath;
import com.github.tjake.jlama.safetensors.Config;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider;
import com.google.common.base.Preconditions;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.function.Consumer;

/* loaded from: input_file:com/github/tjake/jlama/model/CausalSelfAttention.class */
public class CausalSelfAttention {
    private final AbstractModel m;
    private final Config c;
    private final Optional<AbstractTensor> queryAttnBias;
    private final Optional<AbstractTensor> keyAttnBias;
    private final Optional<AbstractTensor> valueAttnBias;
    private final Optional<AbstractTensor> outputProjectionBias;
    final AbstractTensor queryAttnWeights;
    final AbstractTensor keyAttnWeights;
    final AbstractTensor valueAttnWeights;
    private final AbstractTensor outputProjectionWeights;
    private final float attentionScale;
    private final AbstractTensor[] qkvResults;
    private final AbstractTensor[] qkvWeights;
    private static final boolean USE_FLASH_ATTN = false;

    public CausalSelfAttention(AbstractModel abstractModel, AbstractTensor abstractTensor, AbstractTensor abstractTensor2, AbstractTensor abstractTensor3, AbstractTensor abstractTensor4) {
        this(abstractModel, (Optional<AbstractTensor>) Optional.empty(), (Optional<AbstractTensor>) Optional.empty(), (Optional<AbstractTensor>) Optional.empty(), abstractTensor, abstractTensor2, abstractTensor3, (Optional<AbstractTensor>) Optional.empty(), abstractTensor4);
    }

    public CausalSelfAttention(AbstractModel abstractModel, AbstractTensor abstractTensor, AbstractTensor abstractTensor2, AbstractTensor abstractTensor3, AbstractTensor abstractTensor4, AbstractTensor abstractTensor5, AbstractTensor abstractTensor6, AbstractTensor abstractTensor7, AbstractTensor abstractTensor8) {
        this(abstractModel, (Optional<AbstractTensor>) Optional.of(abstractTensor), (Optional<AbstractTensor>) Optional.of(abstractTensor2), (Optional<AbstractTensor>) Optional.of(abstractTensor3), abstractTensor4, abstractTensor5, abstractTensor6, (Optional<AbstractTensor>) Optional.of(abstractTensor7), abstractTensor8);
    }

    public CausalSelfAttention(AbstractModel abstractModel, Optional<AbstractTensor> optional, Optional<AbstractTensor> optional2, Optional<AbstractTensor> optional3, AbstractTensor abstractTensor, AbstractTensor abstractTensor2, AbstractTensor abstractTensor3, Optional<AbstractTensor> optional4, AbstractTensor abstractTensor4) {
        this.m = abstractModel;
        this.c = abstractModel.c;
        this.queryAttnBias = optional;
        this.keyAttnBias = optional2;
        this.valueAttnBias = optional3;
        this.queryAttnWeights = abstractTensor;
        this.keyAttnWeights = abstractTensor2;
        this.valueAttnWeights = abstractTensor3;
        this.outputProjectionBias = optional4;
        this.outputProjectionWeights = abstractTensor4;
        this.attentionScale = (float) (1.0d / StrictMath.sqrt(this.c.headSize));
        this.qkvResults = new AbstractTensor[3];
        this.qkvWeights = new AbstractTensor[]{abstractTensor, abstractTensor2, abstractTensor3};
    }

    public AbstractTensor forward(AbstractTensor abstractTensor, int i, AbstractTensor abstractTensor2, Optional<Consumer<List<AbstractTensor>>> optional) {
        Preconditions.checkArgument(abstractTensor.dims() == 2 && abstractTensor.shape().last() == this.c.embeddingLength);
        int first = abstractTensor.shape().first();
        AbstractTensor makeTensor = this.m.makeTensor(first, this.c.numberOfHeads);
        try {
            AbstractTensor makeTensor2 = this.m.makeTensor(first, this.c.numberOfHeads);
            try {
                AbstractTensor makeFullTensor = this.m.makeFullTensor(first, this.c.embeddingLength);
                try {
                    AbstractTensor makeFullTensor2 = this.m.makeFullTensor(first, this.c.kvLength);
                    try {
                        AbstractTensor makeFullTensor3 = this.m.makeFullTensor(first, this.c.kvLength);
                        try {
                            AbstractTensor makeFullTensor4 = this.m.makeFullTensor(first, this.c.embeddingLength);
                            try {
                                if (this.c.isGQA) {
                                    VectorMath.pchunk(USE_FLASH_ATTN, this.c.embeddingLength, (i2, i3) -> {
                                        TensorOperationsProvider.get().dotProductChunk(makeFullTensor, abstractTensor, this.queryAttnWeights, this.c.embeddingSegmentStart(), this.c.embeddingSegmentLength(), i2, i3);
                                    });
                                    VectorMath.pchunk(USE_FLASH_ATTN, this.c.kvLength, (i4, i5) -> {
                                        TensorOperationsProvider.get().dotProductChunk(makeFullTensor2, abstractTensor, this.keyAttnWeights, this.c.embeddingSegmentStart(), this.c.embeddingSegmentLength(), i4, i5);
                                        TensorOperationsProvider.get().dotProductChunk(makeFullTensor3, abstractTensor, this.valueAttnWeights, this.c.embeddingSegmentStart(), this.c.embeddingSegmentLength(), i4, i5);
                                    });
                                } else {
                                    this.qkvResults[USE_FLASH_ATTN] = makeFullTensor;
                                    this.qkvResults[1] = makeFullTensor2;
                                    this.qkvResults[2] = makeFullTensor3;
                                    VectorMath.pchunk(USE_FLASH_ATTN, this.c.embeddingLength, (i6, i7) -> {
                                        TensorOperationsProvider.get().dotProductBatchChunk(this.qkvResults, abstractTensor, this.qkvWeights, this.c.embeddingSegmentStart(), this.c.embeddingSegmentLength(), i6, i7);
                                    });
                                }
                                optional.ifPresent(consumer -> {
                                    consumer.accept(List.of(makeFullTensor, makeFullTensor2, makeFullTensor3));
                                });
                                this.queryAttnBias.ifPresent(abstractTensor3 -> {
                                    TensorOperationsProvider.get().accumulate(makeFullTensor, abstractTensor3, this.c.embeddingSegmentStart(), this.c.embeddingSegmentLength());
                                });
                                this.keyAttnBias.ifPresent(abstractTensor4 -> {
                                    TensorOperationsProvider.get().accumulate(makeFullTensor2, abstractTensor4, this.c.kvSegmentStart(), this.c.kvSegmentLength());
                                });
                                this.valueAttnBias.ifPresent(abstractTensor5 -> {
                                    TensorOperationsProvider.get().accumulate(makeFullTensor3, abstractTensor5, this.c.kvSegmentStart(), this.c.kvSegmentLength());
                                });
                                abstractTensor2.slice(true, USE_FLASH_ATTN).slice(USE_FLASH_ATTN);
                                abstractTensor2.slice(true, 1).slice(USE_FLASH_ATTN);
                                int i8 = i;
                                int i9 = USE_FLASH_ATTN;
                                while (i8 < i + first) {
                                    int i10 = i8;
                                    AbstractTensor slice = abstractTensor2.slice(true, USE_FLASH_ATTN);
                                    AbstractTensor slice2 = abstractTensor2.slice(true, 1);
                                    AbstractTensor slice3 = slice.slice(i8);
                                    AbstractTensor slice4 = slice2.slice(i8);
                                    AbstractTensor slice5 = makeFullTensor2.slice(i9);
                                    AbstractTensor slice6 = makeFullTensor3.slice(i9);
                                    AbstractTensor slice7 = makeFullTensor.slice(i9);
                                    AbstractTensor slice8 = makeFullTensor4.slice(i9);
                                    slice3.copyFrom(slice5, slice5.getOffset(USE_FLASH_ATTN, this.c.kvSegmentStart()), slice3.getOffset(USE_FLASH_ATTN, this.c.kvSegmentStart()), this.c.kvSegmentLength());
                                    slice4.copyFrom(slice6, slice6.getOffset(USE_FLASH_ATTN, this.c.kvSegmentStart()), slice4.getOffset(USE_FLASH_ATTN, this.c.kvSegmentStart()), this.c.kvSegmentLength());
                                    this.c.ropeFreqs.ifPresent(fArr -> {
                                        int i11 = this.c.headSize / 2;
                                        int i12 = i10 * i11;
                                        if (!this.c.isGQA) {
                                            for (int headStart = this.c.headStart(); headStart < this.c.headEnd(); headStart++) {
                                                int i13 = headStart * this.c.headSize;
                                                for (int i14 = i13; i14 < i13 + i11; i14++) {
                                                    float f = slice7.get(USE_FLASH_ATTN, i14);
                                                    float f2 = slice7.get(USE_FLASH_ATTN, i14 + i11);
                                                    float f3 = slice3.get(USE_FLASH_ATTN, i14);
                                                    float f4 = slice3.get(USE_FLASH_ATTN, i14 + i11);
                                                    float[] fArr = fArr[i12 + i14];
                                                    float f5 = fArr[USE_FLASH_ATTN];
                                                    float f6 = fArr[1];
                                                    slice7.set((f * f5) - (f2 * f6), USE_FLASH_ATTN, i14);
                                                    slice7.set((f * f6) + (f2 * f5), USE_FLASH_ATTN, i14 + i11);
                                                    slice3.set((f3 * f5) - (f4 * f6), USE_FLASH_ATTN, i14);
                                                    slice3.set((f3 * f6) + (f4 * f5), USE_FLASH_ATTN, i14 + i11);
                                                }
                                            }
                                            return;
                                        }
                                        for (int headStart2 = this.c.headStart(); headStart2 < this.c.headEnd(); headStart2++) {
                                            int i15 = headStart2 * this.c.headSize;
                                            int i16 = i15;
                                            int maybeMapToGroupHead = this.c.maybeMapToGroupHead(headStart2) * this.c.headSize;
                                            while (i16 < i15 + i11) {
                                                float f7 = slice7.get(USE_FLASH_ATTN, i16);
                                                float f8 = slice7.get(USE_FLASH_ATTN, i16 + i11);
                                                float[] fArr2 = fArr[i12 + maybeMapToGroupHead];
                                                float f9 = fArr2[USE_FLASH_ATTN];
                                                float f10 = fArr2[1];
                                                slice7.set((f7 * f9) - (f8 * f10), USE_FLASH_ATTN, i16);
                                                slice7.set((f7 * f10) + (f8 * f9), USE_FLASH_ATTN, i16 + i11);
                                                i16++;
                                                maybeMapToGroupHead++;
                                            }
                                        }
                                        for (int groupHeadStart = this.c.groupHeadStart(); groupHeadStart < this.c.groupHeadEnd(); groupHeadStart++) {
                                            int i17 = groupHeadStart * this.c.headSize;
                                            for (int i18 = i17; i18 < i17 + i11; i18++) {
                                                float f11 = slice3.get(USE_FLASH_ATTN, i18);
                                                float f12 = slice3.get(USE_FLASH_ATTN, i18 + i11);
                                                float[] fArr3 = fArr[i12 + i18];
                                                float f13 = fArr3[USE_FLASH_ATTN];
                                                float f14 = fArr3[1];
                                                slice3.set((f11 * f13) - (f12 * f14), USE_FLASH_ATTN, i18);
                                                slice3.set((f11 * f14) + (f12 * f13), USE_FLASH_ATTN, i18 + i11);
                                            }
                                        }
                                    });
                                    VectorMath.pfor(this.c.headStart(), this.c.headEnd(), i11 -> {
                                        AbstractTensor makeFullTensor5 = this.m.makeFullTensor(1, slice.shape().first());
                                        try {
                                            int maybeMapToGroupHead = this.c.maybeMapToGroupHead(i11) * this.c.headSize;
                                            int i11 = i11 * this.c.headSize;
                                            TensorOperationsProvider.get().batchDotProduct(makeFullTensor5, slice7, slice, i11, maybeMapToGroupHead, this.c.headSize, USE_FLASH_ATTN, i10 + 1);
                                            TensorOperationsProvider.get().scale(this.attentionScale, makeFullTensor5, USE_FLASH_ATTN, i10 + 1);
                                            VectorMath.softMax(makeFullTensor5, USE_FLASH_ATTN, i10 + 1);
                                            TensorOperationsProvider.get().saxpy(makeFullTensor5, slice2, slice8, maybeMapToGroupHead, i11, this.c.headSize, i10 + 1);
                                            if (makeFullTensor5 != null) {
                                                makeFullTensor5.close();
                                            }
                                        } catch (Throwable th) {
                                            if (makeFullTensor5 != null) {
                                                try {
                                                    makeFullTensor5.close();
                                                } catch (Throwable th2) {
                                                    th.addSuppressed(th2);
                                                }
                                            }
                                            throw th;
                                        }
                                    });
                                    i8++;
                                    i9++;
                                }
                                AbstractTensor makeFullTensor5 = this.m.makeFullTensor(first, this.c.embeddingLength);
                                AbstractTensor maybeQuantize = this.m.maybeQuantize(makeFullTensor4);
                                try {
                                    VectorMath.pchunk(USE_FLASH_ATTN, this.c.embeddingLength, (i12, i13) -> {
                                        TensorOperationsProvider.get().dotProductChunk(makeFullTensor5, maybeQuantize, this.outputProjectionWeights, this.c.embeddingSegmentStart(), this.c.embeddingSegmentLength(), i12, i13);
                                    });
                                    optional.ifPresent(consumer2 -> {
                                        consumer2.accept(Collections.singletonList(makeFullTensor5));
                                    });
                                    this.outputProjectionBias.ifPresent(abstractTensor6 -> {
                                        TensorOperationsProvider.get().accumulate(makeFullTensor5, abstractTensor6, this.c.embeddingSegmentStart(), this.c.embeddingSegmentLength());
                                    });
                                    if (maybeQuantize != null) {
                                        maybeQuantize.close();
                                    }
                                    if (makeFullTensor4 != null) {
                                        makeFullTensor4.close();
                                    }
                                    if (makeFullTensor3 != null) {
                                        makeFullTensor3.close();
                                    }
                                    if (makeFullTensor2 != null) {
                                        makeFullTensor2.close();
                                    }
                                    if (makeFullTensor != null) {
                                        makeFullTensor.close();
                                    }
                                    if (makeTensor2 != null) {
                                        makeTensor2.close();
                                    }
                                    if (makeTensor != null) {
                                        makeTensor.close();
                                    }
                                    return makeFullTensor5;
                                } catch (Throwable th) {
                                    if (maybeQuantize != null) {
                                        try {
                                            maybeQuantize.close();
                                        } catch (Throwable th2) {
                                            th.addSuppressed(th2);
                                        }
                                    }
                                    throw th;
                                }
                            } catch (Throwable th3) {
                                if (makeFullTensor4 != null) {
                                    try {
                                        makeFullTensor4.close();
                                    } catch (Throwable th4) {
                                        th3.addSuppressed(th4);
                                    }
                                }
                                throw th3;
                            }
                        } catch (Throwable th5) {
                            if (makeFullTensor3 != null) {
                                try {
                                    makeFullTensor3.close();
                                } catch (Throwable th6) {
                                    th5.addSuppressed(th6);
                                }
                            }
                            throw th5;
                        }
                    } catch (Throwable th7) {
                        if (makeFullTensor2 != null) {
                            try {
                                makeFullTensor2.close();
                            } catch (Throwable th8) {
                                th7.addSuppressed(th8);
                            }
                        }
                        throw th7;
                    }
                } catch (Throwable th9) {
                    if (makeFullTensor != null) {
                        try {
                            makeFullTensor.close();
                        } catch (Throwable th10) {
                            th9.addSuppressed(th10);
                        }
                    }
                    throw th9;
                }
            } catch (Throwable th11) {
                if (makeTensor2 != null) {
                    try {
                        makeTensor2.close();
                    } catch (Throwable th12) {
                        th11.addSuppressed(th12);
                    }
                }
                throw th11;
            }
        } catch (Throwable th13) {
            if (makeTensor != null) {
                try {
                    makeTensor.close();
                } catch (Throwable th14) {
                    th13.addSuppressed(th14);
                }
            }
            throw th13;
        }
    }

    private /* synthetic */ void lambda$forward$8(int i, AbstractTensor abstractTensor, AbstractTensor abstractTensor2, AbstractTensor abstractTensor3, int i2, AbstractTensor abstractTensor4, AbstractTensor abstractTensor5, int i3) {
        int maybeMapToGroupHead = this.c.maybeMapToGroupHead(i3) * this.c.headSize;
        int i4 = i3 * this.c.headSize;
        for (int i5 = USE_FLASH_ATTN; i5 < i; i5++) {
            AbstractTensor slice = abstractTensor.slice(true, i5 + 1);
            float dotProduct = TensorOperationsProvider.get().dotProduct(abstractTensor2, slice, i4, maybeMapToGroupHead, this.c.headSize) * this.attentionScale;
            if (dotProduct > abstractTensor3.get(i2, i3)) {
                float exp = (float) Math.exp(abstractTensor3.get(i2, i3) - dotProduct);
                TensorOperationsProvider.get().sxpby(exp, slice, abstractTensor4, maybeMapToGroupHead, i4, this.c.headSize);
                abstractTensor5.set(1.0f + (exp * abstractTensor5.get(i2, i3)), i2, i3);
                abstractTensor3.set(dotProduct, i2, i3);
            } else {
                float exp2 = (float) Math.exp(dotProduct - abstractTensor3.get(i2, i3));
                TensorOperationsProvider.get().saxpy(exp2, slice, abstractTensor4, slice.getOffset(1, maybeMapToGroupHead), i4, this.c.headSize);
                abstractTensor5.set(abstractTensor5.get(i2, i3) + exp2, i2, i3);
            }
        }
    }
}
