package dev.langchain4j.store.embedding.filter.builder.sql;

import dev.langchain4j.Experimental;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.rag.query.Query;
import dev.langchain4j.store.embedding.filter.Filter;
import dev.langchain4j.store.embedding.filter.parser.sql.SqlFilterParser;
import java.util.HashMap;
import lombok.Generated;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Experimental
/* loaded from: input_file:dev/langchain4j/store/embedding/filter/builder/sql/LanguageModelSqlFilterBuilder.class */
public class LanguageModelSqlFilterBuilder {
    private static final Logger log = LoggerFactory.getLogger(LanguageModelSqlFilterBuilder.class);
    private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from("### Instructions:\nYour task is to convert a question into a SQL query, given a Postgres database schema.\nAdhere to these rules:\n- **Deliberately go through the question and database schema word by word** to appropriately answer the question\n- **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`.\n- When creating a ratio, always cast the numerator as float\n\n### Input:\nGenerate a SQL query that answers the question `{{query}}`.\nThis query will run on a database whose schema is represented in this string:\n{{create_table_statement}}\n\n### Response:\nBased on your instructions, here is the SQL query I have generated to answer the question `{{query}}`:\n```sql");
    protected final ChatLanguageModel chatLanguageModel;
    protected final TableDefinition tableDefinition;
    protected final String createTableStatement;
    protected final PromptTemplate promptTemplate;
    protected final SqlFilterParser sqlFilterParser;

    @Generated
    /* loaded from: input_file:dev/langchain4j/store/embedding/filter/builder/sql/LanguageModelSqlFilterBuilder$LanguageModelSqlFilterBuilderBuilder.class */
    public static class LanguageModelSqlFilterBuilderBuilder {

        @Generated
        private ChatLanguageModel chatLanguageModel;

        @Generated
        private TableDefinition tableDefinition;

        @Generated
        private PromptTemplate promptTemplate;

        @Generated
        private SqlFilterParser sqlFilterParser;

        @Generated
        LanguageModelSqlFilterBuilderBuilder() {
        }

        @Generated
        public LanguageModelSqlFilterBuilderBuilder chatLanguageModel(ChatLanguageModel chatLanguageModel) {
            this.chatLanguageModel = chatLanguageModel;
            return this;
        }

        @Generated
        public LanguageModelSqlFilterBuilderBuilder tableDefinition(TableDefinition tableDefinition) {
            this.tableDefinition = tableDefinition;
            return this;
        }

        @Generated
        public LanguageModelSqlFilterBuilderBuilder promptTemplate(PromptTemplate promptTemplate) {
            this.promptTemplate = promptTemplate;
            return this;
        }

        @Generated
        public LanguageModelSqlFilterBuilderBuilder sqlFilterParser(SqlFilterParser sqlFilterParser) {
            this.sqlFilterParser = sqlFilterParser;
            return this;
        }

        @Generated
        public LanguageModelSqlFilterBuilder build() {
            return new LanguageModelSqlFilterBuilder(this.chatLanguageModel, this.tableDefinition, this.promptTemplate, this.sqlFilterParser);
        }

        @Generated
        public String toString() {
            return "LanguageModelSqlFilterBuilder.LanguageModelSqlFilterBuilderBuilder(chatLanguageModel=" + String.valueOf(this.chatLanguageModel) + ", tableDefinition=" + String.valueOf(this.tableDefinition) + ", promptTemplate=" + String.valueOf(this.promptTemplate) + ", sqlFilterParser=" + String.valueOf(this.sqlFilterParser) + ")";
        }
    }

    public LanguageModelSqlFilterBuilder(ChatLanguageModel chatLanguageModel, TableDefinition tableDefinition) {
        this(chatLanguageModel, tableDefinition, DEFAULT_PROMPT_TEMPLATE, new SqlFilterParser());
    }

