package com.alibaba.cloud.ai.vectorstore.oceanbase;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.json.JsonMapper;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.IntStream;
import javax.sql.DataSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
import org.springframework.ai.util.JacksonUtils;
import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

/* loaded from: input_file:com/alibaba/cloud/ai/vectorstore/oceanbase/OceanBaseVectorStore.class */
public class OceanBaseVectorStore extends AbstractObservationVectorStore implements InitializingBean {
    private static final String DATA_BASE_SYSTEM = "oceanbase";
    private static final String REF_DOC_NAME = "refDocId";
    private static final String METADATA_FIELD_NAME = "metadata";
    private static final String CONTENT_FIELD_NAME = "content";
    private static final String DOC_NAME = "docId";
    private static final String CREATE_TABLE_SQL_TEMPLATE = "CREATE TABLE IF NOT EXISTS %s (id varchar(100) PRIMARY KEY, vector VECTOR(384) NOT NULL, description text, metadata text)";
    private static final String INSERT_DOC_SQL_TEMPLATE = "INSERT INTO %s (id, vector, description, metadata) VALUES (?, ?, ?, ?)";
    private static final String DELETE_DOC_SQL_TEMPLATE = "DELETE FROM %s WHERE id = ?";
    private static final String DELETE_DOC_BY_FILTER_SQL_TEMPLATE = "DELETE FROM %s WHERE %s";
    private static final String SIMILARITY_SEARCH_SQL_TEMPLATE = "SELECT id, vector, description, metadata, l2_distance(vector,?) as distance FROM %s ORDER BY vector_distance(vector, ?) ASC LIMIT ?";
    public final FilterExpressionConverter filterExpressionConverter;
    private final String tableName;
    private final Integer defaultTopK;
    private final Double defaultSimilarityThreshold;
    private final DataSource dataSource;
    private final ObjectMapper objectMapper;
    private static final Logger logger = LoggerFactory.getLogger(OceanBaseVectorStore.class);
    private static final Double DEFAULT_SIMILARITY_THRESHOLD = Double.valueOf(0.0d);

    /* loaded from: input_file:com/alibaba/cloud/ai/vectorstore/oceanbase/OceanBaseVectorStore$Builder.class */
    public static class Builder extends AbstractVectorStoreBuilder<Builder> {
        private final String tableName;
        private final DataSource dataSource;
        private int defaultTopK;
        private Double defaultSimilarityThreshold;

        private Builder(String str, DataSource dataSource, EmbeddingModel embeddingModel) {
            super(embeddingModel);
            this.defaultTopK = 4;
            this.defaultSimilarityThreshold = OceanBaseVectorStore.DEFAULT_SIMILARITY_THRESHOLD;
            Assert.notNull(str, "Table name must not be null");
            Assert.notNull(dataSource, "Data source must not be null");
            this.tableName = str.toLowerCase();
            this.dataSource = dataSource;
        }

        public Builder defaultTopK(int i) {
            Assert.isTrue(i >= 0, "The topK should be positive value.");
            this.defaultTopK = i;
            return this;
        }

        public Builder defaultSimilarityThreshold(Double d) {
            Assert.isTrue(d.doubleValue() >= 0.0d && d.doubleValue() <= 1.0d, "The similarity threshold must be in range [0.0:1.0].");
            this.defaultSimilarityThreshold = d;
            return this;
        }

        /* renamed from: build, reason: merged with bridge method [inline-methods] */
        public OceanBaseVectorStore m2build() {
            try {
                return new OceanBaseVectorStore(this);
            } catch (Exception e) {
                throw new RuntimeException("Failed to build OceanBaseVectorStore: " + e.getMessage(), e);
            }
        }
    }

