package dev.langchain4j.chain;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.rag.DefaultRetrievalAugmentor;
import dev.langchain4j.rag.content.Content;
import dev.langchain4j.rag.content.injector.DefaultContentInjector;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.rag.query.Query;
import dev.langchain4j.retriever.Retriever;
import java.util.Arrays;
import java.util.List;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.Spy;
import org.mockito.junit.jupiter.MockitoExtension;

@ExtendWith({MockitoExtension.class})
/* loaded from: input_file:dev/langchain4j/chain/ConversationalRetrievalChainTest.class */
class ConversationalRetrievalChainTest {
    private static final String QUERY = "query";
    private static final String ANSWER = "answer";

    @Mock
    ChatLanguageModel chatLanguageModel;

    @Mock
    ContentRetriever contentRetriever;

    @Mock
    Retriever<TextSegment> retriever;

    @Spy
    ChatMemory chatMemory = MessageWindowChatMemory.withMaxMessages(10);

    @Captor
    ArgumentCaptor<List<ChatMessage>> messagesCaptor;

    ConversationalRetrievalChainTest() {
    }

    @BeforeEach
    void beforeEach() {
        Mockito.when(this.chatLanguageModel.generate(Mockito.anyList())).thenReturn(Response.from(AiMessage.aiMessage(ANSWER)));
    }

    @Test
    void should_inject_retrieved_segments() {
        Mockito.when(this.contentRetriever.retrieve((Query) Mockito.any())).thenReturn(Arrays.asList(Content.from("Segment 1"), Content.from("Segment 2")));
        Assertions.assertThat(ConversationalRetrievalChain.builder().chatLanguageModel(this.chatLanguageModel).chatMemory(this.chatMemory).contentRetriever(this.contentRetriever).build().execute(QUERY)).isEqualTo(ANSWER);
        ((ChatLanguageModel) Mockito.verify(this.chatLanguageModel)).generate((List) this.messagesCaptor.capture());
        ChatMessage from = UserMessage.from("query\n\nAnswer using the following information:\nSegment 1\n\nSegment 2");
        Assertions.assertThat((List) this.messagesCaptor.getValue()).containsExactly(new ChatMessage[]{from});
        Assertions.assertThat(this.chatMemory.messages()).containsExactly(new ChatMessage[]{from, AiMessage.from(ANSWER)});
    }

    @Test
    void should_inject_retrieved_segments_using_custom_prompt_template() {
        Mockito.when(this.contentRetriever.retrieve((Query) Mockito.any())).thenReturn(Arrays.asList(Content.from("Segment 1"), Content.from("Segment 2")));
        Assertions.assertThat(ConversationalRetrievalChain.builder().chatLanguageModel(this.chatLanguageModel).chatMemory(this.chatMemory).retrievalAugmentor(DefaultRetrievalAugmentor.builder().contentRetriever(this.contentRetriever).contentInjector(DefaultContentInjector.builder().promptTemplate(PromptTemplate.from("Answer '{{userMessage}}' using '{{contents}}'")).build()).build()).build().execute(QUERY)).isEqualTo(ANSWER);
        ((ChatLanguageModel) Mockito.verify(this.chatLanguageModel)).generate((List) this.messagesCaptor.capture());
        ChatMessage from = UserMessage.from("Answer 'query' using 'Segment 1\n\nSegment 2'");
        Assertions.assertThat((List) this.messagesCaptor.getValue()).containsExactly(new ChatMessage[]{from});
        Assertions.assertThat(this.chatMemory.messages()).containsExactly(new ChatMessage[]{from, AiMessage.from(ANSWER)});
    }

    @Test
    void test_backward_compatibility_should_inject_retrieved_segments() {
        Mockito.when(this.retriever.findRelevant(QUERY)).thenReturn(Arrays.asList(TextSegment.from("Segment 1"), TextSegment.from("Segment 2")));
        Mockito.when(this.retriever.toContentRetriever()).thenCallRealMethod();
        Assertions.assertThat(ConversationalRetrievalChain.builder().chatLanguageModel(this.chatLanguageModel).chatMemory(this.chatMemory).retriever(this.retriever).build().execute(QUERY)).isEqualTo(ANSWER);
        ((ChatLanguageModel) Mockito.verify(this.chatLanguageModel)).generate((List) this.messagesCaptor.capture());
        ChatMessage from = UserMessage.from("Answer the following question to the best of your ability: query\n\nBase your answer on the following information:\nSegment 1\n\nSegment 2");
        Assertions.assertThat((List) this.messagesCaptor.getValue()).containsExactly(new ChatMessage[]{from});
        Assertions.assertThat(this.chatMemory.messages()).containsExactly(new ChatMessage[]{from, AiMessage.from(ANSWER)});
    }

    @Test
    void test_backward_compatibility_should_inject_retrieved_segments_using_custom_prompt_template() {
        Mockito.when(this.retriever.findRelevant(QUERY)).thenReturn(Arrays.asList(TextSegment.from("Segment 1"), TextSegment.from("Segment 2")));
        Mockito.when(this.retriever.toContentRetriever()).thenCallRealMethod();
        Assertions.assertThat(ConversationalRetrievalChain.builder().chatLanguageModel(this.chatLanguageModel).chatMemory(this.chatMemory).promptTemplate(PromptTemplate.from("Answer '{{question}}' using '{{information}}'")).retriever(this.retriever).build().execute(QUERY)).isEqualTo(ANSWER);
        ((ChatLanguageModel) Mockito.verify(this.chatLanguageModel)).generate((List) this.messagesCaptor.capture());
        ChatMessage from = UserMessage.from("Answer 'query' using 'Segment 1\n\nSegment 2'");
        Assertions.assertThat((List) this.messagesCaptor.getValue()).containsExactly(new ChatMessage[]{from});
        Assertions.assertThat(this.chatMemory.messages()).containsExactly(new ChatMessage[]{from, AiMessage.from(ANSWER)});
    }
}
