/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.cloud.ai.mcp.nacos.client.transport;

import com.alibaba.cloud.ai.mcp.nacos.client.utils.NacosMcpClientUtils;
import com.alibaba.cloud.ai.mcp.nacos.service.NacosMcpOperationService;
import com.alibaba.cloud.ai.mcp.nacos.service.model.NacosMcpServerEndpoint;
import com.alibaba.nacos.api.ai.model.mcp.McpEndpointInfo;
import com.alibaba.nacos.api.exception.NacosException;
import com.alibaba.nacos.api.utils.StringUtils;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.client.McpAsyncClient;
import io.modelcontextprotocol.client.McpClient;
import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport;
import io.modelcontextprotocol.spec.McpClientTransport;
import io.modelcontextprotocol.spec.McpSchema;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.mcp.client.autoconfigure.NamedClientMcpTransport;
import org.springframework.ai.mcp.client.autoconfigure.configurer.McpAsyncClientConfigurer;
import org.springframework.ai.mcp.client.autoconfigure.properties.McpClientCommonProperties;
import org.springframework.context.ApplicationContext;
import org.springframework.util.Assert;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Mono;

public class LoadbalancedMcpAsyncClient {
    private static final Logger logger = LoggerFactory.getLogger(LoadbalancedMcpAsyncClient.class);
    private final String serverName;
    private final NacosMcpOperationService nacosMcpOperationService;
    private final McpClientCommonProperties commonProperties;
    private final WebClient.Builder webClientBuilderTemplate;
    private final McpAsyncClientConfigurer mcpAsyncClientConfigurer;
    private final ObjectMapper objectMapper;
    private Map<String, McpAsyncClient> keyToClientMap;
    private Map<String, Integer> keyToCountMap;
    private NacosMcpServerEndpoint serverEndpoint;
    private final ApplicationContext applicationContext;

    public LoadbalancedMcpAsyncClient(String serverName, NacosMcpOperationService nacosMcpOperationService, ApplicationContext applicationContext) {
        Assert.notNull((Object)serverName, (String)"serviceName cannot be null");
        Assert.notNull((Object)nacosMcpOperationService, (String)"nacosMcpOperationService cannot be null");
        Assert.notNull((Object)applicationContext, (String)"applicationContext cannot be null");
        this.serverName = serverName;
        this.nacosMcpOperationService = nacosMcpOperationService;
        this.applicationContext = applicationContext;
        try {
            this.serverEndpoint = this.nacosMcpOperationService.getServerEndpoint(this.serverName);
            if (this.serverEndpoint == null) {
                throw new NacosException(404, String.format("Can not find mcp server from nacos: %s", serverName));
            }
            if (!StringUtils.equals((CharSequence)this.serverEndpoint.getProtocol(), (CharSequence)"mcp-sse")) {
                throw new RuntimeException("mcp server protocol must be sse");
            }
        }
        catch (Exception e) {
            throw new RuntimeException(String.format("Failed to get instances for service: %s", serverName), e);
        }
        this.commonProperties = (McpClientCommonProperties)this.applicationContext.getBean(McpClientCommonProperties.class);
        this.mcpAsyncClientConfigurer = (McpAsyncClientConfigurer)this.applicationContext.getBean(McpAsyncClientConfigurer.class);
        this.objectMapper = (ObjectMapper)this.applicationContext.getBean(ObjectMapper.class);
        this.webClientBuilderTemplate = (WebClient.Builder)this.applicationContext.getBean(WebClient.Builder.class);
    }

    public void init() {
        this.keyToClientMap = new ConcurrentHashMap<String, McpAsyncClient>();
        this.keyToCountMap = new ConcurrentHashMap<String, Integer>();
        for (McpEndpointInfo mcpEndpointInfo : this.serverEndpoint.getMcpEndpointInfoList()) {
            this.updateByAddEndpoint(mcpEndpointInfo, this.serverEndpoint.getExportPath());
        }
    }

    public void subscribe() {
        this.nacosMcpOperationService.subscribeNacosMcpServer(this.serverName, mcpServerDetailInfo -> {
            ArrayList<McpEndpointInfo> mcpEndpointInfoList = mcpServerDetailInfo.getBackendEndpoints() == null ? new ArrayList() : mcpServerDetailInfo.getBackendEndpoints();
            String exportPath = mcpServerDetailInfo.getRemoteServerConfig().getExportPath();
            String protocol = mcpServerDetailInfo.getProtocol();
            String realVersion = mcpServerDetailInfo.getVersionDetail().getVersion();
            NacosMcpServerEndpoint nacosMcpServerEndpoint = new NacosMcpServerEndpoint(mcpEndpointInfoList, exportPath, protocol, realVersion);
            if (!StringUtils.equals((CharSequence)protocol, (CharSequence)"mcp-sse")) {
                return;
            }
            this.updateClientList(nacosMcpServerEndpoint);
        });
    }

