package dev.langchain4j.service;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.agent.tool.P;
import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
import dev.langchain4j.model.openai.OpenAiChatModelName;
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.service.tool.ToolExecution;
import dev.langchain4j.service.tool.ToolExecutor;
import dev.langchain4j.service.tool.ToolProvider;
import dev.langchain4j.service.tool.ToolProviderResult;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.stream.Stream;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.Mockito;

@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
/* loaded from: input_file:dev/langchain4j/service/StreamingAiServicesWithToolsIT.class */
class StreamingAiServicesWithToolsIT {

    /* loaded from: input_file:dev/langchain4j/service/StreamingAiServicesWithToolsIT$Assistant.class */
    interface Assistant {
        TokenStream chat(String str);
    }

    /* loaded from: input_file:dev/langchain4j/service/StreamingAiServicesWithToolsIT$TemperatureUnit.class */
    enum TemperatureUnit {
        CELSIUS,
        fahrenheit,
        Kelvin
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:dev/langchain4j/service/StreamingAiServicesWithToolsIT$TransactionService.class */
    public static class TransactionService {
        static ToolSpecification EXPECTED_SPECIFICATION = ToolSpecification.builder().name("getTransactionAmount").description("returns amount of a given transaction").parameters(JsonObjectSchema.builder().addStringProperty("arg0", "ID of a transaction").required(new String[]{"arg0"}).build()).build();

        TransactionService() {
        }

        @Tool({"returns amount of a given transaction"})
        Double getTransactionAmount(@P("ID of a transaction") String str) {
            System.out.printf("called getTransactionAmount(%s)%n", str);
            boolean z = -1;
            switch (str.hashCode()) {
                case 2550109:
                    if (str.equals("T001")) {
                        z = false;
                        break;
                    }
                    break;
                case 2550110:
                    if (str.equals("T002")) {
                        z = true;
                        break;
                    }
                    break;
            }
            switch (z) {
                case false:
                    return Double.valueOf(11.1d);
                case true:
                    return Double.valueOf(22.2d);
                default:
                    throw new IllegalArgumentException("Unknown transaction ID: " + str);
            }
        }
    }

    /* loaded from: input_file:dev/langchain4j/service/StreamingAiServicesWithToolsIT$TransactionServiceExecutor.class */
    static class TransactionServiceExecutor implements ToolExecutor {
        private final TransactionService transactionService = new TransactionService();

        TransactionServiceExecutor() {
        }

        public String execute(ToolExecutionRequest toolExecutionRequest, Object obj) {
            return this.transactionService.getTransactionAmount(StreamingAiServicesWithToolsIT.toMap(toolExecutionRequest.arguments()).get("arg0").toString()).toString();
        }
    }

    /* loaded from: input_file:dev/langchain4j/service/StreamingAiServicesWithToolsIT$WeatherService.class */
    static class WeatherService {
        static ToolSpecification EXPECTED_SPECIFICATION = ToolSpecification.builder().name("currentTemperature").parameters(JsonObjectSchema.builder().addStringProperty("arg0").addEnumProperty("arg1", List.of("CELSIUS", "fahrenheit", "Kelvin")).required(new String[]{"arg0", "arg1"}).build()).build();
        static final int TEMPERATURE = 19;

        WeatherService() {
        }

        @Tool
        int currentTemperature(String str, TemperatureUnit temperatureUnit) {
            System.out.printf("called currentTemperature(%s, %s)%n", str, temperatureUnit);
            return TEMPERATURE;
        }
    }

    StreamingAiServicesWithToolsIT() {
    }

    static Stream<StreamingChatLanguageModel> models() {
        return Stream.of(OpenAiStreamingChatModel.builder().baseUrl(System.getenv("OPENAI_BASE_URL")).apiKey(System.getenv("OPENAI_API_KEY")).organizationId(System.getenv("OPENAI_ORGANIZATION_ID")).modelName(OpenAiChatModelName.GPT_4_O_MINI).temperature(Double.valueOf(0.0d)).logRequests(true).logResponses(true).build());
    }

