package dev.langchain4j.store.embedding.vespa;

import ai.vespa.client.dsl.A;
import ai.vespa.client.dsl.NearestNeighbor;
import ai.vespa.client.dsl.Q;
import ai.vespa.feed.client.DocumentId;
import ai.vespa.feed.client.FeedClientBuilder;
import ai.vespa.feed.client.FeedException;
import ai.vespa.feed.client.JsonFeeder;
import ai.vespa.feed.client.Result;
import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.vespa.Record;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.net.URI;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import okhttp3.ResponseBody;
import retrofit2.Response;

/* loaded from: input_file:dev/langchain4j/store/embedding/vespa/VespaEmbeddingStore.class */
public class VespaEmbeddingStore implements EmbeddingStore<TextSegment> {
    private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
    private static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(5);
    static final String DEFAULT_NAMESPACE = "namespace";
    static final String DEFAULT_DOCUMENT_TYPE = "langchain4j";
    private static final String DEFAULT_CLUSTER_NAME = "langchain4j";
    private static final boolean DEFAULT_AVOID_DUPS = true;
    private static final String FIELD_NAME_TEXT_SEGMENT = "text_segment";
    private static final String FIELD_NAME_VECTOR = "vector";
    private static final String FIELD_NAME_DOCUMENT_ID = "documentid";
    private static final String DEFAULT_RANK_PROFILE = "langchain4j_relevance_score";
    private static final int DEFAULT_TARGET_HITS = 10;
    private final String url;
    private final Path keyPath;
    private final Path certPath;
    private final Duration timeout;
    private final String namespace;
    private final String documentType;
    private final String clusterName;
    private final String rankProfile;
    private final int targetHits;
    private final boolean avoidDups;
    private final boolean logRequests;
    private final boolean logResponses;
    private VespaApi api;

    /* loaded from: input_file:dev/langchain4j/store/embedding/vespa/VespaEmbeddingStore$Builder.class */
    public static class Builder {
        private String url;
        private String keyPath;
        private String certPath;
        private Duration timeout;
        private String namespace;
        private String documentType;
        private String clusterName;
        private String rankProfile;
        private Integer targetHits;
        private Boolean avoidDups;
        private Boolean logRequests;
        private Boolean logResponses;

        public Builder url(String str) {
            this.url = str;
            return this;
        }

        public Builder keyPath(String str) {
            this.keyPath = str;
            return this;
        }

        public Builder certPath(String str) {
            this.certPath = str;
            return this;
        }

        public Builder timeout(Duration duration) {
            this.timeout = duration;
            return this;
        }

        public Builder namespace(String str) {
            this.namespace = str;
            return this;
        }

        public Builder documentType(String str) {
            this.documentType = str;
            return this;
        }

        public Builder clusterName(String str) {
            this.clusterName = str;
            return this;
        }

        public Builder rankProfile(String str) {
            this.rankProfile = str;
            return this;
        }

        public Builder targetHits(Integer num) {
            this.targetHits = num;
            return this;
        }

        public Builder avoidDups(Boolean bool) {
            this.avoidDups = bool;
            return this;
        }

        public Builder logRequests(Boolean bool) {
            this.logRequests = bool;
            return this;
        }

        public Builder logResponses(Boolean bool) {
            this.logResponses = bool;
            return this;
        }

        public VespaEmbeddingStore build() {
            return new VespaEmbeddingStore(this.url, this.keyPath, this.certPath, this.timeout, this.namespace, this.documentType, this.rankProfile, this.clusterName, this.targetHits, this.avoidDups, this.logRequests, this.logResponses);
        }
    }

    public VespaEmbeddingStore(String str, String str2, String str3, Duration duration, String str4, String str5, String str6, String str7, Integer num, Boolean bool, Boolean bool2, Boolean bool3) {
        ValidationUtils.ensureNotNull(str, "url");
        this.url = str;
        this.keyPath = str2 != null ? Paths.get(str2, new String[0]) : null;
        this.certPath = str3 != null ? Paths.get(str3, new String[0]) : null;
        this.timeout = (Duration) Utils.getOrDefault(duration, DEFAULT_TIMEOUT);
        this.namespace = (String) Utils.getOrDefault(str4, DEFAULT_NAMESPACE);
        this.documentType = (String) Utils.getOrDefault(str5, "langchain4j");
        this.clusterName = (String) Utils.getOrDefault(str6, "langchain4j");
        this.rankProfile = (String) Utils.getOrDefault(str7, DEFAULT_RANK_PROFILE);
        this.targetHits = ((Integer) Utils.getOrDefault(num, Integer.valueOf(DEFAULT_TARGET_HITS))).intValue();
        this.avoidDups = ((Boolean) Utils.getOrDefault(bool, true)).booleanValue();
        this.logRequests = ((Boolean) Utils.getOrDefault(bool2, false)).booleanValue();
        this.logResponses = ((Boolean) Utils.getOrDefault(bool3, false)).booleanValue();
    }

    private static EmbeddingMatch<TextSegment> toEmbeddingMatch(Record record) {
        return new EmbeddingMatch<>(record.relevance(), DocumentId.of(record.fields().documentid()).userSpecific(), Embedding.from(record.fields().vector().values()), record.fields().textSegment() != null ? TextSegment.from(record.fields().textSegment()) : null);
    }

    public static Builder builder() {
        return new Builder();
    }