    public McpAsyncClient getMcpAsyncClient() {
        List<McpAsyncClient> asynClients = this.getMcpAsyncClientList();
        if (asynClients.isEmpty()) {
            throw new IllegalStateException("No McpAsyncClient available");
        }
        String clientInfoName = this.keyToCountMap.entrySet().stream().min(Map.Entry.comparingByValue()).map(Map.Entry::getKey).get();
        this.keyToCountMap.put(clientInfoName, this.keyToCountMap.get(clientInfoName) + 1);
        return asynClients.stream().filter(aysnClient -> aysnClient.getClientInfo().name().equals(clientInfoName)).findFirst().get();
    }

    public List<McpAsyncClient> getMcpAsyncClientList() {
        return this.keyToClientMap.values().stream().toList();
    }

    public String getServerName() {
        return this.serverName;
    }

    public NacosMcpServerEndpoint getNacosMcpServerEndpoint() {
        return this.serverEndpoint;
    }

    public McpSchema.ServerCapabilities getServerCapabilities() {
        return this.getMcpAsyncClient().getServerCapabilities();
    }

    public McpSchema.Implementation getServerInfo() {
        return this.getMcpAsyncClient().getServerInfo();
    }

    public boolean isInitialized() {
        return this.getMcpAsyncClient().isInitialized();
    }

    public McpSchema.ClientCapabilities getClientCapabilities() {
        return this.getMcpAsyncClient().getClientCapabilities();
    }

    public McpSchema.Implementation getClientInfo() {
        return this.getMcpAsyncClient().getClientInfo();
    }

    public void close() {
        Iterator<McpAsyncClient> iterator = this.getMcpAsyncClientList().iterator();
        while (iterator.hasNext()) {
            McpAsyncClient mcpAsyncClient = iterator.next();
            mcpAsyncClient.close();
            iterator.remove();
            logger.info("Closed and removed McpAsyncClient: {}", (Object)mcpAsyncClient.getClientInfo().name());
        }
    }

    public Mono<Void> closeGracefully() {
        Iterator<McpAsyncClient> iterator = this.getMcpAsyncClientList().iterator();
        ArrayList<Mono> closeMonos = new ArrayList<Mono>();
        while (iterator.hasNext()) {
            McpAsyncClient mcpAsyncClient = iterator.next();
            Mono voidMono = mcpAsyncClient.closeGracefully().doOnSuccess(v -> {
                iterator.remove();
                logger.info("Closed and removed McpAsyncClient: {}", (Object)mcpAsyncClient.getClientInfo().name());
            });
            closeMonos.add(voidMono);
        }
        return Mono.when(closeMonos);
    }

    public Mono<Object> ping() {
        return this.getMcpAsyncClient().ping();
    }

    public Mono<Void> addRoot(McpSchema.Root root) {
        return Mono.when((Iterable)this.getMcpAsyncClientList().stream().map(mcpAsyncClient -> mcpAsyncClient.addRoot(root)).collect(Collectors.toList()));
    }

    public Mono<Void> removeRoot(String rootUri) {
        return Mono.when((Iterable)this.getMcpAsyncClientList().stream().map(mcpAsyncClient -> mcpAsyncClient.removeRoot(rootUri)).collect(Collectors.toList()));
    }

    public Mono<Void> rootsListChangedNotification() {
        return Mono.when((Iterable)this.getMcpAsyncClientList().stream().map(McpAsyncClient::rootsListChangedNotification).collect(Collectors.toList()));
    }

    public Mono<McpSchema.CallToolResult> callTool(McpSchema.CallToolRequest callToolRequest) {
        return this.getMcpAsyncClient().callTool(callToolRequest);
    }

    public Mono<McpSchema.ListToolsResult> listTools() {
        return this.listToolsInternal(null);
    }

    public Mono<McpSchema.ListToolsResult> listTools(String cursor) {
        return this.listToolsInternal(cursor);
    }

    private Mono<McpSchema.ListToolsResult> listToolsInternal(String cursor) {
        return this.getMcpAsyncClient().listTools(cursor);
    }

    public Mono<McpSchema.ListResourcesResult> listResources() {
        return this.getMcpAsyncClient().listResources();
    }