    private LanguageModelSqlFilterBuilder(ChatLanguageModel chatLanguageModel, TableDefinition tableDefinition, PromptTemplate promptTemplate, SqlFilterParser sqlFilterParser) {
        this.chatLanguageModel = (ChatLanguageModel) ValidationUtils.ensureNotNull(chatLanguageModel, "chatLanguageModel");
        this.tableDefinition = (TableDefinition) ValidationUtils.ensureNotNull(tableDefinition, "tableDefinition");
        this.createTableStatement = format(tableDefinition);
        this.promptTemplate = (PromptTemplate) Utils.getOrDefault(promptTemplate, DEFAULT_PROMPT_TEMPLATE);
        this.sqlFilterParser = (SqlFilterParser) Utils.getOrDefault(sqlFilterParser, SqlFilterParser::new);
    }

    public Filter build(Query query) {
        String text = this.chatLanguageModel.chat(new ChatMessage[]{createPrompt(query).toUserMessage()}).aiMessage().text();
        String clean = clean(text);
        log.trace("Cleaned SQL: '{}'", clean);
        try {
            return this.sqlFilterParser.parse(clean);
        } catch (Exception e) {
            log.warn("Failed parsing the following SQL: '{}'", clean, e);
            return fallback(query, text, clean, e);
        }
    }

    protected Prompt createPrompt(Query query) {
        HashMap hashMap = new HashMap();
        hashMap.put("create_table_statement", this.createTableStatement);
        hashMap.put("query", query.text());
        return this.promptTemplate.apply(hashMap);
    }

    protected String clean(String str) {
        return str.trim();
    }

    protected Filter fallback(Query query, String str, String str2, Exception exc) {
        String extractSelectStatement = extractSelectStatement(str);
        if (Utils.isNullOrBlank(extractSelectStatement)) {
            log.trace("Cannot extract SQL, giving up");
            return null;
        }
        try {
            log.trace("Extracted SQL: '{}'", extractSelectStatement);
            return this.sqlFilterParser.parse(extractSelectStatement);
        } catch (Exception e) {
            log.warn("Failed parsing the following SQL, giving up: '{}'", extractSelectStatement, e);
            return null;
        }
    }

    protected String extractSelectStatement(String str) {
        if (str.contains("```sql")) {
            for (String str2 : str.split("```sql")) {
                if (str2.toUpperCase().contains("SELECT") && str2.toUpperCase().contains("WHERE")) {
                    return str2.split("```")[0].trim();
                }
            }
            return null;
        }
        if (str.contains("```")) {
            for (String str3 : str.split("```")) {
                if (str3.toUpperCase().contains("SELECT") && str3.toUpperCase().contains("WHERE")) {
                    return str3.split("```")[0].trim();
                }
            }
            return null;
        }
        for (String str4 : str.split("SELECT")) {
            if (str4.toUpperCase().contains("WHERE")) {
                if (!str4.contains("\n")) {
                    return "SELECT " + str4.trim();
                }
                for (String str5 : str4.split("\n")) {
                    if (str5.toUpperCase().contains("WHERE")) {
                        return "SELECT " + str5.trim();
                    }
                }
            }
        }
        return null;
    }

    protected String format(TableDefinition tableDefinition) {
        StringBuilder sb = new StringBuilder();
        sb.append(String.format("CREATE TABLE %s (\n", tableDefinition.name()));
        for (ColumnDefinition columnDefinition : tableDefinition.columns()) {
            sb.append(String.format("    %s %s,", columnDefinition.name(), columnDefinition.type()));
            if (!Utils.isNullOrBlank(columnDefinition.description())) {
                sb.append(String.format(" -- %s", columnDefinition.description()));
            }
            sb.append("\n");
        }
        sb.append(")");
        if (!Utils.isNullOrBlank(tableDefinition.description())) {
            sb.append(String.format(" COMMENT='%s'", tableDefinition.description()));
        }
        sb.append(";");
        return sb.toString();
    }

    @Generated
    public static LanguageModelSqlFilterBuilderBuilder builder() {
        return new LanguageModelSqlFilterBuilderBuilder();
    }
}
