package dev.langchain4j.service;

import dev.langchain4j.Internal;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.guardrail.ChatExecutor;
import dev.langchain4j.guardrail.GuardrailRequestParams;
import dev.langchain4j.guardrail.OutputGuardrailRequest;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.service.guardrail.GuardrailService;
import dev.langchain4j.service.tool.ToolExecution;
import dev.langchain4j.service.tool.ToolExecutor;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Consumer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Internal
/* loaded from: input_file:dev/langchain4j/service/AiServiceStreamingResponseHandler.class */
class AiServiceStreamingResponseHandler implements StreamingChatResponseHandler {
    private static final Logger LOG = LoggerFactory.getLogger(AiServiceStreamingResponseHandler.class);
    private final ChatExecutor chatExecutor;
    private final AiServiceContext context;
    private final Object memoryId;
    private final GuardrailRequestParams commonGuardrailParams;
    private final Object methodKey;
    private final Consumer<String> partialResponseHandler;
    private final Consumer<ToolExecution> toolExecutionHandler;
    private final Consumer<ChatResponse> completeResponseHandler;
    private final Consumer<Throwable> errorHandler;
    private final ChatMemory temporaryMemory;
    private final TokenUsage tokenUsage;
    private final List<ToolSpecification> toolSpecifications;
    private final Map<String, ToolExecutor> toolExecutors;
    private final List<String> responseBuffer = new ArrayList();
    private final boolean hasOutputGuardrails;

    /* JADX INFO: Access modifiers changed from: package-private */
    public AiServiceStreamingResponseHandler(ChatExecutor chatExecutor, AiServiceContext aiServiceContext, Object obj, Consumer<String> consumer, Consumer<ToolExecution> consumer2, Consumer<ChatResponse> consumer3, Consumer<Throwable> consumer4, ChatMemory chatMemory, TokenUsage tokenUsage, List<ToolSpecification> list, Map<String, ToolExecutor> map, GuardrailRequestParams guardrailRequestParams, Object obj2) {
        this.chatExecutor = (ChatExecutor) ValidationUtils.ensureNotNull(chatExecutor, "chatExecutor");
        this.context = (AiServiceContext) ValidationUtils.ensureNotNull(aiServiceContext, "context");
        this.memoryId = ValidationUtils.ensureNotNull(obj, "memoryId");
        this.methodKey = obj2;
        this.partialResponseHandler = (Consumer) ValidationUtils.ensureNotNull(consumer, "partialResponseHandler");
        this.completeResponseHandler = consumer3;
        this.toolExecutionHandler = consumer2;
        this.errorHandler = consumer4;
        this.temporaryMemory = chatMemory;
        this.tokenUsage = (TokenUsage) ValidationUtils.ensureNotNull(tokenUsage, "tokenUsage");
        this.commonGuardrailParams = guardrailRequestParams;
        this.toolSpecifications = Utils.copy(list);
        this.toolExecutors = Utils.copy(map);
        this.hasOutputGuardrails = aiServiceContext.guardrailService().hasOutputGuardrails(obj2);
    }

    public void onPartialResponse(String str) {
        if (this.hasOutputGuardrails) {
            this.responseBuffer.add(str);
        } else {
            this.partialResponseHandler.accept(str);
        }
    }

    public void onCompleteResponse(ChatResponse chatResponse) {
        AiMessage aiMessage = chatResponse.aiMessage();
        addToMemory(aiMessage);
        if (aiMessage.hasToolExecutionRequests()) {
            for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) {
                String execute = this.toolExecutors.get(toolExecutionRequest.name()).execute(toolExecutionRequest, this.memoryId);
                addToMemory(ToolExecutionResultMessage.from(toolExecutionRequest, execute));
                if (this.toolExecutionHandler != null) {
                    this.toolExecutionHandler.accept(ToolExecution.builder().request(toolExecutionRequest).result(execute).build());
                }
            }
            this.context.streamingChatModel.chat(ChatRequest.builder().messages(messagesToSend(this.memoryId)).toolSpecifications(this.toolSpecifications).build(), new AiServiceStreamingResponseHandler(this.chatExecutor, this.context, this.memoryId, this.partialResponseHandler, this.toolExecutionHandler, this.completeResponseHandler, this.errorHandler, this.temporaryMemory, TokenUsage.sum(this.tokenUsage, chatResponse.metadata().tokenUsage()), this.toolSpecifications, this.toolExecutors, this.commonGuardrailParams, this.methodKey));
            return;
        }
        if (this.completeResponseHandler != null) {
            ChatResponse build = ChatResponse.builder().aiMessage(aiMessage).metadata(chatResponse.metadata().toBuilder().tokenUsage(this.tokenUsage.add(chatResponse.metadata().tokenUsage())).build()).build();
            if (this.hasOutputGuardrails) {
                if (this.commonGuardrailParams != null) {
                    build = (ChatResponse) this.context.guardrailService().executeGuardrails((GuardrailService) this.methodKey, OutputGuardrailRequest.builder().responseFromLLM(build).chatExecutor(this.chatExecutor).requestParams(GuardrailRequestParams.builder().chatMemory(getMemory()).augmentationResult(this.commonGuardrailParams.augmentationResult()).userMessageTemplate(this.commonGuardrailParams.userMessageTemplate()).variables(this.commonGuardrailParams.variables()).build()).build());
                }
                List<String> list = this.responseBuffer;
                Consumer<String> consumer = this.partialResponseHandler;
                Objects.requireNonNull(consumer);
                list.forEach((v1) -> {
                    r1.accept(v1);
                });
                this.responseBuffer.clear();
            }
            this.completeResponseHandler.accept(build);
        }
    }

    private ChatMemory getMemory() {
        return getMemory(this.memoryId);
    }

    private ChatMemory getMemory(Object obj) {
        return this.context.hasChatMemory() ? this.context.chatMemoryService.getOrCreateChatMemory(this.memoryId) : this.temporaryMemory;
    }

    private void addToMemory(ChatMessage chatMessage) {
        getMemory().add(chatMessage);
    }

    private List<ChatMessage> messagesToSend(Object obj) {
        return getMemory(obj).messages();
    }

    public void onError(Throwable th) {
        if (this.errorHandler == null) {
            LOG.warn("Ignored error", th);
            return;
        }
        try {
            this.errorHandler.accept(th);
        } catch (Exception e) {
            LOG.error("While handling the following error...", th);
            LOG.error("...the following error happened", e);
        }
    }
}
