package dev.langchain4j.service;

import dev.langchain4j.data.document.loader.FileSystemDocumentLoader;
import dev.langchain4j.data.document.parser.TextDocumentParser;
import dev.langchain4j.data.document.splitter.DocumentSplitters;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.openai.OpenAiChatModelName;
import dev.langchain4j.model.openai.OpenAiTokenizer;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.scoring.ScoringModel;
import dev.langchain4j.rag.DefaultRetrievalAugmentor;
import dev.langchain4j.rag.content.Content;
import dev.langchain4j.rag.content.aggregator.ReRankingContentAggregator;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever;
import dev.langchain4j.rag.query.Metadata;
import dev.langchain4j.rag.query.Query;
import dev.langchain4j.rag.query.router.LanguageModelQueryRouter;
import dev.langchain4j.rag.query.transformer.ExpandingQueryTransformer;
import dev.langchain4j.retriever.EmbeddingStoreRetriever;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
import dev.langchain4j.store.embedding.filter.Filter;
import dev.langchain4j.store.embedding.filter.MetadataFilterBuilder;
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
import java.net.URISyntaxException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.function.Function;
import java.util.stream.Stream;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.Mockito;

@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
/* loaded from: input_file:dev/langchain4j/service/AiServicesWithRagIT.class */
class AiServicesWithRagIT {
    private static final String ALLOWED_CANCELLATION_PERIOD_DAYS = "61";
    private static final String MIN_BOOKING_PERIOD_DAYS = "17";
    EmbeddingStore<TextSegment> embeddingStore = new InMemoryEmbeddingStore();
    EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:dev/langchain4j/service/AiServicesWithRagIT$Assistant.class */
    public interface Assistant {
        String answer(String str);
    }

    /* loaded from: input_file:dev/langchain4j/service/AiServicesWithRagIT$AssistantReturningResult.class */
    interface AssistantReturningResult {
        Result<String> answer(String str);
    }

    /* loaded from: input_file:dev/langchain4j/service/AiServicesWithRagIT$MultiUserAssistant.class */
    interface MultiUserAssistant {
        String answer(@MemoryId int i, @UserMessage String str);
    }

    /* loaded from: input_file:dev/langchain4j/service/AiServicesWithRagIT$PersonalizedAssistant.class */
    interface PersonalizedAssistant {
        String chat(@MemoryId String str, @UserMessage String str2);
    }

    AiServicesWithRagIT() {
    }

    @BeforeEach
    void beforeEach() {
        ingest("miles-of-smiles-terms-of-use.txt", this.embeddingStore, this.embeddingModel);
    }

    @MethodSource({"models"})
    @ParameterizedTest
    void should_use_content_retriever(ChatLanguageModel chatLanguageModel) {
        Assertions.assertThat(((Assistant) AiServices.builder(Assistant.class).chatLanguageModel(chatLanguageModel).contentRetriever(EmbeddingStoreContentRetriever.builder().embeddingStore(this.embeddingStore).embeddingModel(this.embeddingModel).maxResults(1).build()).build()).answer("Can I cancel my booking?")).containsAnyOf(new CharSequence[]{ALLOWED_CANCELLATION_PERIOD_DAYS, MIN_BOOKING_PERIOD_DAYS});
    }

    @MethodSource({"models"})
    @ParameterizedTest
    void should_use_content_retriever_and_chat_memory(ChatLanguageModel chatLanguageModel) {
        ContentRetriever contentRetriever = (ContentRetriever) Mockito.spy(EmbeddingStoreContentRetriever.builder().embeddingStore(this.embeddingStore).embeddingModel(this.embeddingModel).maxResults(1).build());
        MessageWindowChatMemory withMaxMessages = MessageWindowChatMemory.withMaxMessages(10);
        ChatMessage from = UserMessage.from("Hello");
        withMaxMessages.add(from);
        ChatMessage from2 = AiMessage.from("Hi, how can I help you today?");
        withMaxMessages.add(from2);
        Assertions.assertThat(((Assistant) AiServices.builder(Assistant.class).chatLanguageModel(chatLanguageModel).contentRetriever(contentRetriever).chatMemory(withMaxMessages).build()).answer("In which cases can I cancel my booking?")).containsAnyOf(new CharSequence[]{ALLOWED_CANCELLATION_PERIOD_DAYS, MIN_BOOKING_PERIOD_DAYS});
        ((ContentRetriever) Mockito.verify(contentRetriever)).retrieve(Query.from("In which cases can I cancel my booking?", Metadata.from(UserMessage.from("In which cases can I cancel my booking?"), "default", Arrays.asList(from, from2))));
        Mockito.verifyNoMoreInteractions(new Object[]{contentRetriever});
    }