    @MethodSource({"models"})
    @ParameterizedTest
    void should_execute_a_tool_then_answer(StreamingChatLanguageModel streamingChatLanguageModel) throws Exception {
        TransactionService transactionService = (TransactionService) Mockito.spy(new TransactionService());
        MessageWindowChatMemory withMaxMessages = MessageWindowChatMemory.withMaxMessages(10);
        StreamingChatLanguageModel streamingChatLanguageModel2 = (StreamingChatLanguageModel) Mockito.spy(streamingChatLanguageModel);
        Assistant assistant = (Assistant) AiServices.builder(Assistant.class).streamingChatLanguageModel(streamingChatLanguageModel2).chatMemory(withMaxMessages).tools(new Object[]{transactionService}).build();
        CompletableFuture completableFuture = new CompletableFuture();
        TokenStream onNext = assistant.chat("What is the amounts of transaction T001?").onNext(str -> {
        });
        Objects.requireNonNull(completableFuture);
        TokenStream onComplete = onNext.onComplete((v1) -> {
            r1.complete(v1);
        });
        Objects.requireNonNull(completableFuture);
        onComplete.onError(completableFuture::completeExceptionally).start();
        Assertions.assertThat(((AiMessage) ((Response) completableFuture.get(60L, TimeUnit.SECONDS)).content()).text()).contains(new CharSequence[]{"11.1"});
        ((TransactionService) Mockito.verify(transactionService)).getTransactionAmount("T001");
        Mockito.verifyNoMoreInteractions(new Object[]{transactionService});
        List messages = withMaxMessages.messages();
        ((StreamingChatLanguageModel) Mockito.verify(streamingChatLanguageModel2)).chat((ChatRequest) Mockito.eq(ChatRequest.builder().messages(new ChatMessage[]{(ChatMessage) messages.get(0)}).toolSpecifications(new ToolSpecification[]{TransactionService.EXPECTED_SPECIFICATION}).build()), (StreamingChatResponseHandler) Mockito.any());
        ((StreamingChatLanguageModel) Mockito.verify(streamingChatLanguageModel2)).chat((ChatRequest) Mockito.eq(ChatRequest.builder().messages(new ChatMessage[]{(ChatMessage) messages.get(0), (ChatMessage) messages.get(1), (ChatMessage) messages.get(2)}).toolSpecifications(new ToolSpecification[]{TransactionService.EXPECTED_SPECIFICATION}).build()), (StreamingChatResponseHandler) Mockito.any());
    }

    @MethodSource({"models"})
    @ParameterizedTest
    void should_use_tool_with_enum_parameter(StreamingChatLanguageModel streamingChatLanguageModel) throws Exception {
        WeatherService weatherService = (WeatherService) Mockito.spy(new WeatherService());
        MessageWindowChatMemory withMaxMessages = MessageWindowChatMemory.withMaxMessages(10);
        StreamingChatLanguageModel streamingChatLanguageModel2 = (StreamingChatLanguageModel) Mockito.spy(streamingChatLanguageModel);
        Assistant assistant = (Assistant) AiServices.builder(Assistant.class).streamingChatLanguageModel(streamingChatLanguageModel2).chatMemory(withMaxMessages).tools(new Object[]{weatherService}).build();
        CompletableFuture completableFuture = new CompletableFuture();
        TokenStream onNext = assistant.chat("What is the temperature in Munich now, in Celsius?").onNext(str -> {
        });
        Objects.requireNonNull(completableFuture);
        TokenStream onComplete = onNext.onComplete((v1) -> {
            r1.complete(v1);
        });
        Objects.requireNonNull(completableFuture);
        onComplete.onError(completableFuture::completeExceptionally).start();
        Assertions.assertThat(((AiMessage) ((Response) completableFuture.get(60L, TimeUnit.SECONDS)).content()).text()).contains(new CharSequence[]{String.valueOf(19)});
        ((WeatherService) Mockito.verify(weatherService)).currentTemperature("Munich", TemperatureUnit.CELSIUS);
        Mockito.verifyNoMoreInteractions(new Object[]{weatherService});
        List messages = withMaxMessages.messages();
        ((StreamingChatLanguageModel) Mockito.verify(streamingChatLanguageModel2)).chat((ChatRequest) Mockito.eq(ChatRequest.builder().messages(new ChatMessage[]{(ChatMessage) messages.get(0)}).toolSpecifications(new ToolSpecification[]{WeatherService.EXPECTED_SPECIFICATION}).build()), (StreamingChatResponseHandler) Mockito.any());
        ((StreamingChatLanguageModel) Mockito.verify(streamingChatLanguageModel2)).chat((ChatRequest) Mockito.eq(ChatRequest.builder().messages(new ChatMessage[]{(ChatMessage) messages.get(0), (ChatMessage) messages.get(1), (ChatMessage) messages.get(2)}).toolSpecifications(new ToolSpecification[]{WeatherService.EXPECTED_SPECIFICATION}).build()), (StreamingChatResponseHandler) Mockito.any());
    }

