package dev.langchain4j.service;

import dev.langchain4j.agent.tool.P;
import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.openai.OpenAiChatModelName;
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.rag.AugmentationRequest;
import dev.langchain4j.rag.AugmentationResult;
import dev.langchain4j.rag.RetrievalAugmentor;
import dev.langchain4j.rag.content.Content;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.stream.Stream;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.Mockito;

@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
/* loaded from: input_file:dev/langchain4j/service/StreamingAiServicesIT.class */
class StreamingAiServicesIT {

    /* loaded from: input_file:dev/langchain4j/service/StreamingAiServicesIT$Assistant.class */
    interface Assistant {
        TokenStream chat(String str);
    }

    /* loaded from: input_file:dev/langchain4j/service/StreamingAiServicesIT$Calculator.class */
    static class Calculator {
        Calculator() {
        }

        @Tool({"calculates the square root of the provided number"})
        double squareRoot(@P("number to operate on") double d) {
            return Math.sqrt(d);
        }
    }

    StreamingAiServicesIT() {
    }

    static Stream<StreamingChatLanguageModel> models() {
        return Stream.of(OpenAiStreamingChatModel.builder().baseUrl(System.getenv("OPENAI_BASE_URL")).apiKey(System.getenv("OPENAI_API_KEY")).organizationId(System.getenv("OPENAI_ORGANIZATION_ID")).modelName(OpenAiChatModelName.GPT_4_O_MINI).logRequests(true).logResponses(true).build());
    }

    @MethodSource({"models"})
    @ParameterizedTest
    void should_stream_answer(StreamingChatLanguageModel streamingChatLanguageModel) throws Exception {
        Assistant assistant = (Assistant) AiServices.create(Assistant.class, streamingChatLanguageModel);
        StringBuilder sb = new StringBuilder();
        CompletableFuture completableFuture = new CompletableFuture();
        CompletableFuture completableFuture2 = new CompletableFuture();
        TokenStream chat = assistant.chat("What is the capital of Germany?");
        Objects.requireNonNull(sb);
        TokenStream onComplete = chat.onNext(sb::append).onComplete(response -> {
            completableFuture.complete(sb.toString());
            completableFuture2.complete(response);
        });
        Objects.requireNonNull(completableFuture);
        onComplete.onError(completableFuture::completeExceptionally).start();
        String str = (String) completableFuture.get(30L, TimeUnit.SECONDS);
        Response response2 = (Response) completableFuture2.get(30L, TimeUnit.SECONDS);
        Assertions.assertThat(str).contains(new CharSequence[]{"Berlin"});
        Assertions.assertThat(((AiMessage) response2.content()).text()).isEqualTo(str);
        TokenUsage tokenUsage = response2.tokenUsage();
        Assertions.assertThat(tokenUsage.inputTokenCount()).isPositive();
        Assertions.assertThat(tokenUsage.outputTokenCount()).isPositive();
        Assertions.assertThat(tokenUsage.totalTokenCount()).isEqualTo(tokenUsage.inputTokenCount().intValue() + tokenUsage.outputTokenCount().intValue());
        Assertions.assertThat(response2.finishReason()).isEqualTo(FinishReason.STOP);
    }

    @MethodSource({"models"})
    @ParameterizedTest
    void should_callback_with_content(StreamingChatLanguageModel streamingChatLanguageModel) throws Exception {
        ArrayList arrayList = new ArrayList();
        arrayList.add(Content.from("This is additional content"));
        RetrievalAugmentor retrievalAugmentor = (RetrievalAugmentor) Mockito.mock(RetrievalAugmentor.class);
        Mockito.when(retrievalAugmentor.augment((AugmentationRequest) Mockito.any())).thenReturn(AugmentationResult.builder().chatMessage(UserMessage.from("What is the capital of Germany?")).contents(arrayList).build());
        Assistant assistant = (Assistant) AiServices.builder(Assistant.class).streamingChatLanguageModel(streamingChatLanguageModel).retrievalAugmentor(retrievalAugmentor).build();
        StringBuilder sb = new StringBuilder();
        CompletableFuture completableFuture = new CompletableFuture();
        TokenStream chat = assistant.chat("What is the capital of Germany?");
        Objects.requireNonNull(sb);
        TokenStream onNext = chat.onNext(sb::append);
        Objects.requireNonNull(completableFuture);
        onNext.onRetrieved((v1) -> {
            r1.complete(v1);
        }).ignoreErrors().start();
        Assertions.assertThat((List) completableFuture.get(30L, TimeUnit.SECONDS)).hasSize(1).anySatisfy(content -> {
            Assertions.assertThat(content.textSegment().text()).isEqualTo("This is additional content");
        });
    }