    @MethodSource({"models"})
    @ParameterizedTest
    void should_use_content_retriever_and_chat_memory_provider(ChatLanguageModel chatLanguageModel) {
        ContentRetriever contentRetriever = (ContentRetriever) Mockito.spy(EmbeddingStoreContentRetriever.builder().embeddingStore(this.embeddingStore).embeddingModel(this.embeddingModel).maxResults(1).build());
        Assertions.assertThat(((MultiUserAssistant) AiServices.builder(MultiUserAssistant.class).chatLanguageModel(chatLanguageModel).contentRetriever(contentRetriever).chatMemoryProvider(obj -> {
            return MessageWindowChatMemory.withMaxMessages(10);
        }).build()).answer(1, "Can I cancel my booking?")).containsAnyOf(new CharSequence[]{ALLOWED_CANCELLATION_PERIOD_DAYS, MIN_BOOKING_PERIOD_DAYS});
        ((ContentRetriever) Mockito.verify(contentRetriever)).retrieve(Query.from("Can I cancel my booking?", Metadata.from(UserMessage.from("Can I cancel my booking?"), 1, Collections.emptyList())));
        Mockito.verifyNoMoreInteractions(new Object[]{contentRetriever});
    }

    @MethodSource({"models"})
    @ParameterizedTest
    void should_use_query_transformer_and_content_retriever(ChatLanguageModel chatLanguageModel) {
        Assertions.assertThat(((Assistant) AiServices.builder(Assistant.class).chatLanguageModel(chatLanguageModel).retrievalAugmentor(DefaultRetrievalAugmentor.builder().queryTransformer(new ExpandingQueryTransformer(chatLanguageModel)).contentRetriever(EmbeddingStoreContentRetriever.builder().embeddingStore(this.embeddingStore).embeddingModel(this.embeddingModel).maxResults(1).build()).build()).build()).answer("Can I cancel my booking?")).containsAnyOf(new CharSequence[]{ALLOWED_CANCELLATION_PERIOD_DAYS, MIN_BOOKING_PERIOD_DAYS});
    }

    @MethodSource({"models"})
    @ParameterizedTest
    void should_use_query_router_and_content_retriever(ChatLanguageModel chatLanguageModel) {
        EmbeddingStoreContentRetriever build = EmbeddingStoreContentRetriever.builder().embeddingStore(this.embeddingStore).embeddingModel(this.embeddingModel).maxResults(1).build();
        ContentRetriever contentRetriever = query -> {
            throw new RuntimeException("Should never be called");
        };
        HashMap hashMap = new HashMap();
        hashMap.put(build, "car rental company terms of use");
        hashMap.put(contentRetriever, "articles about cats");
        Assertions.assertThat(((Assistant) AiServices.builder(Assistant.class).chatLanguageModel(chatLanguageModel).retrievalAugmentor(DefaultRetrievalAugmentor.builder().queryRouter(new LanguageModelQueryRouter(chatLanguageModel, hashMap)).build()).build()).answer("Can I cancel my booking?")).containsAnyOf(new CharSequence[]{ALLOWED_CANCELLATION_PERIOD_DAYS, MIN_BOOKING_PERIOD_DAYS});
    }

