use std::collections::BTreeMap;
use std::sync::Arc;
use std::time::Duration;
use anyhow::Result;
use tokio::sync::{mpsc, RwLock};
use atomcode_telemetry::{Event as TelemetryEvent, McpErrorKind, McpTransport};
use super::client::{McpClient, McpToolInfo};
use super::config::{load_mcp_config, McpServerConfig};
use super::transport_http::HttpClient;
use super::transport_stdio::StdioClient;
use super::types::ServerStatus;
#[derive(Debug, Clone)]
pub enum McpConnectEvent {
Connected { name: String },
Failed { name: String, error: String },
Warning { name: String, message: String },
}
pub struct McpRegistry {
servers: Arc<RwLock<BTreeMap<String, Arc<dyn McpClient>>>>,
server_timeouts_ms: Arc<RwLock<BTreeMap<String, u64>>>,
connect_events: Option<mpsc::UnboundedSender<McpConnectEvent>>,
initial_ready: Arc<tokio::sync::Notify>,
telemetry: Option<Arc<atomcode_telemetry::Telemetry>>,
}
impl McpRegistry {
pub fn new() -> Self {
Self {
servers: Arc::new(RwLock::new(BTreeMap::new())),
server_timeouts_ms: Arc::new(RwLock::new(BTreeMap::new())),
connect_events: None,
initial_ready: Arc::new(tokio::sync::Notify::new()),
telemetry: None,
}
}
pub fn with_telemetry(mut self, tel: Arc<atomcode_telemetry::Telemetry>) -> Self {
self.telemetry = Some(tel);
self
}
pub fn with_event_channel() -> (Self, mpsc::UnboundedReceiver<McpConnectEvent>) {
let (tx, rx) = mpsc::unbounded_channel();
(
Self {
servers: Arc::new(RwLock::new(BTreeMap::new())),
server_timeouts_ms: Arc::new(RwLock::new(BTreeMap::new())),
connect_events: Some(tx),
initial_ready: Arc::new(tokio::sync::Notify::new()),
telemetry: None,
},
rx,
)
}
pub fn event_sender(&self) -> Option<mpsc::UnboundedSender<McpConnectEvent>> {
self.connect_events.clone()
}
pub fn from_config_background(project_dir: &std::path::Path) -> Self {
Self::from_config_background_with_events(project_dir, None)
}
pub fn from_config_background_with_events(
project_dir: &std::path::Path,
event_tx: Option<mpsc::UnboundedSender<McpConnectEvent>>,
) -> Self {
let mut registry = Self::new();
let combined_tx = event_tx.or(registry.connect_events.clone());
registry.connect_events = combined_tx.clone();
let configs = match load_mcp_config(project_dir) {
Ok(c) => c,
Err(e) => {
if let Some(tx) = &combined_tx {
let _ = tx.send(McpConnectEvent::Failed {
name: "config".to_string(),
error: format!("Failed to load config: {}", e),
});
}
return registry;
}
};
if !configs.is_empty() {
let servers = registry.servers.clone();
let server_timeouts_ms = registry.server_timeouts_ms.clone();
let initial_ready = registry.initial_ready.clone();
let telemetry = registry.telemetry.clone();
tokio::spawn(async move {
let tasks: Vec<_> = configs
.into_iter()
.map(|config| {
let servers = servers.clone();
let server_timeouts_ms = server_timeouts_ms.clone();
let tx = combined_tx.clone();
let telemetry = telemetry.clone();
async move {
let name = config.name.clone();
let timeout_ms = config.timeout_ms();
let config_source = config.source;
let transport = match &config.config {
super::config::McpTransportConfig::Stdio { .. } => McpTransport::Stdio,
super::config::McpTransportConfig::Http { .. } => McpTransport::StreamableHttp,
};
let start = std::time::Instant::now();
let mut client: Box<dyn McpClient> = match &config.config {
super::config::McpTransportConfig::Stdio {
command,
args,
env,
timeout_ms,
} => Box::new(StdioClient::new(
name.clone(),
command.clone(),
args.clone(),
env.clone(),
*timeout_ms,
)),
super::config::McpTransportConfig::Http {
url,
headers,
auth,
timeout_ms,
} => Box::new(HttpClient::new(
name.clone(),
url.clone(),
headers.clone(),
auth.clone(),
*timeout_ms,
)),
};
match client.initialize().await {
Ok(_result) => {
let duration_ms = start.elapsed().as_millis() as u32;
let mut servers = servers.write().await;
servers.insert(name.clone(), Arc::from(client));
drop(servers);
let mut timeouts = server_timeouts_ms.write().await;
timeouts.insert(name.clone(), timeout_ms);
if let Some(tx) = tx {
let _ = tx.send(McpConnectEvent::Connected {
name: name.clone(),
});
}
if let Some(tel) = &telemetry {
tel.track(TelemetryEvent::McpConnect {
server_name: name.clone(),
transport,
success: true,
duration_ms: Some(duration_ms),
error_kind: None,
error_data: Some(serde_json::json!({
"server_name": name,
"transport": match transport { McpTransport::Stdio => "stdio", McpTransport::Sse => "sse", McpTransport::StreamableHttp => "streamable_http" },
"duration_ms": duration_ms,
"tool_count": 0,
"config_source": config_source.as_str(),
}).to_string()),
});
}
}
Err(e) => {
let duration_ms = start.elapsed().as_millis() as u32;
let error_str = format!("{}", e);
if let Some(tx) = tx {
let _ = tx.send(McpConnectEvent::Failed {
name: name.clone(),
error: error_str.clone(),
});
}
if let Some(tel) = &telemetry {
let error_kind = classify_mcp_error(&error_str);
tel.track(TelemetryEvent::McpConnect {
server_name: name.clone(),
transport,
success: false,
duration_ms: Some(duration_ms),
error_kind: Some(error_kind),
error_data: Some(serde_json::json!({
"server_name": name,
"transport": match transport { McpTransport::Stdio => "stdio", McpTransport::Sse => "sse", McpTransport::StreamableHttp => "streamable_http" },
"duration_ms": duration_ms,
"message": atomcode_telemetry::scrub::truncate_head(&error_str, 200),
"config_source": config_source.as_str(),
}).to_string()),
});
}
}
}
}
})
.collect();
futures::future::join_all(tasks).await;
initial_ready.notify_waiters();
});
} else {
registry.initial_ready.notify_waiters();
}
registry
}
pub async fn from_config(project_dir: &std::path::Path) -> Self {
let registry = Self::new();
let configs = match load_mcp_config(project_dir) {
Ok(c) => c,
Err(e) => {
tracing::warn!("[mcp] Failed to load config: {}", e);
return registry;
}
};
for config in configs {
if let Err(e) = registry.add_server(config).await {
tracing::warn!("[mcp] Failed to connect server: {}", e);
}
}
registry
}
pub async fn add_server(&self, config: McpServerConfig) -> Result<()> {
let mut client: Box<dyn McpClient> = match &config.config {
super::config::McpTransportConfig::Stdio {
command,
args,
env,
timeout_ms,
} => Box::new(StdioClient::new(
config.name.clone(),
command.clone(),
args.clone(),
env.clone(),
*timeout_ms,
)),
super::config::McpTransportConfig::Http {
url,
headers,
auth,
timeout_ms,
} => Box::new(HttpClient::new(
config.name.clone(),
url.clone(),
headers.clone(),
auth.clone(),
*timeout_ms,
)),
};
client.initialize().await?;
let mut servers = self.servers.write().await;
servers.insert(config.name.clone(), Arc::from(client));
drop(servers);
let mut timeouts = self.server_timeouts_ms.write().await;
timeouts.insert(config.name.clone(), config.timeout_ms());
Ok(())
}
pub async fn list_tools_timeout(&self, server_name: &str) -> Duration {
let configured_ms = {
let timeouts = self.server_timeouts_ms.read().await;
timeouts.get(server_name).copied().unwrap_or(30_000)
};
Duration::from_millis(configured_ms.saturating_add(5_000))
}
pub async fn list_all_tools(&self) -> Vec<McpToolInfo> {
let server_snapshot: Vec<(String, Arc<dyn McpClient>)> = {
let servers = self.servers.read().await;
servers
.iter()
.map(|(name, client)| (name.clone(), Arc::clone(client)))
.collect()
};
let mut all_tools = Vec::new();
for (server_name, client) in server_snapshot {
match client.list_tools().await {
Ok(result) => {
for tool in result.tools {
all_tools.push(McpToolInfo {
server_name: server_name.clone(),
tool_name: tool.name,
description: tool.description,
input_schema: tool.input_schema,
});
}
}
Err(e) => {
if let Some(tx) = &self.connect_events {
let _ = tx.send(McpConnectEvent::Warning {
name: server_name.clone(),
message: format!("tools/list failed: {}", e),
});
} else {
tracing::warn!("[mcp] Failed to list tools from {}: {}", server_name, e);
}
}
}
}
all_tools
}
pub async fn list_tools_for_server(&self, server_name: &str) -> Vec<McpToolInfo> {
let client = {
let servers = self.servers.read().await;
servers.get(server_name).map(Arc::clone)
};
let Some(client) = client else {
if let Some(tx) = &self.connect_events {
let _ = tx.send(McpConnectEvent::Warning {
name: server_name.to_string(),
message: "tools/list skipped: server not found".to_string(),
});
}
return Vec::new();
};
match client.list_tools().await {
Ok(result) => result
.tools
.into_iter()
.map(|tool| McpToolInfo {
server_name: server_name.to_string(),
tool_name: tool.name,
description: tool.description,
input_schema: tool.input_schema,
})
.collect(),
Err(e) => {
if let Some(tx) = &self.connect_events {
let _ = tx.send(McpConnectEvent::Warning {
name: server_name.to_string(),
message: format!("tools/list failed: {}", e),
});
} else {
tracing::warn!("[mcp] Failed to list tools from {}: {}", server_name, e);
}
Vec::new()
}
}
}
pub async fn call_tool(
&self,
server_name: &str,
tool_name: &str,
arguments: serde_json::Value,
) -> Result<String> {
let servers = self.servers.read().await;
let client = servers
.get(server_name)
.ok_or_else(|| anyhow::anyhow!("MCP server '{}' not found", server_name))?;
let result = client.call_tool(tool_name, arguments).await?;
let output = result
.content
.into_iter()
.filter_map(|c| match c {
super::types::ContentBlock::Text { text } => Some(text),
_ => None,
})
.collect::<Vec<_>>()
.join("\n");
if result.is_error {
anyhow::bail!("MCP tool error: {}", output);
}
Ok(output)
}
pub async fn server_statuses(&self) -> Vec<(String, ServerStatus)> {
let servers = self.servers.read().await;
servers
.iter()
.map(|(name, client)| (name.clone(), client.status()))
.collect()
}
pub async fn wait_for_initial_connections(&self, timeout: Duration) {
let _ = tokio::time::timeout(timeout, self.initial_ready.notified()).await;
}
pub fn share(&self) -> Arc<Self> {
Arc::new(Self {
servers: self.servers.clone(),
server_timeouts_ms: self.server_timeouts_ms.clone(),
connect_events: self.connect_events.clone(),
initial_ready: self.initial_ready.clone(),
telemetry: self.telemetry.clone(),
})
}
}
fn classify_mcp_error(error: &str) -> McpErrorKind {
let e = error.to_lowercase();
if e.contains("connection refused") || e.contains("dns") || e.contains("network") {
McpErrorKind::NetworkError
} else if e.contains("401") || e.contains("403") || e.contains("unauthorized") || e.contains("oauth") {
McpErrorKind::AuthError
} else if e.contains("not found") || e.contains("no such") || e.contains("path") || e.contains("spawn") {
McpErrorKind::ExecutionFailed
} else if e.contains("timeout") || e.contains("timed out") {
McpErrorKind::Timeout
} else if e.contains("server") || e.contains("-326") || e.contains("mcp error") {
McpErrorKind::ServerError
} else {
McpErrorKind::Other
}
}
impl McpServerConfig {
fn timeout_ms(&self) -> u64 {
match &self.config {
super::config::McpTransportConfig::Stdio { timeout_ms, .. }
| super::config::McpTransportConfig::Http { timeout_ms, .. } => {
timeout_ms.unwrap_or(30_000)
}
}
}
}
impl Default for McpRegistry {
fn default() -> Self {
Self::new()
}
}