    public Mono<McpSchema.ListResourcesResult> listResources(String cursor) {
        return this.getMcpAsyncClient().listResources(cursor);
    }

    public Mono<McpSchema.ReadResourceResult> readResource(McpSchema.Resource resource) {
        return this.getMcpAsyncClient().readResource(resource);
    }

    public Mono<McpSchema.ReadResourceResult> readResource(McpSchema.ReadResourceRequest readResourceRequest) {
        return this.getMcpAsyncClient().readResource(readResourceRequest);
    }

    public Mono<McpSchema.ListResourceTemplatesResult> listResourceTemplates() {
        return this.getMcpAsyncClient().listResourceTemplates();
    }

    public Mono<McpSchema.ListResourceTemplatesResult> listResourceTemplates(String cursor) {
        return this.getMcpAsyncClient().listResourceTemplates(cursor);
    }

    public Mono<Void> subscribeResource(McpSchema.SubscribeRequest subscribeRequest) {
        return Mono.when((Iterable)this.getMcpAsyncClientList().stream().map(mcpAsyncClient -> mcpAsyncClient.subscribeResource(subscribeRequest)).collect(Collectors.toList()));
    }

    public Mono<Void> unsubscribeResource(McpSchema.UnsubscribeRequest unsubscribeRequest) {
        return Mono.when((Iterable)this.getMcpAsyncClientList().stream().map(mcpAsyncClient -> mcpAsyncClient.unsubscribeResource(unsubscribeRequest)).collect(Collectors.toList()));
    }

    public Mono<McpSchema.ListPromptsResult> listPrompts() {
        return this.getMcpAsyncClient().listPrompts();
    }

    public Mono<McpSchema.ListPromptsResult> listPrompts(String cursor) {
        return this.getMcpAsyncClient().listPrompts(cursor);
    }

    public Mono<McpSchema.GetPromptResult> getPrompt(McpSchema.GetPromptRequest getPromptRequest) {
        return this.getMcpAsyncClient().getPrompt(getPromptRequest);
    }

    public Mono<Void> setLoggingLevel(McpSchema.LoggingLevel loggingLevel) {
        return Mono.when((Iterable)this.getMcpAsyncClientList().stream().map(mcpAsyncClient -> mcpAsyncClient.setLoggingLevel(loggingLevel)).collect(Collectors.toList()));
    }

    private void updateClientList(NacosMcpServerEndpoint newServerEndpoint) {
        if (!StringUtils.equals((CharSequence)this.serverEndpoint.getExportPath(), (CharSequence)newServerEndpoint.getExportPath()) || !StringUtils.equals((CharSequence)this.serverEndpoint.getVersion(), (CharSequence)newServerEndpoint.getVersion())) {
            this.updateAll(newServerEndpoint);
        } else {
            List<McpEndpointInfo> currentMcpEndpointInfoList = this.serverEndpoint.getMcpEndpointInfoList();
            List<McpEndpointInfo> newMcpEndpointInfoList = newServerEndpoint.getMcpEndpointInfoList();
            List<McpEndpointInfo> addEndpointInfoList = newMcpEndpointInfoList.stream().filter(newEndpoint -> currentMcpEndpointInfoList.stream().noneMatch(currentEndpoint -> currentEndpoint.getAddress().equals(newEndpoint.getAddress()) && currentEndpoint.getPort() == newEndpoint.getPort())).toList();
            List<McpEndpointInfo> removeEndpointInfoList = currentMcpEndpointInfoList.stream().filter(currentEndpoint -> newMcpEndpointInfoList.stream().noneMatch(newEndpoint -> newEndpoint.getAddress().equals(currentEndpoint.getAddress()) && newEndpoint.getPort() == currentEndpoint.getPort())).toList();
            for (McpEndpointInfo addEndpointInfo : addEndpointInfoList) {
                this.updateByAddEndpoint(addEndpointInfo, newServerEndpoint.getExportPath());
            }
            for (McpEndpointInfo removeEndpointInfo : removeEndpointInfoList) {
                this.updateByRemoveEndpoint(removeEndpointInfo, newServerEndpoint.getExportPath());
            }
        }
        this.serverEndpoint = newServerEndpoint;
    }