    @Disabled("TODO fix")
    @MethodSource({"models"})
    @ParameterizedTest
    void should_not_route_when_query_is_ambiguous(ChatLanguageModel chatLanguageModel) {
        ContentRetriever contentRetriever = (ContentRetriever) Mockito.mock(ContentRetriever.class);
        HashMap hashMap = new HashMap();
        hashMap.put(contentRetriever, "articles about cats");
        Assertions.assertThat(((Assistant) AiServices.builder(Assistant.class).chatLanguageModel(chatLanguageModel).retrievalAugmentor(DefaultRetrievalAugmentor.builder().queryRouter(new LanguageModelQueryRouter(chatLanguageModel, hashMap)).build()).build()).answer("Hey what's up?")).isNotBlank();
        Mockito.verifyNoInteractions(new Object[]{contentRetriever});
    }

    @MethodSource({"models"})
    @ParameterizedTest
    void should_route_to_all_retrievers_when_query_is_ambiguous(ChatLanguageModel chatLanguageModel) {
        LanguageModelQueryRouter.FallbackStrategy fallbackStrategy = LanguageModelQueryRouter.FallbackStrategy.ROUTE_TO_ALL;
        ContentRetriever contentRetriever = (ContentRetriever) Mockito.spy(EmbeddingStoreContentRetriever.builder().embeddingStore(this.embeddingStore).embeddingModel(this.embeddingModel).maxResults(1).build());
        HashMap hashMap = new HashMap();
        hashMap.put(contentRetriever, "car rental company terms of use");
        Assertions.assertThat(((Assistant) AiServices.builder(Assistant.class).chatLanguageModel(chatLanguageModel).retrievalAugmentor(DefaultRetrievalAugmentor.builder().queryRouter(LanguageModelQueryRouter.builder().chatLanguageModel(chatLanguageModel).retrieverToDescription(hashMap).fallbackStrategy(fallbackStrategy).build()).build()).build()).answer("Hey what's up?")).isNotBlank();
        ((ContentRetriever) Mockito.verify(contentRetriever)).retrieve(Query.from("Hey what's up?", Metadata.from(UserMessage.from("Hey what's up?"), "default", (List) null)));
        Mockito.verifyNoMoreInteractions(new Object[]{contentRetriever});
    }

    @Disabled("Fixed in https://github.com/langchain4j/langchain4j/pull/2311")
    @MethodSource({"models"})
    @ParameterizedTest
    void should_fail_when_query_is_ambiguous(ChatLanguageModel chatLanguageModel) {
        String str = "Hey what's up?";
        LanguageModelQueryRouter.FallbackStrategy fallbackStrategy = LanguageModelQueryRouter.FallbackStrategy.FAIL;
        ContentRetriever contentRetriever = (ContentRetriever) Mockito.spy(EmbeddingStoreContentRetriever.builder().embeddingStore(this.embeddingStore).embeddingModel(this.embeddingModel).maxResults(1).build());
        HashMap hashMap = new HashMap();
        hashMap.put(contentRetriever, "car rental company terms of use");
        Assistant assistant = (Assistant) AiServices.builder(Assistant.class).chatLanguageModel(chatLanguageModel).retrievalAugmentor(DefaultRetrievalAugmentor.builder().queryRouter(LanguageModelQueryRouter.builder().chatLanguageModel(chatLanguageModel).retrieverToDescription(hashMap).fallbackStrategy(fallbackStrategy).build()).build()).build();
        Assertions.assertThatThrownBy(() -> {
            assistant.answer(str);
        }).hasRootCauseExactlyInstanceOf(NumberFormatException.class);
        Mockito.verifyNoInteractions(new Object[]{contentRetriever});
    }

    @MethodSource({"models"})
    @ParameterizedTest
    void should_use_content_retriever_and_content_aggregator(ChatLanguageModel chatLanguageModel) {
        EmbeddingStoreContentRetriever build = EmbeddingStoreContentRetriever.builder().embeddingStore(this.embeddingStore).embeddingModel(this.embeddingModel).maxResults(2).build();
        ScoringModel scoringModel = (ScoringModel) Mockito.mock(ScoringModel.class);
        Mockito.when(scoringModel.scoreAll((List) Mockito.any(), (String) Mockito.any())).thenReturn(Response.from(Arrays.asList(Double.valueOf(0.9d), Double.valueOf(0.7d))));
        Assertions.assertThat(((Assistant) AiServices.builder(Assistant.class).chatLanguageModel(chatLanguageModel).retrievalAugmentor(DefaultRetrievalAugmentor.builder().contentRetriever(build).contentAggregator(new ReRankingContentAggregator(scoringModel)).build()).build()).answer("Can I cancel my booking?")).containsAnyOf(new CharSequence[]{ALLOWED_CANCELLATION_PERIOD_DAYS, MIN_BOOKING_PERIOD_DAYS});
    }