    public String add(Embedding embedding) {
        return add(null, embedding, null);
    }

    public void add(String str, Embedding embedding) {
        add(str, embedding, null);
    }

    public String add(Embedding embedding, TextSegment textSegment) {
        return add(null, embedding, textSegment);
    }

    public List<String> addAll(List<Embedding> list) {
        return addAll(list, null);
    }

    public void addAll(List<String> list, List<Embedding> list2, List<TextSegment> list3) {
        if (list3 != null && list2.size() != list3.size()) {
            throw new IllegalArgumentException("The list of embeddings and embedded must have the same size");
        }
        try {
            JsonFeeder feeder = feeder();
            try {
                ArrayList arrayList = new ArrayList();
                for (int i = 0; i < list2.size(); i += DEFAULT_AVOID_DUPS) {
                    arrayList.add(buildRecord(list.get(i), list2.get(i), list3 != null ? list3.get(i) : null));
                }
                feeder.feedMany(new ByteArrayInputStream(OBJECT_MAPPER.writeValueAsString(arrayList).getBytes()), new JsonFeeder.ResultCallback() { // from class: dev.langchain4j.store.embedding.vespa.VespaEmbeddingStore.1
                    public void onNextResult(Result result, FeedException feedException) {
                        if (feedException != null) {
                            throw new RuntimeException(feedException.getMessage());
                        }
                    }

                    public void onError(FeedException feedException) {
                        throw new RuntimeException(feedException.getMessage());
                    }
                });
                if (feeder != null) {
                    feeder.close();
                }
            } finally {
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest embeddingSearchRequest) {
        try {
            Response execute = api().search(Q.select(FIELD_NAME_DOCUMENT_ID, new String[]{FIELD_NAME_TEXT_SEGMENT, FIELD_NAME_VECTOR}).from(this.documentType).where(buildNearestNeighbor()).fix().hits(embeddingSearchRequest.maxResults()).ranking(this.rankProfile).param("input.query(q)", OBJECT_MAPPER.writeValueAsString(embeddingSearchRequest.queryEmbedding().vectorAsList())).param("input.query(threshold)", String.valueOf(embeddingSearchRequest.minScore())).build()).execute();
            if (!execute.isSuccessful()) {
                throw toException(execute);
            }
            List<Record> children = ((QueryResponse) execute.body()).root().children();
            return new EmbeddingSearchResult<>((children == null || children.isEmpty()) ? new ArrayList() : children.stream().map(VespaEmbeddingStore::toEmbeddingMatch).toList());
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public void removeAll() {
        try {
            Response execute = api().deleteAll(this.namespace, this.documentType, this.clusterName).execute();
            if (execute.isSuccessful()) {
            } else {
                throw toException(execute);
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private String add(String str, Embedding embedding, TextSegment textSegment) {
        AtomicReference atomicReference = new AtomicReference();
        try {
            JsonFeeder feeder = feeder();
            try {
                feeder.feedSingle(OBJECT_MAPPER.writeValueAsString(buildRecord(str, embedding, textSegment))).whenComplete((result, th) -> {
                    if (th != null) {
                        throw new RuntimeException(th);
                    }
                    if (Result.Type.success.equals(result.type())) {
                        atomicReference.set(result.documentId().userSpecific());
                    }
                });
                if (feeder != null) {
                    feeder.close();
                }
                return (String) atomicReference.get();
            } finally {
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private JsonFeeder feeder() {
        FeedClientBuilder create = FeedClientBuilder.create(URI.create(this.url));
        if (this.certPath != null && this.keyPath != null) {
            create.setCertificate(this.certPath, this.keyPath);
        }
        return JsonFeeder.builder(create.build()).withTimeout(this.timeout).build();
    }

    private VespaApi api() {
        if (this.api == null) {
            this.api = VespaClient.createInstance(this.url, this.certPath, this.keyPath, this.logRequests, this.logResponses);
        }
        return this.api;
    }

    private Record buildRecord(String str, Embedding embedding, TextSegment textSegment) {
        return new Record(DocumentId.of(this.namespace, this.documentType, str != null ? str : (!this.avoidDups || textSegment == null) ? Utils.randomUUID() : Utils.generateUUIDFrom(textSegment.text())).toString(), null, new Record.Fields(null, textSegment != null ? textSegment.text() : null, new Record.Fields.Vector(embedding.vectorAsList())));
    }

    private NearestNeighbor buildNearestNeighbor() {
        NearestNeighbor nearestNeighbor = Q.nearestNeighbor(FIELD_NAME_VECTOR, "q");
        nearestNeighbor.annotate(A.a("targetHits", Integer.valueOf(this.targetHits)));
        return nearestNeighbor;
    }

    private static RuntimeException toException(Response<?> response) throws IOException {
        ResponseBody errorBody = response.errorBody();
        try {
            int code = response.code();
            if (errorBody != null) {
                RuntimeException runtimeException = new RuntimeException(String.format("status code: %s; body: %s", Integer.valueOf(code), errorBody.string()));
                if (errorBody != null) {
                    errorBody.close();
                }
                return runtimeException;
            }
            RuntimeException runtimeException2 = new RuntimeException(String.format("status code: %s;", Integer.valueOf(code)));
            if (errorBody != null) {
                errorBody.close();
            }
            return runtimeException2;
        } catch (Throwable th) {
            if (errorBody != null) {
                try {
                    errorBody.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }
}