    protected OceanBaseVectorStore(Builder builder) {
        super(builder);
        this.filterExpressionConverter = new OceanBaseVectorFilterExpressionConverter();
        this.tableName = builder.tableName;
        this.dataSource = builder.dataSource;
        this.objectMapper = JsonMapper.builder().addModules(JacksonUtils.instantiateAvailableModules()).build();
        this.defaultSimilarityThreshold = builder.defaultSimilarityThreshold;
        this.defaultTopK = Integer.valueOf(builder.defaultTopK);
    }

    public static Builder builder(String str, DataSource dataSource, EmbeddingModel embeddingModel) {
        return new Builder(str, dataSource, embeddingModel);
    }

    public void afterPropertiesSet() {
        initializeDatabase();
    }

    private void initializeDatabase() {
        executeUpdate(String.format(CREATE_TABLE_SQL_TEMPLATE, this.tableName));
        logger.debug("Successfully created or verified table: {}", this.tableName);
    }

    public void doAdd(List<Document> list) {
        Assert.notNull(list, "The document list should not be null.");
        if (CollectionUtils.isEmpty(list)) {
            return;
        }
        List embed = this.embeddingModel.embed(list, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);
        String format = String.format(INSERT_DOC_SQL_TEMPLATE, this.tableName);
        try {
            Connection connection = this.dataSource.getConnection();
            try {
                PreparedStatement prepareStatement = connection.prepareStatement(format);
                for (int i = 0; i < list.size(); i++) {
                    try {
                        Document document = list.get(i);
                        Map<String, String> createMetadata = createMetadata(document);
                        String convertEmbeddingToString = convertEmbeddingToString((float[]) embed.get(i));
                        prepareStatement.setString(1, document.getId());
                        prepareStatement.setString(2, convertEmbeddingToString);
                        prepareStatement.setString(3, document.getText());
                        prepareStatement.setString(4, this.objectMapper.writeValueAsString(createMetadata));
                        prepareStatement.addBatch();
                    } catch (Throwable th) {
                        if (prepareStatement != null) {
                            try {
                                prepareStatement.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        }
                        throw th;
                    }
                }
                prepareStatement.executeBatch();
                if (prepareStatement != null) {
                    prepareStatement.close();
                }
                if (connection != null) {
                    connection.close();
                }
            } finally {
            }
        } catch (Exception e) {
            logger.error("Failed to add documents", e);
            throw new RuntimeException("Failed to add documents to OceanBase", e);
        }
    }

    private Map<String, String> createMetadata(Document document) throws JsonProcessingException {
        HashMap hashMap = new HashMap();
        hashMap.put(REF_DOC_NAME, (String) Optional.ofNullable(document.getMetadata().get(DOC_NAME)).map((v0) -> {
            return v0.toString();
        }).orElse(document.getId()));
        hashMap.put(CONTENT_FIELD_NAME, document.getText());
        hashMap.put(METADATA_FIELD_NAME, this.objectMapper.writeValueAsString(document.getMetadata()));
        return hashMap;
    }

    private String convertEmbeddingToString(float[] fArr) {
        return Arrays.toString(IntStream.range(0, fArr.length).mapToObj(i -> {
            return Float.valueOf(fArr[i]);
        }).toArray());
    }

    public void doDelete(List<String> list) {
        if (CollectionUtils.isEmpty(list)) {
            return;
        }
        executeBatchUpdate(String.format(DELETE_DOC_SQL_TEMPLATE, this.tableName), list);
    }

    public void doDelete(Filter.Expression expression) {
        executeUpdate(String.format(DELETE_DOC_BY_FILTER_SQL_TEMPLATE, this.tableName, this.filterExpressionConverter.convertExpression(expression)));
    }

    public List<Document> similaritySearch(String str) {
        return similaritySearch(SearchRequest.builder().query(str).topK(this.defaultTopK.intValue()).similarityThreshold(this.defaultSimilarityThreshold.doubleValue()).build());
    }

    public List<Document> doSimilaritySearch(SearchRequest searchRequest) {
        String format = String.format(SIMILARITY_SEARCH_SQL_TEMPLATE, this.tableName);
        ArrayList arrayList = new ArrayList();
        try {
            Connection connection = this.dataSource.getConnection();
            try {
                PreparedStatement prepareStatement = connection.prepareStatement(format);
                try {
                    String convertQueryToVectorBytes = convertQueryToVectorBytes(searchRequest.getQuery());
                    prepareStatement.setString(1, convertQueryToVectorBytes);
                    prepareStatement.setString(2, convertQueryToVectorBytes);
                    prepareStatement.setInt(3, searchRequest.getTopK());
                    ResultSet executeQuery = prepareStatement.executeQuery();
                    while (executeQuery.next()) {
                        arrayList.add(extractDocumentFromResultSet(executeQuery));
                    }
                    if (prepareStatement != null) {
                        prepareStatement.close();
                    }
                    if (connection != null) {
                        connection.close();
                    }
                    return arrayList;
                } catch (Throwable th) {
                    if (prepareStatement != null) {
                        try {
                            prepareStatement.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } finally {
            }
        } catch (Exception e) {
            logger.error("Failed to perform similarity search", e);
            throw new RuntimeException("Failed to perform similarity search in OceanBase", e);
        }
    }

    private Document extractDocumentFromResultSet(ResultSet resultSet) throws SQLException, JsonProcessingException {
        String string = resultSet.getString("id");
        String string2 = resultSet.getString(METADATA_FIELD_NAME);
        String string3 = resultSet.getString("distance");
        Map<String, String> extractMetadata = extractMetadata(string2);
        String str = extractMetadata.get(CONTENT_FIELD_NAME);
        Map map = (Map) this.objectMapper.readValue(extractMetadata.get(METADATA_FIELD_NAME), new TypeReference<Map<String, Object>>() { // from class: com.alibaba.cloud.ai.vectorstore.oceanbase.OceanBaseVectorStore.1
        });
        map.put("distance", string3);
        return new Document(String.valueOf(string), str, map);
    }

    private Map<String, String> extractMetadata(String str) throws JsonProcessingException {
        return (Map) this.objectMapper.readValue(str, Map.class);
    }

    private String convertQueryToVectorBytes(String str) {
        return Arrays.toString(this.embeddingModel.embed(str));
    }

    private void executeUpdate(String str) {
        try {
            Connection connection = this.dataSource.getConnection();
            try {
                PreparedStatement prepareStatement = connection.prepareStatement(str);
                try {
                    prepareStatement.execute();
                    if (prepareStatement != null) {
                        prepareStatement.close();
                    }
                    if (connection != null) {
                        connection.close();
                    }
                } catch (Throwable th) {
                    if (prepareStatement != null) {
                        try {
                            prepareStatement.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } finally {
            }
        } catch (SQLException e) {
            logger.error("SQL execution failed", e);
            throw new RuntimeException("Failed to execute SQL", e);
        }
    }

    private void executeBatchUpdate(String str, List<String> list) {
        try {
            Connection connection = this.dataSource.getConnection();
            try {
                PreparedStatement prepareStatement = connection.prepareStatement(str);
                try {
                    Iterator<String> it = list.iterator();
                    while (it.hasNext()) {
                        prepareStatement.setLong(1, Long.parseLong(it.next()));
                        prepareStatement.addBatch();
                    }
                    prepareStatement.executeBatch();
                    if (prepareStatement != null) {
                        prepareStatement.close();
                    }
                    if (connection != null) {
                        connection.close();
                    }
                } catch (Throwable th) {
                    if (prepareStatement != null) {
                        try {
                            prepareStatement.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } finally {
            }
        } catch (SQLException e) {
            logger.error("Batch SQL execution failed", e);
            throw new RuntimeException("Failed to execute batch SQL", e);
        }
    }

    public VectorStoreObservationContext.Builder createObservationContextBuilder(String str) {
        return VectorStoreObservationContext.builder(DATA_BASE_SYSTEM, str).collectionName(this.tableName).dimensions(Integer.valueOf(this.embeddingModel.dimensions()));
    }
}