    @MethodSource({"models"})
    @ParameterizedTest
    void should_use_all_rag_components(ChatLanguageModel chatLanguageModel) {
        ExpandingQueryTransformer expandingQueryTransformer = new ExpandingQueryTransformer(chatLanguageModel);
        EmbeddingStoreContentRetriever build = EmbeddingStoreContentRetriever.builder().embeddingStore(this.embeddingStore).embeddingModel(this.embeddingModel).maxResults(2).build();
        ContentRetriever contentRetriever = query -> {
            throw new RuntimeException("Should never be called");
        };
        HashMap hashMap = new HashMap();
        hashMap.put(build, "car rental company terms of use");
        hashMap.put(contentRetriever, "articles about unicorns");
        LanguageModelQueryRouter languageModelQueryRouter = new LanguageModelQueryRouter(chatLanguageModel, hashMap);
        ScoringModel scoringModel = (ScoringModel) Mockito.mock(ScoringModel.class);
        Mockito.when(scoringModel.scoreAll((List) Mockito.any(), (String) Mockito.any())).thenReturn(Response.from(Arrays.asList(Double.valueOf(0.9d), Double.valueOf(0.7d))));
        Assertions.assertThat(((Assistant) AiServices.builder(Assistant.class).chatLanguageModel(chatLanguageModel).retrievalAugmentor(DefaultRetrievalAugmentor.builder().queryTransformer(expandingQueryTransformer).queryRouter(languageModelQueryRouter).contentAggregator(ReRankingContentAggregator.builder().scoringModel(scoringModel).querySelector(map -> {
            return (Query) map.keySet().iterator().next();
        }).build()).build()).build()).answer("Can I cancel my booking?")).containsAnyOf(new CharSequence[]{ALLOWED_CANCELLATION_PERIOD_DAYS, MIN_BOOKING_PERIOD_DAYS});
    }

    @MethodSource({"models"})
    @ParameterizedTest
    void should_use_dynamicFilter_by_user_id(ChatLanguageModel chatLanguageModel) {
        TextSegment from = TextSegment.from("My favorite color is green", dev.langchain4j.data.document.Metadata.metadata("userId", "1"));
        TextSegment from2 = TextSegment.from("My favorite color is red", dev.langchain4j.data.document.Metadata.metadata("userId", "2"));
        Function function = query -> {
            return MetadataFilterBuilder.metadataKey("userId").isEqualTo(query.metadata().chatMemoryId().toString());
        };
        InMemoryEmbeddingStore inMemoryEmbeddingStore = new InMemoryEmbeddingStore();
        inMemoryEmbeddingStore.add((Embedding) this.embeddingModel.embed(from).content(), from);
        inMemoryEmbeddingStore.add((Embedding) this.embeddingModel.embed(from2).content(), from2);
        PersonalizedAssistant personalizedAssistant = (PersonalizedAssistant) AiServices.builder(PersonalizedAssistant.class).chatLanguageModel(chatLanguageModel).contentRetriever(EmbeddingStoreContentRetriever.builder().embeddingStore(inMemoryEmbeddingStore).embeddingModel(this.embeddingModel).dynamicFilter(function).build()).chatMemoryProvider(obj -> {
            return MessageWindowChatMemory.withMaxMessages(10);
        }).build();
        Assertions.assertThat(personalizedAssistant.chat("1", "Which color would be best for a dress?")).containsIgnoringCase("green");
        Assertions.assertThat(personalizedAssistant.chat("2", "Which color would be best for a dress?")).containsIgnoringCase("red");
    }

