package dev.langchain4j.service;

import dev.langchain4j.agent.tool.P;
import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ChatMessageType;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
import dev.langchain4j.model.openai.OpenAiChatModelName;
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mockito;
import org.mockito.Spy;
import org.mockito.junit.jupiter.MockitoExtension;

@ExtendWith({MockitoExtension.class})
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
/* loaded from: input_file:dev/langchain4j/service/StreamingAiServicesWithToolsWithoutMemoryIT.class */
class StreamingAiServicesWithToolsWithoutMemoryIT {

    @Spy
    StreamingChatLanguageModel spyModel = OpenAiStreamingChatModel.builder().baseUrl(System.getenv("OPENAI_BASE_URL")).apiKey(System.getenv("OPENAI_API_KEY")).organizationId(System.getenv("OPENAI_ORGANIZATION_ID")).modelName(OpenAiChatModelName.GPT_4_O_MINI).temperature(Double.valueOf(0.0d)).logRequests(true).logResponses(true).build();

    @Captor
    ArgumentCaptor<ChatRequest> chatRequestCaptor;

    /* loaded from: input_file:dev/langchain4j/service/StreamingAiServicesWithToolsWithoutMemoryIT$Assistant.class */
    interface Assistant {
        TokenStream chat(String str);
    }

    /* loaded from: input_file:dev/langchain4j/service/StreamingAiServicesWithToolsWithoutMemoryIT$Calculator.class */
    static class Calculator {
        Calculator() {
        }

        @Tool({"calculates the square root of the provided number"})
        double squareRoot(@P("number to operate on") double d) {
            System.out.printf("called squareRoot(%s)%n", Double.valueOf(d));
            return Math.sqrt(d);
        }
    }

    StreamingAiServicesWithToolsWithoutMemoryIT() {
    }

    @Test
    void should_execute_a_tool_then_answer() throws Exception {
        Calculator calculator = (Calculator) Mockito.spy(new Calculator());
        Assistant assistant = (Assistant) AiServices.builder(Assistant.class).streamingChatLanguageModel(this.spyModel).tools(new Object[]{calculator}).build();
        CompletableFuture completableFuture = new CompletableFuture();
        TokenStream onNext = assistant.chat("What is the square root of 485906798473894056 in scientific notation?").onNext(str -> {
        });
        Objects.requireNonNull(completableFuture);
        TokenStream onComplete = onNext.onComplete((v1) -> {
            r1.complete(v1);
        });
        Objects.requireNonNull(completableFuture);
        onComplete.onError(completableFuture::completeExceptionally).start();
        Response response = (Response) completableFuture.get(60L, TimeUnit.SECONDS);
        Assertions.assertThat(((AiMessage) response.content()).text()).contains(new CharSequence[]{"6.97"});
        Assertions.assertThat(response.finishReason()).isEqualTo(FinishReason.STOP);
        TokenUsage tokenUsage = response.tokenUsage();
        Assertions.assertThat(tokenUsage.inputTokenCount()).isPositive();
        Assertions.assertThat(tokenUsage.outputTokenCount()).isPositive();
        Assertions.assertThat(tokenUsage.totalTokenCount()).isEqualTo(tokenUsage.inputTokenCount().intValue() + tokenUsage.outputTokenCount().intValue());
        ((Calculator) Mockito.verify(calculator)).squareRoot(4.859067984738941E17d);
        Mockito.verifyNoMoreInteractions(new Object[]{calculator});
        ((StreamingChatLanguageModel) Mockito.verify(this.spyModel, Mockito.times(2))).chat((ChatRequest) this.chatRequestCaptor.capture(), (StreamingChatResponseHandler) Mockito.any());
        List allValues = this.chatRequestCaptor.getAllValues();
        List messages = ((ChatRequest) allValues.get(0)).messages();
        Assertions.assertThat(messages).hasSize(1);
        Assertions.assertThat(((ChatMessage) messages.get(0)).type()).isEqualTo(ChatMessageType.USER);
        Assertions.assertThat(((UserMessage) messages.get(0)).singleText()).isEqualTo("What is the square root of 485906798473894056 in scientific notation?");
        List messages2 = ((ChatRequest) allValues.get(1)).messages();
        Assertions.assertThat(messages2).hasSize(3);
        Assertions.assertThat(((ChatMessage) messages2.get(0)).type()).isEqualTo(ChatMessageType.USER);
        Assertions.assertThat(((UserMessage) messages2.get(0)).singleText()).isEqualTo("What is the square root of 485906798473894056 in scientific notation?");
        Assertions.assertThat(((ChatMessage) messages2.get(1)).type()).isEqualTo(ChatMessageType.AI);
        Assertions.assertThat(((ChatMessage) messages2.get(2)).type()).isEqualTo(ChatMessageType.TOOL_EXECUTION_RESULT);
    }