    @Test
    void should_use_tool_provider() throws Exception {
        ToolExecutor toolExecutor = (ToolExecutor) Mockito.spy(new TransactionServiceExecutor());
        ToolProvider toolProvider = toolProviderRequest -> {
            return ToolProviderResult.builder().add(TransactionService.EXPECTED_SPECIFICATION, toolExecutor).build();
        };
        StreamingChatLanguageModel streamingChatLanguageModel = (StreamingChatLanguageModel) Mockito.spy(models().findFirst().get());
        MessageWindowChatMemory withMaxMessages = MessageWindowChatMemory.withMaxMessages(10);
        Assistant assistant = (Assistant) AiServices.builder(Assistant.class).streamingChatLanguageModel(streamingChatLanguageModel).chatMemory(withMaxMessages).toolProvider(toolProvider).build();
        CompletableFuture completableFuture = new CompletableFuture();
        TokenStream onNext = assistant.chat("What is the amounts of transactions T001?").onNext(str -> {
        });
        Objects.requireNonNull(completableFuture);
        TokenStream onComplete = onNext.onComplete((v1) -> {
            r1.complete(v1);
        });
        Objects.requireNonNull(completableFuture);
        onComplete.onError(completableFuture::completeExceptionally).start();
        Assertions.assertThat(((AiMessage) ((Response) completableFuture.get(60L, TimeUnit.SECONDS)).content()).text()).contains(new CharSequence[]{"11.1"});
        ((ToolExecutor) Mockito.verify(toolExecutor)).execute((ToolExecutionRequest) Mockito.any(), Mockito.any());
        Mockito.verifyNoMoreInteractions(new Object[]{toolExecutor});
        List messages = withMaxMessages.messages();
        ((StreamingChatLanguageModel) Mockito.verify(streamingChatLanguageModel)).chat((ChatRequest) Mockito.eq(ChatRequest.builder().messages(new ChatMessage[]{(ChatMessage) messages.get(0)}).toolSpecifications(new ToolSpecification[]{TransactionService.EXPECTED_SPECIFICATION}).build()), (StreamingChatResponseHandler) Mockito.any());
        ((StreamingChatLanguageModel) Mockito.verify(streamingChatLanguageModel)).chat((ChatRequest) Mockito.eq(ChatRequest.builder().messages(new ChatMessage[]{(ChatMessage) messages.get(0), (ChatMessage) messages.get(1), (ChatMessage) messages.get(2)}).toolSpecifications(new ToolSpecification[]{TransactionService.EXPECTED_SPECIFICATION}).build()), (StreamingChatResponseHandler) Mockito.any());
        verifyNoMoreInteractionsFor(streamingChatLanguageModel);
    }