    @MethodSource({"models"})
    @ParameterizedTest
    void should_use_static_metadata_filter(ChatLanguageModel chatLanguageModel) {
        TextSegment from = TextSegment.from("cats", dev.langchain4j.data.document.Metadata.metadata("animal", "cat"));
        TextSegment from2 = TextSegment.from("dogs", dev.langchain4j.data.document.Metadata.metadata("animal", "dog"));
        Filter isEqualTo = MetadataFilterBuilder.metadataKey("animal").isEqualTo("dog");
        InMemoryEmbeddingStore inMemoryEmbeddingStore = new InMemoryEmbeddingStore();
        inMemoryEmbeddingStore.add((Embedding) this.embeddingModel.embed(from).content(), from);
        inMemoryEmbeddingStore.add((Embedding) this.embeddingModel.embed(from2).content(), from2);
        Assertions.assertThat(((Assistant) AiServices.builder(Assistant.class).chatLanguageModel(chatLanguageModel).contentRetriever(EmbeddingStoreContentRetriever.builder().embeddingStore(inMemoryEmbeddingStore).embeddingModel(this.embeddingModel).filter(isEqualTo).build()).build()).answer("Which animal is mentioned?")).containsIgnoringCase("dog");
    }

    @MethodSource({"models"})
    @ParameterizedTest
    void should_use_legacy_retriever(ChatLanguageModel chatLanguageModel) {
        Assertions.assertThat(((Assistant) AiServices.builder(Assistant.class).chatLanguageModel(chatLanguageModel).retriever(EmbeddingStoreRetriever.from(this.embeddingStore, this.embeddingModel, 1)).build()).answer("Can I cancel my booking?")).containsAnyOf(new CharSequence[]{ALLOWED_CANCELLATION_PERIOD_DAYS, MIN_BOOKING_PERIOD_DAYS});
    }

    @MethodSource({"models"})
    @ParameterizedTest
    void should_use_content_retriever_and_return_sources_inside_result(ChatLanguageModel chatLanguageModel) {
        Result<String> answer = ((AssistantReturningResult) AiServices.builder(AssistantReturningResult.class).chatLanguageModel(chatLanguageModel).contentRetriever(EmbeddingStoreContentRetriever.builder().embeddingStore(this.embeddingStore).embeddingModel(this.embeddingModel).maxResults(1).build()).build()).answer("Can I cancel my booking?");
        Assertions.assertThat((String) answer.content()).containsAnyOf(new CharSequence[]{ALLOWED_CANCELLATION_PERIOD_DAYS, MIN_BOOKING_PERIOD_DAYS});
        Assertions.assertThat(answer.tokenUsage()).isNotNull();
        Assertions.assertThat(answer.sources()).hasSize(1);
        Content content = (Content) answer.sources().get(0);
        Assertions.assertThat(content.textSegment().text()).isEqualToIgnoringWhitespace("4. Cancellation Policy4.1 Reservations can be cancelled up to 61 days prior to the start of the booking period.4.2 If the booking period is less than 17 days, cancellations are not permitted.");
        Assertions.assertThat(content.textSegment().metadata("index")).isEqualTo("3");
        Assertions.assertThat(content.textSegment().metadata("file_name")).isEqualTo("miles-of-smiles-terms-of-use.txt");
    }

    private void ingest(String str, EmbeddingStore<TextSegment> embeddingStore, EmbeddingModel embeddingModel) {
        EmbeddingStoreIngestor.builder().documentSplitter(DocumentSplitters.recursive(100, 0, new OpenAiTokenizer())).embeddingModel(embeddingModel).embeddingStore(embeddingStore).build().ingest(FileSystemDocumentLoader.loadDocument(toPath(str), new TextDocumentParser()));
    }

    static Stream<Arguments> models() {
        return Stream.of(Arguments.of(new Object[]{OpenAiChatModel.builder().baseUrl(System.getenv("OPENAI_BASE_URL")).apiKey(System.getenv("OPENAI_API_KEY")).organizationId(System.getenv("OPENAI_ORGANIZATION_ID")).modelName(OpenAiChatModelName.GPT_4_O_MINI).logRequests(true).logResponses(true).build()}));
    }

    private Path toPath(String str) {
        try {
            return Paths.get(getClass().getClassLoader().getResource(str).toURI());
        } catch (URISyntaxException e) {
            throw new RuntimeException(e);
        }
    }
}