    @MethodSource({"models"})
    @ParameterizedTest
    void should_stream_answers_with_memory(StreamingChatLanguageModel streamingChatLanguageModel) throws Exception {
        MessageWindowChatMemory withMaxMessages = MessageWindowChatMemory.withMaxMessages(10);
        Assistant assistant = (Assistant) AiServices.builder(Assistant.class).streamingChatLanguageModel(streamingChatLanguageModel).chatMemory(withMaxMessages).build();
        CompletableFuture completableFuture = new CompletableFuture();
        TokenStream chat = assistant.chat("Hi, my name is Klaus");
        PrintStream printStream = System.out;
        Objects.requireNonNull(printStream);
        TokenStream onNext = chat.onNext(printStream::println);
        Objects.requireNonNull(completableFuture);
        TokenStream onComplete = onNext.onComplete((v1) -> {
            r1.complete(v1);
        });
        Objects.requireNonNull(completableFuture);
        onComplete.onError(completableFuture::completeExceptionally).start();
        Response response = (Response) completableFuture.get(30L, TimeUnit.SECONDS);
        Assertions.assertThat(((AiMessage) response.content()).text()).contains(new CharSequence[]{"Klaus"});
        CompletableFuture completableFuture2 = new CompletableFuture();
        TokenStream chat2 = assistant.chat("What is my name?");
        PrintStream printStream2 = System.out;
        Objects.requireNonNull(printStream2);
        TokenStream onNext2 = chat2.onNext(printStream2::println);
        Objects.requireNonNull(completableFuture2);
        TokenStream onComplete2 = onNext2.onComplete((v1) -> {
            r1.complete(v1);
        });
        Objects.requireNonNull(completableFuture2);
        onComplete2.onError(completableFuture2::completeExceptionally).start();
        Response response2 = (Response) completableFuture2.get(30L, TimeUnit.SECONDS);
        Assertions.assertThat(((AiMessage) response2.content()).text()).contains(new CharSequence[]{"Klaus"});
        List messages = withMaxMessages.messages();
        Assertions.assertThat(messages).hasSize(4);
        Assertions.assertThat((ChatMessage) messages.get(0)).isInstanceOf(UserMessage.class);
        Assertions.assertThat(((ChatMessage) messages.get(0)).text()).isEqualTo("Hi, my name is Klaus");
        Assertions.assertThat((ChatMessage) messages.get(1)).isInstanceOf(AiMessage.class);
        Assertions.assertThat((ChatMessage) messages.get(1)).isEqualTo(response.content());
        Assertions.assertThat((ChatMessage) messages.get(2)).isInstanceOf(UserMessage.class);
        Assertions.assertThat(((ChatMessage) messages.get(2)).text()).isEqualTo("What is my name?");
        Assertions.assertThat((ChatMessage) messages.get(3)).isInstanceOf(AiMessage.class);
        Assertions.assertThat((ChatMessage) messages.get(3)).isEqualTo(response2.content());
    }