    private static Map<String, Object> toMap(String str) {
        try {
            return (Map) new ObjectMapper().readValue(str, new TypeReference<Map<String, Object>>() { // from class: dev.langchain4j.service.StreamingAiServicesWithToolsIT.1
            });
        } catch (JsonProcessingException e) {
            throw new RuntimeException((Throwable) e);
        }
    }

    @Test
    void should_invoke_tool_execution_handler() throws Exception {
        WeatherService weatherService = (WeatherService) Mockito.spy(new WeatherService());
        Assistant assistant = (Assistant) AiServices.builder(Assistant.class).streamingChatLanguageModel((StreamingChatLanguageModel) Mockito.spy(models().findFirst().get())).chatMemory(MessageWindowChatMemory.withMaxMessages(10)).tools(new Object[]{weatherService}).build();
        ArrayList arrayList = new ArrayList();
        CompletableFuture completableFuture = new CompletableFuture();
        TokenStream onNext = assistant.chat("What is the temperature in Munich and London, in Celsius?").onNext(str -> {
        });
        Objects.requireNonNull(arrayList);
        TokenStream onToolExecuted = onNext.onToolExecuted((v1) -> {
            r1.add(v1);
        });
        Objects.requireNonNull(completableFuture);
        TokenStream onComplete = onToolExecuted.onComplete((v1) -> {
            r1.complete(v1);
        });
        Objects.requireNonNull(completableFuture);
        onComplete.onError(completableFuture::completeExceptionally).start();
        Assertions.assertThat(((AiMessage) ((Response) completableFuture.get(60L, TimeUnit.SECONDS)).content()).text()).contains(new CharSequence[]{String.valueOf(19)});
        ((WeatherService) Mockito.verify(weatherService)).currentTemperature("Munich", TemperatureUnit.CELSIUS);
        ((WeatherService) Mockito.verify(weatherService)).currentTemperature("London", TemperatureUnit.CELSIUS);
        Mockito.verifyNoMoreInteractions(new Object[]{weatherService});
        Assertions.assertThat(arrayList).hasSize(2);
        Assertions.assertThat(((ToolExecution) arrayList.get(0)).request().name()).isEqualTo("currentTemperature");
        Assertions.assertThat(((ToolExecution) arrayList.get(0)).request().arguments()).isEqualToIgnoringWhitespace("{\"arg0\":\"Munich\", \"arg1\": \"CELSIUS\"}");
        Assertions.assertThat(((ToolExecution) arrayList.get(0)).result()).isEqualTo(String.valueOf(19));
        Assertions.assertThat(((ToolExecution) arrayList.get(1)).request().name()).isEqualTo("currentTemperature");
        Assertions.assertThat(((ToolExecution) arrayList.get(1)).request().arguments()).isEqualToIgnoringWhitespace("{\"arg0\":\"London\", \"arg1\":\"CELSIUS\"}");
        Assertions.assertThat(((ToolExecution) arrayList.get(1)).result()).isEqualTo(String.valueOf(19));
    }

    public static void verifyNoMoreInteractionsFor(StreamingChatLanguageModel streamingChatLanguageModel) {
        try {
            ((StreamingChatLanguageModel) Mockito.verify(streamingChatLanguageModel, Mockito.atLeastOnce())).doChat((ChatRequest) Mockito.any(), (StreamingChatResponseHandler) Mockito.any());
        } catch (Throwable th) {
        }
        try {
            ((StreamingChatLanguageModel) Mockito.verify(streamingChatLanguageModel, Mockito.atLeastOnce())).defaultRequestParameters();
        } catch (Throwable th2) {
        }
        try {
            ((StreamingChatLanguageModel) Mockito.verify(streamingChatLanguageModel, Mockito.atLeastOnce())).supportedCapabilities();
        } catch (Throwable th3) {
        }
        try {
            ((StreamingChatLanguageModel) Mockito.verify(streamingChatLanguageModel, Mockito.atLeastOnce())).listeners();
        } catch (Throwable th4) {
        }
        Mockito.verifyNoMoreInteractions(new Object[]{streamingChatLanguageModel});
    }
}
