package com.github.tjake.jlama.model;

import com.github.tjake.jlama.model.functions.FeedForward;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider;
import com.github.tjake.jlama.util.DebugSupport;
import com.github.tjake.jlama.util.Pair;
import java.util.List;
import java.util.Optional;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/github/tjake/jlama/model/TransformerBlock.class */
public class TransformerBlock {
    private static final Logger logger = LoggerFactory.getLogger(TransformerBlock.class);
    private final AbstractModel model;
    final int layerIndex;
    final Optional<LayerNorm> preAttentionNorm;
    final CausalSelfAttention attention;
    final LayerNorm postAttentionNorm;
    final FeedForward ffBlock;
    final Optional<LayerNorm> postFFNorm;

    public TransformerBlock(AbstractModel abstractModel, int i, LayerNorm layerNorm, CausalSelfAttention causalSelfAttention, LayerNorm layerNorm2, FeedForward feedForward) {
        this.model = abstractModel;
        this.layerIndex = i;
        this.preAttentionNorm = Optional.of(layerNorm);
        this.attention = causalSelfAttention;
        this.postAttentionNorm = layerNorm2;
        this.ffBlock = feedForward;
        this.postFFNorm = Optional.empty();
    }

    public TransformerBlock(AbstractModel abstractModel, int i, CausalSelfAttention causalSelfAttention, LayerNorm layerNorm, FeedForward feedForward, LayerNorm layerNorm2) {
        this.model = abstractModel;
        this.layerIndex = i;
        this.preAttentionNorm = Optional.empty();
        this.attention = causalSelfAttention;
        this.postAttentionNorm = layerNorm;
        this.ffBlock = feedForward;
        this.postFFNorm = Optional.of(layerNorm2);
    }

    public AbstractTensor forward(AbstractTensor abstractTensor, int i, AbstractTensor abstractTensor2) {
        return forward(abstractTensor, i, abstractTensor2, Optional.empty(), Optional.empty());
    }

    public AbstractTensor forward(AbstractTensor abstractTensor, int i, AbstractTensor abstractTensor2, Optional<BiFunction<Float, Float, Pair<Float, Float>>> optional, Optional<Consumer<List<AbstractTensor>>> optional2) {
        DebugSupport.debug("input_emb", abstractTensor, this.layerIndex);
        AbstractTensor abstractTensor3 = (AbstractTensor) this.preAttentionNorm.map(layerNorm -> {
            return layerNorm.forward(abstractTensor, optional);
        }).orElse(abstractTensor);
        DebugSupport.debug("ln_emb", abstractTensor3, this.layerIndex);
        AbstractTensor maybeQuantize = this.model.maybeQuantize(abstractTensor3);
        try {
            AbstractTensor forward = this.attention.forward(maybeQuantize, i, abstractTensor2, optional2);
            if (maybeQuantize != null) {
                maybeQuantize.close();
            }
            DebugSupport.debug("post_attn", forward, this.layerIndex);
            TensorOperationsProvider.get().accumulate(forward, abstractTensor, this.model.c.embeddingSegmentStart(), this.model.c.embeddingSegmentLength());
            DebugSupport.debug("post_attn_res", forward, this.layerIndex);
            AbstractTensor forward2 = this.postAttentionNorm.forward(forward, optional);
            DebugSupport.debug("ln_emb2", forward2, this.layerIndex);
            maybeQuantize = this.model.maybeQuantize(forward2);
            try {
                AbstractTensor forward3 = this.ffBlock.forward(maybeQuantize, optional2);
                DebugSupport.debug("post_ff", forward3, this.layerIndex);
                if (maybeQuantize != null) {
                    maybeQuantize.close();
                }
                TensorOperationsProvider.get().accumulate(forward3, forward, this.model.c.embeddingSegmentStart(), this.model.c.embeddingSegmentLength());
                DebugSupport.debug("post_ff_res", forward3, this.layerIndex);
                if (abstractTensor3 != abstractTensor) {
                    abstractTensor3.close();
                }
                forward2.close();
                forward.close();
                return (AbstractTensor) this.postFFNorm.map(layerNorm2 -> {
                    AbstractTensor forward4 = layerNorm2.forward(forward3, optional);
                    DebugSupport.debug("ln_out", forward4, this.layerIndex);
                    forward3.close();
                    return forward4;
                }).orElse(forward3);
            } finally {
            }
        } finally {
        }
    }
}