    @MethodSource({"models"})
    @ParameterizedTest
    void should_execute_a_tool_then_stream_answer(StreamingChatLanguageModel streamingChatLanguageModel) throws Exception {
        Calculator calculator = (Calculator) Mockito.spy(new Calculator());
        MessageWindowChatMemory withMaxMessages = MessageWindowChatMemory.withMaxMessages(10);
        Assistant assistant = (Assistant) AiServices.builder(Assistant.class).streamingChatLanguageModel(streamingChatLanguageModel).chatMemory(withMaxMessages).tools(new Object[]{calculator}).build();
        StringBuilder sb = new StringBuilder();
        CompletableFuture completableFuture = new CompletableFuture();
        CompletableFuture completableFuture2 = new CompletableFuture();
        TokenStream chat = assistant.chat("What is the square root of 485906798473894056 in scientific notation?");
        Objects.requireNonNull(sb);
        TokenStream onComplete = chat.onNext(sb::append).onComplete(response -> {
            completableFuture.complete(sb.toString());
            completableFuture2.complete(response);
        });
        Objects.requireNonNull(completableFuture);
        onComplete.onError(completableFuture::completeExceptionally).start();
        String str = (String) completableFuture.get(30L, TimeUnit.SECONDS);
        Response response2 = (Response) completableFuture2.get(30L, TimeUnit.SECONDS);
        Assertions.assertThat(str).contains(new CharSequence[]{"6.97"});
        Assertions.assertThat(((AiMessage) response2.content()).text()).isEqualTo(str);
        TokenUsage tokenUsage = response2.tokenUsage();
        Assertions.assertThat(tokenUsage.inputTokenCount()).isPositive();
        Assertions.assertThat(tokenUsage.outputTokenCount()).isPositive();
        Assertions.assertThat(tokenUsage.totalTokenCount()).isEqualTo(tokenUsage.inputTokenCount().intValue() + tokenUsage.outputTokenCount().intValue());
        Assertions.assertThat(response2.finishReason()).isEqualTo(FinishReason.STOP);
        ((Calculator) Mockito.verify(calculator)).squareRoot(4.859067984738941E17d);
        Mockito.verifyNoMoreInteractions(new Object[]{calculator});
        List messages = withMaxMessages.messages();
        Assertions.assertThat(messages).hasSize(4);
        Assertions.assertThat((ChatMessage) messages.get(0)).isInstanceOf(UserMessage.class);
        Assertions.assertThat(((ChatMessage) messages.get(0)).text()).isEqualTo("What is the square root of 485906798473894056 in scientific notation?");
        AiMessage aiMessage = (AiMessage) messages.get(1);
        Assertions.assertThat(aiMessage.text()).isNull();
        Assertions.assertThat(aiMessage.toolExecutionRequests()).hasSize(1);
        ToolExecutionRequest toolExecutionRequest = (ToolExecutionRequest) aiMessage.toolExecutionRequests().get(0);
        Assertions.assertThat(toolExecutionRequest.id()).isNotBlank();
        Assertions.assertThat(toolExecutionRequest.name()).isEqualTo("squareRoot");
        Assertions.assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"arg0\": 485906798473894056}");
        ToolExecutionResultMessage toolExecutionResultMessage = (ToolExecutionResultMessage) messages.get(2);
        Assertions.assertThat(toolExecutionResultMessage.id()).isEqualTo(toolExecutionRequest.id());
        Assertions.assertThat(toolExecutionResultMessage.toolName()).isEqualTo("squareRoot");
        Assertions.assertThat(toolExecutionResultMessage.text()).isEqualTo("6.97070153193991E8");
        Assertions.assertThat((ChatMessage) messages.get(3)).isInstanceOf(AiMessage.class);
        Assertions.assertThat(((ChatMessage) messages.get(3)).text()).contains(new CharSequence[]{"6.97"});
    }

    @Test
    void should_execute_multiple_tools_sequentially_then_answer() throws Exception {
        OpenAiStreamingChatModel build = OpenAiStreamingChatModel.builder().baseUrl(System.getenv("OPENAI_BASE_URL")).apiKey(System.getenv("OPENAI_API_KEY")).organizationId(System.getenv("OPENAI_ORGANIZATION_ID")).modelName(OpenAiChatModelName.GPT_4_O_MINI).parallelToolCalls(false).temperature(Double.valueOf(0.0d)).logRequests(true).logResponses(true).build();
        Calculator calculator = (Calculator) Mockito.spy(new Calculator());
        MessageWindowChatMemory withMaxMessages = MessageWindowChatMemory.withMaxMessages(10);
        Assistant assistant = (Assistant) AiServices.builder(Assistant.class).streamingChatLanguageModel(build).chatMemory(withMaxMessages).tools(new Object[]{calculator}).build();
        StringBuilder sb = new StringBuilder();
        CompletableFuture completableFuture = new CompletableFuture();
        CompletableFuture completableFuture2 = new CompletableFuture();
        TokenStream chat = assistant.chat("What is the square root of 485906798473894056 and 97866249624785 in scientific notation?");
        Objects.requireNonNull(sb);
        TokenStream onComplete = chat.onNext(sb::append).onComplete(response -> {
            completableFuture.complete(sb.toString());
            completableFuture2.complete(response);
        });
        Objects.requireNonNull(completableFuture);
        onComplete.onError(completableFuture::completeExceptionally).start();
        String str = (String) completableFuture.get(30L, TimeUnit.SECONDS);
        Response response2 = (Response) completableFuture2.get(30L, TimeUnit.SECONDS);
        Assertions.assertThat(str).contains(new CharSequence[]{"6.97", "9.89"});
        Assertions.assertThat(((AiMessage) response2.content()).text()).isEqualTo(str);
        TokenUsage tokenUsage = response2.tokenUsage();
        Assertions.assertThat(tokenUsage.inputTokenCount()).isPositive();
        Assertions.assertThat(tokenUsage.outputTokenCount()).isPositive();
        Assertions.assertThat(tokenUsage.totalTokenCount()).isEqualTo(tokenUsage.inputTokenCount().intValue() + tokenUsage.outputTokenCount().intValue());
        Assertions.assertThat(response2.finishReason()).isEqualTo(FinishReason.STOP);
        ((Calculator) Mockito.verify(calculator)).squareRoot(4.859067984738941E17d);
        ((Calculator) Mockito.verify(calculator)).squareRoot(9.7866249624785E13d);
        Mockito.verifyNoMoreInteractions(new Object[]{calculator});
        List messages = withMaxMessages.messages();
        Assertions.assertThat(messages).hasSize(6);
        Assertions.assertThat((ChatMessage) messages.get(0)).isInstanceOf(UserMessage.class);
        Assertions.assertThat(((ChatMessage) messages.get(0)).text()).isEqualTo("What is the square root of 485906798473894056 and 97866249624785 in scientific notation?");
        AiMessage aiMessage = (AiMessage) messages.get(1);
        Assertions.assertThat(aiMessage.text()).isNull();
        Assertions.assertThat(aiMessage.toolExecutionRequests()).hasSize(1);
        ToolExecutionRequest toolExecutionRequest = (ToolExecutionRequest) aiMessage.toolExecutionRequests().get(0);
        Assertions.assertThat(toolExecutionRequest.id()).isNotBlank();
        Assertions.assertThat(toolExecutionRequest.name()).isEqualTo("squareRoot");
        Assertions.assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"arg0\": 485906798473894056}");
        ToolExecutionResultMessage toolExecutionResultMessage = (ToolExecutionResultMessage) messages.get(2);
        Assertions.assertThat(toolExecutionResultMessage.id()).isEqualTo(toolExecutionRequest.id());
        Assertions.assertThat(toolExecutionResultMessage.toolName()).isEqualTo("squareRoot");
        Assertions.assertThat(toolExecutionResultMessage.text()).isEqualTo("6.97070153193991E8");
        AiMessage aiMessage2 = (AiMessage) messages.get(3);
        Assertions.assertThat(aiMessage2.text()).isNull();
        Assertions.assertThat(aiMessage2.toolExecutionRequests()).hasSize(1);
        ToolExecutionRequest toolExecutionRequest2 = (ToolExecutionRequest) aiMessage2.toolExecutionRequests().get(0);
        Assertions.assertThat(toolExecutionRequest2.id()).isNotBlank();
        Assertions.assertThat(toolExecutionRequest2.name()).isEqualTo("squareRoot");
        Assertions.assertThat(toolExecutionRequest2.arguments()).isEqualToIgnoringWhitespace("{\"arg0\": 97866249624785}");
        ToolExecutionResultMessage toolExecutionResultMessage2 = (ToolExecutionResultMessage) messages.get(4);
        Assertions.assertThat(toolExecutionResultMessage2.id()).isEqualTo(toolExecutionRequest2.id());
        Assertions.assertThat(toolExecutionResultMessage2.toolName()).isEqualTo("squareRoot");
        Assertions.assertThat(toolExecutionResultMessage2.text()).isEqualTo("9892737.215997653");
        Assertions.assertThat((ChatMessage) messages.get(5)).isInstanceOf(AiMessage.class);
        Assertions.assertThat(((ChatMessage) messages.get(5)).text()).contains(new CharSequence[]{"6.97", "9.89"});
    }

    @Test
    void should_execute_multiple_tools_in_parallel_then_answer() throws Exception {
        Calculator calculator = (Calculator) Mockito.spy(new Calculator());
        OpenAiStreamingChatModel build = OpenAiStreamingChatModel.builder().baseUrl(System.getenv("OPENAI_BASE_URL")).apiKey(System.getenv("OPENAI_API_KEY")).organizationId(System.getenv("OPENAI_ORGANIZATION_ID")).modelName(OpenAiChatModelName.GPT_4_O_MINI).temperature(Double.valueOf(0.0d)).logRequests(true).logResponses(true).build();
        MessageWindowChatMemory withMaxMessages = MessageWindowChatMemory.withMaxMessages(10);
        Assistant assistant = (Assistant) AiServices.builder(Assistant.class).streamingChatLanguageModel(build).chatMemory(withMaxMessages).tools(new Object[]{calculator}).build();
        StringBuilder sb = new StringBuilder();
        CompletableFuture completableFuture = new CompletableFuture();
        CompletableFuture completableFuture2 = new CompletableFuture();
        TokenStream chat = assistant.chat("What is the square root of 485906798473894056 and 97866249624785 in scientific notation?");
        Objects.requireNonNull(sb);
        TokenStream onComplete = chat.onNext(sb::append).onComplete(response -> {
            completableFuture.complete(sb.toString());
            completableFuture2.complete(response);
        });
        Objects.requireNonNull(completableFuture);
        onComplete.onError(completableFuture::completeExceptionally).start();
        String str = (String) completableFuture.get(30L, TimeUnit.SECONDS);
        Response response2 = (Response) completableFuture2.get(30L, TimeUnit.SECONDS);
        Assertions.assertThat(str).contains(new CharSequence[]{"6.97", "9.89"});
        Assertions.assertThat(((AiMessage) response2.content()).text()).isEqualTo(str);
        TokenUsage tokenUsage = response2.tokenUsage();
        Assertions.assertThat(tokenUsage.inputTokenCount()).isPositive();
        Assertions.assertThat(tokenUsage.outputTokenCount()).isPositive();
        Assertions.assertThat(tokenUsage.totalTokenCount()).isEqualTo(tokenUsage.inputTokenCount().intValue() + tokenUsage.outputTokenCount().intValue());
        Assertions.assertThat(response2.finishReason()).isEqualTo(FinishReason.STOP);
        ((Calculator) Mockito.verify(calculator)).squareRoot(4.859067984738941E17d);
        ((Calculator) Mockito.verify(calculator)).squareRoot(9.7866249624785E13d);
        Mockito.verifyNoMoreInteractions(new Object[]{calculator});
        List messages = withMaxMessages.messages();
        Assertions.assertThat(messages).hasSize(5);
        Assertions.assertThat((ChatMessage) messages.get(0)).isInstanceOf(UserMessage.class);
        Assertions.assertThat(((ChatMessage) messages.get(0)).text()).isEqualTo("What is the square root of 485906798473894056 and 97866249624785 in scientific notation?");
        AiMessage aiMessage = (AiMessage) messages.get(1);
        Assertions.assertThat(aiMessage.text()).isNull();
        Assertions.assertThat(aiMessage.toolExecutionRequests()).hasSize(2);
        ToolExecutionRequest toolExecutionRequest = (ToolExecutionRequest) aiMessage.toolExecutionRequests().get(0);
        Assertions.assertThat(toolExecutionRequest.id()).isNotBlank();
        Assertions.assertThat(toolExecutionRequest.name()).isEqualTo("squareRoot");
        Assertions.assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"arg0\": 485906798473894056}");
        ToolExecutionRequest toolExecutionRequest2 = (ToolExecutionRequest) aiMessage.toolExecutionRequests().get(1);
        Assertions.assertThat(toolExecutionRequest2.id()).isNotBlank();
        Assertions.assertThat(toolExecutionRequest2.name()).isEqualTo("squareRoot");
        Assertions.assertThat(toolExecutionRequest2.arguments()).isEqualToIgnoringWhitespace("{\"arg0\": 97866249624785}");
        ToolExecutionResultMessage toolExecutionResultMessage = (ToolExecutionResultMessage) messages.get(2);
        Assertions.assertThat(toolExecutionResultMessage.id()).isEqualTo(toolExecutionRequest.id());
        Assertions.assertThat(toolExecutionResultMessage.toolName()).isEqualTo("squareRoot");
        Assertions.assertThat(toolExecutionResultMessage.text()).isEqualTo("6.97070153193991E8");
        ToolExecutionResultMessage toolExecutionResultMessage2 = (ToolExecutionResultMessage) messages.get(3);
        Assertions.assertThat(toolExecutionResultMessage2.id()).isEqualTo(toolExecutionRequest2.id());
        Assertions.assertThat(toolExecutionResultMessage2.toolName()).isEqualTo("squareRoot");
        Assertions.assertThat(toolExecutionResultMessage2.text()).isEqualTo("9892737.215997653");
        Assertions.assertThat((ChatMessage) messages.get(4)).isInstanceOf(AiMessage.class);
        Assertions.assertThat(((ChatMessage) messages.get(4)).text()).contains(new CharSequence[]{"6.97", "9.89"});
    }
}