    @Test
    void should_execute_multiple_tools_sequentially_then_answer() throws Exception {
        Calculator calculator = (Calculator) Mockito.spy(new Calculator());
        StreamingChatLanguageModel streamingChatLanguageModel = (StreamingChatLanguageModel) Mockito.spy(OpenAiStreamingChatModel.builder().baseUrl(System.getenv("OPENAI_BASE_URL")).apiKey(System.getenv("OPENAI_API_KEY")).organizationId(System.getenv("OPENAI_ORGANIZATION_ID")).modelName(OpenAiChatModelName.GPT_4_O_MINI).parallelToolCalls(false).temperature(Double.valueOf(0.0d)).logRequests(true).logResponses(true).build());
        Assistant assistant = (Assistant) AiServices.builder(Assistant.class).streamingChatLanguageModel(streamingChatLanguageModel).tools(new Object[]{calculator}).build();
        CompletableFuture completableFuture = new CompletableFuture();
        TokenStream onNext = assistant.chat("What is the square root of 485906798473894056 and 97866249624785 in scientific notation?").onNext(str -> {
        });
        Objects.requireNonNull(completableFuture);
        TokenStream onComplete = onNext.onComplete((v1) -> {
            r1.complete(v1);
        });
        Objects.requireNonNull(completableFuture);
        onComplete.onError(completableFuture::completeExceptionally).start();
        Response response = (Response) completableFuture.get(60L, TimeUnit.SECONDS);
        Assertions.assertThat(((AiMessage) response.content()).text()).contains(new CharSequence[]{"6.97", "9.89"});
        Assertions.assertThat(response.finishReason()).isEqualTo(FinishReason.STOP);
        TokenUsage tokenUsage = response.tokenUsage();
        Assertions.assertThat(tokenUsage.inputTokenCount()).isPositive();
        Assertions.assertThat(tokenUsage.outputTokenCount()).isPositive();
        Assertions.assertThat(tokenUsage.totalTokenCount()).isEqualTo(tokenUsage.inputTokenCount().intValue() + tokenUsage.outputTokenCount().intValue());
        ((Calculator) Mockito.verify(calculator)).squareRoot(4.859067984738941E17d);
        ((Calculator) Mockito.verify(calculator)).squareRoot(9.7866249624785E13d);
        Mockito.verifyNoMoreInteractions(new Object[]{calculator});
        ((StreamingChatLanguageModel) Mockito.verify(streamingChatLanguageModel, Mockito.times(3))).chat((ChatRequest) this.chatRequestCaptor.capture(), (StreamingChatResponseHandler) Mockito.any());
        List allValues = this.chatRequestCaptor.getAllValues();
        List messages = ((ChatRequest) allValues.get(0)).messages();
        Assertions.assertThat(messages).hasSize(1);
        Assertions.assertThat(((ChatMessage) messages.get(0)).type()).isEqualTo(ChatMessageType.USER);
        Assertions.assertThat(((UserMessage) messages.get(0)).singleText()).isEqualTo("What is the square root of 485906798473894056 and 97866249624785 in scientific notation?");
        List messages2 = ((ChatRequest) allValues.get(1)).messages();
        Assertions.assertThat(messages2).hasSize(3);
        Assertions.assertThat(((ChatMessage) messages2.get(0)).type()).isEqualTo(ChatMessageType.USER);
        Assertions.assertThat(((UserMessage) messages2.get(0)).singleText()).isEqualTo("What is the square root of 485906798473894056 and 97866249624785 in scientific notation?");
        Assertions.assertThat(((ChatMessage) messages2.get(1)).type()).isEqualTo(ChatMessageType.AI);
        Assertions.assertThat(((ChatMessage) messages2.get(2)).type()).isEqualTo(ChatMessageType.TOOL_EXECUTION_RESULT);
        List messages3 = ((ChatRequest) allValues.get(2)).messages();
        Assertions.assertThat(messages3).hasSize(5);
        Assertions.assertThat(((ChatMessage) messages3.get(0)).type()).isEqualTo(ChatMessageType.USER);
        Assertions.assertThat(((UserMessage) messages3.get(0)).singleText()).isEqualTo("What is the square root of 485906798473894056 and 97866249624785 in scientific notation?");
        Assertions.assertThat(((ChatMessage) messages3.get(1)).type()).isEqualTo(ChatMessageType.AI);
        Assertions.assertThat(((ChatMessage) messages3.get(2)).type()).isEqualTo(ChatMessageType.TOOL_EXECUTION_RESULT);
        Assertions.assertThat(((ChatMessage) messages3.get(3)).type()).isEqualTo(ChatMessageType.AI);
        Assertions.assertThat(((ChatMessage) messages3.get(4)).type()).isEqualTo(ChatMessageType.TOOL_EXECUTION_RESULT);
    }