    private void updateAll(NacosMcpServerEndpoint newServerEndpoint) {
        ConcurrentHashMap<String, McpAsyncClient> newKeyToClientMap = new ConcurrentHashMap<String, McpAsyncClient>();
        Map<String, McpAsyncClient> oldKeyToClientMap = this.keyToClientMap;
        ConcurrentHashMap<String, Integer> newKeyToCountMap = new ConcurrentHashMap<String, Integer>();
        for (McpEndpointInfo mcpEndpointInfo : newServerEndpoint.getMcpEndpointInfoList()) {
            McpAsyncClient syncClient = this.clientByEndpoint(mcpEndpointInfo, newServerEndpoint.getExportPath());
            String key = NacosMcpClientUtils.getMcpEndpointInfoId(mcpEndpointInfo, newServerEndpoint.getExportPath());
            newKeyToClientMap.putIfAbsent(key, syncClient);
            newKeyToCountMap.putIfAbsent(key, 0);
        }
        this.keyToClientMap = newKeyToClientMap;
        this.keyToCountMap = newKeyToCountMap;
        for (Map.Entry entry : oldKeyToClientMap.entrySet()) {
            McpAsyncClient asyncClient = (McpAsyncClient)entry.getValue();
            logger.info("Removing McpAsyncClient: {}", (Object)asyncClient.getClientInfo().name());
            asyncClient.closeGracefully().block();
            logger.info("Removed McpAsyncClient: {} Success", (Object)asyncClient.getClientInfo().name());
        }
    }

    private McpAsyncClient clientByEndpoint(McpEndpointInfo mcpEndpointInfo, String exportPath) {
        String baseUrl = "http://" + mcpEndpointInfo.getAddress() + ":" + mcpEndpointInfo.getPort();
        WebClient.Builder webClientBuilder = this.webClientBuilderTemplate.clone().baseUrl(baseUrl);
        WebFluxSseClientTransport transport = new WebFluxSseClientTransport(webClientBuilder, this.objectMapper, exportPath);
        NamedClientMcpTransport namedTransport = new NamedClientMcpTransport(this.serverName + "-" + NacosMcpClientUtils.getMcpEndpointInfoId(mcpEndpointInfo, exportPath), (McpClientTransport)transport);
        McpSchema.Implementation clientInfo = new McpSchema.Implementation(this.connectedClientName(this.commonProperties.getName(), namedTransport.name()), this.commonProperties.getVersion());
        McpClient.AsyncSpec asyncSpec = McpClient.async((McpClientTransport)namedTransport.transport()).clientInfo(clientInfo).requestTimeout(this.commonProperties.getRequestTimeout());
        asyncSpec = this.mcpAsyncClientConfigurer.configure(namedTransport.name(), asyncSpec);
        McpAsyncClient asyncClient = asyncSpec.build();
        if (this.commonProperties.isInitialized()) {
            asyncClient.initialize().block();
        }
        logger.info("Added McpAsyncClient: {}", (Object)clientInfo.name());
        return asyncClient;
    }

    private void updateByAddEndpoint(McpEndpointInfo serverEndpoint, String exportPath) {
        McpAsyncClient mcpAsyncClient = this.clientByEndpoint(serverEndpoint, exportPath);
        String key = NacosMcpClientUtils.getMcpEndpointInfoId(serverEndpoint, exportPath);
        this.keyToClientMap.putIfAbsent(key, mcpAsyncClient);
        this.keyToCountMap.putIfAbsent(key, 0);
    }

    private void updateByRemoveEndpoint(McpEndpointInfo serverEndpoint, String exportPath) {
        String key = NacosMcpClientUtils.getMcpEndpointInfoId(serverEndpoint, exportPath);
        if (this.keyToClientMap.containsKey(key)) {
            McpAsyncClient asyncClient = this.keyToClientMap.remove(key);
            logger.info("Removing McpAsyncClient: {}", (Object)asyncClient.getClientInfo().name());
            asyncClient.closeGracefully().block();
            this.keyToCountMap.remove(key);
            logger.info("Removed McpAsyncClient: {} Success", (Object)asyncClient.getClientInfo().name());
        }
    }

    private String connectedClientName(String clientName, String serverConnectionName) {
        return clientName + " - " + serverConnectionName;
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder {
        private String serverName;
        private NacosMcpOperationService nacosMcpOperationService;
        private ApplicationContext applicationContext;

        public Builder serverName(String serverName) {
            this.serverName = serverName;
            return this;
        }

        public Builder nacosMcpOperationService(NacosMcpOperationService nacosMcpOperationService) {
            this.nacosMcpOperationService = nacosMcpOperationService;
            return this;
        }

        public Builder applicationContext(ApplicationContext applicationContext) {
            this.applicationContext = applicationContext;
            return this;
        }

        public LoadbalancedMcpAsyncClient build() {
            return new LoadbalancedMcpAsyncClient(this.serverName, this.nacosMcpOperationService, this.applicationContext);
        }
    }
}

