package dev.langchain4j.model.bedrock;

import dev.langchain4j.Internal;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.RetryUtils;
import dev.langchain4j.model.bedrock.BedrockEmbeddingResponse;
import dev.langchain4j.model.embedding.DimensionAwareEmbeddingModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse;

/* JADX INFO: Access modifiers changed from: package-private */
@Internal
/* loaded from: input_file:dev/langchain4j/model/bedrock/AbstractBedrockEmbeddingModel.class */
public abstract class AbstractBedrockEmbeddingModel<T extends BedrockEmbeddingResponse> extends DimensionAwareEmbeddingModel {
    private static final Region DEFAULT_REGION = Region.US_EAST_1;
    private static final AwsCredentialsProvider DEFAULT_CREDENTIALS_PROVIDER = DefaultCredentialsProvider.builder().build();
    private static final Integer DEFAULT_MAX_RETRIES = 2;
    private volatile BedrockRuntimeClient client;
    private final Region region;
    private final AwsCredentialsProvider credentialsProvider;
    private final Integer maxRetries;

    /* loaded from: input_file:dev/langchain4j/model/bedrock/AbstractBedrockEmbeddingModel$AbstractBedrockEmbeddingModelBuilder.class */
    public static abstract class AbstractBedrockEmbeddingModelBuilder<T extends BedrockEmbeddingResponse, C extends AbstractBedrockEmbeddingModel<T>, B extends AbstractBedrockEmbeddingModelBuilder<T, C, B>> {
        private BedrockRuntimeClient client;
        private Region region;
        private boolean isRegionSet;
        private AwsCredentialsProvider credentialsProvider;
        private boolean isCredentialsProviderSet;
        private Integer maxRetries;
        private boolean isMaxRetriesSet;

        public B client(BedrockRuntimeClient bedrockRuntimeClient) {
            this.client = bedrockRuntimeClient;
            return self();
        }

        public B region(Region region) {
            this.region = region;
            this.isRegionSet = true;
            return self();
        }

        public B credentialsProvider(AwsCredentialsProvider awsCredentialsProvider) {
            this.credentialsProvider = awsCredentialsProvider;
            this.isCredentialsProviderSet = true;
            return self();
        }

        public B maxRetries(Integer num) {
            this.maxRetries = num;
            this.isMaxRetriesSet = true;
            return self();
        }

        protected abstract B self();

        public abstract C build();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public AbstractBedrockEmbeddingModel(AbstractBedrockEmbeddingModelBuilder<T, ?, ?> abstractBedrockEmbeddingModelBuilder) {
        this.client = ((AbstractBedrockEmbeddingModelBuilder) abstractBedrockEmbeddingModelBuilder).client;
        if (((AbstractBedrockEmbeddingModelBuilder) abstractBedrockEmbeddingModelBuilder).isRegionSet) {
            this.region = ((AbstractBedrockEmbeddingModelBuilder) abstractBedrockEmbeddingModelBuilder).region;
        } else {
            this.region = DEFAULT_REGION;
        }
        if (((AbstractBedrockEmbeddingModelBuilder) abstractBedrockEmbeddingModelBuilder).isCredentialsProviderSet) {
            this.credentialsProvider = ((AbstractBedrockEmbeddingModelBuilder) abstractBedrockEmbeddingModelBuilder).credentialsProvider;
        } else {
            this.credentialsProvider = DEFAULT_CREDENTIALS_PROVIDER;
        }
        if (((AbstractBedrockEmbeddingModelBuilder) abstractBedrockEmbeddingModelBuilder).isMaxRetriesSet) {
            this.maxRetries = ((AbstractBedrockEmbeddingModelBuilder) abstractBedrockEmbeddingModelBuilder).maxRetries;
        } else {
            this.maxRetries = DEFAULT_MAX_RETRIES;
        }
    }

    public Response<List<Embedding>> embedAll(List<TextSegment> list) {
        List<BedrockEmbeddingResponse> list2 = (List) getRequestParameters(list).stream().map((v0) -> {
            return Json.toJson(v0);
        }).map(str -> {
            return (InvokeModelResponse) RetryUtils.withRetryMappingExceptions(() -> {
                return invoke(str);
            }, this.maxRetries.intValue(), BedrockExceptionMapper.INSTANCE);
        }).map(invokeModelResponse -> {
            return invokeModelResponse.body().asUtf8String();
        }).map(str2 -> {
            return (BedrockEmbeddingResponse) Json.fromJson(str2, getResponseClassType());
        }).collect(Collectors.toList());
        int i = 0;
        ArrayList arrayList = new ArrayList();
        for (BedrockEmbeddingResponse bedrockEmbeddingResponse : list2) {
            arrayList.add(bedrockEmbeddingResponse.toEmbedding());
            i += bedrockEmbeddingResponse.getInputTextTokenCount();
        }
        return Response.from(arrayList, new TokenUsage(Integer.valueOf(i)));
    }

    protected abstract List<Map<String, Object>> getRequestParameters(List<TextSegment> list);

    public BedrockRuntimeClient getClient() {
        if (this.client == null) {
            synchronized (this) {
                if (this.client == null) {
                    this.client = initClient();
                }
            }
        }
        return this.client;
    }

    protected abstract String getModelId();

    protected abstract Class<T> getResponseClassType();

    protected InvokeModelResponse invoke(String str) {
        return getClient().invokeModel((InvokeModelRequest) InvokeModelRequest.builder().modelId(getModelId()).body(SdkBytes.fromString(str, Charset.defaultCharset())).build());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static Map<String, Object> of(String str, Object obj) {
        HashMap hashMap = new HashMap(1);
        hashMap.put(str, obj);
        return hashMap;
    }

    private BedrockRuntimeClient initClient() {
        return (BedrockRuntimeClient) BedrockRuntimeClient.builder().region(this.region).credentialsProvider(this.credentialsProvider).build();
    }

    public Region getRegion() {
        return this.region;
    }

    public AwsCredentialsProvider getCredentialsProvider() {
        return this.credentialsProvider;
    }

    public Integer getMaxRetries() {
        return this.maxRetries;
    }
}