    @Test
    void should_execute_multiple_tools_in_parallel_then_answer() throws Exception {
        Calculator calculator = (Calculator) Mockito.spy(new Calculator());
        Assistant assistant = (Assistant) AiServices.builder(Assistant.class).streamingChatLanguageModel(this.spyModel).tools(new Object[]{calculator}).build();
        CompletableFuture completableFuture = new CompletableFuture();
        TokenStream onNext = assistant.chat("What is the square root of 485906798473894056 and 97866249624785 in scientific notation?").onNext(str -> {
        });
        Objects.requireNonNull(completableFuture);
        TokenStream onComplete = onNext.onComplete((v1) -> {
            r1.complete(v1);
        });
        Objects.requireNonNull(completableFuture);
        onComplete.onError(completableFuture::completeExceptionally).start();
        Response response = (Response) completableFuture.get(60L, TimeUnit.SECONDS);
        Assertions.assertThat(((AiMessage) response.content()).text()).contains(new CharSequence[]{"6.97", "9.89"});
        Assertions.assertThat(response.finishReason()).isEqualTo(FinishReason.STOP);
        TokenUsage tokenUsage = response.tokenUsage();
        Assertions.assertThat(tokenUsage.inputTokenCount()).isPositive();
        Assertions.assertThat(tokenUsage.outputTokenCount()).isPositive();
        Assertions.assertThat(tokenUsage.totalTokenCount()).isEqualTo(tokenUsage.inputTokenCount().intValue() + tokenUsage.outputTokenCount().intValue());
        ((Calculator) Mockito.verify(calculator)).squareRoot(4.859067984738941E17d);
        ((Calculator) Mockito.verify(calculator)).squareRoot(9.7866249624785E13d);
        Mockito.verifyNoMoreInteractions(new Object[]{calculator});
        ((StreamingChatLanguageModel) Mockito.verify(this.spyModel, Mockito.times(2))).chat((ChatRequest) this.chatRequestCaptor.capture(), (StreamingChatResponseHandler) Mockito.any());
        List allValues = this.chatRequestCaptor.getAllValues();
        List messages = ((ChatRequest) allValues.get(0)).messages();
        Assertions.assertThat(messages).hasSize(1);
        Assertions.assertThat(((ChatMessage) messages.get(0)).type()).isEqualTo(ChatMessageType.USER);
        Assertions.assertThat(((UserMessage) messages.get(0)).singleText()).isEqualTo("What is the square root of 485906798473894056 and 97866249624785 in scientific notation?");
        List messages2 = ((ChatRequest) allValues.get(1)).messages();
        Assertions.assertThat(messages2).hasSize(4);
        Assertions.assertThat(((ChatMessage) messages2.get(0)).type()).isEqualTo(ChatMessageType.USER);
        Assertions.assertThat(((UserMessage) messages2.get(0)).singleText()).isEqualTo("What is the square root of 485906798473894056 and 97866249624785 in scientific notation?");
        Assertions.assertThat(((ChatMessage) messages2.get(1)).type()).isEqualTo(ChatMessageType.AI);
        Assertions.assertThat(((ChatMessage) messages2.get(2)).type()).isEqualTo(ChatMessageType.TOOL_EXECUTION_RESULT);
        Assertions.assertThat(((ChatMessage) messages2.get(3)).type()).isEqualTo(ChatMessageType.TOOL_EXECUTION_RESULT);
    }
}
