use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex as StdMutex, OnceLock};
use std::time::Duration;
use async_trait::async_trait;
use atomcode_core::config::Config;
use atomcode_core::conversation::Conversation;
use atomcode_core::live::{LiveEvent, TurnExecutor, TurnState, UserInput};
use atomcode_core::conversation::message::ImagePart;
use atomcode_core::mcp::{register_mcp_tools, McpRegistry};
use atomcode_core::provider;
use atomcode_core::tool::diagnostics::DiagnosticsTool;
use atomcode_core::tool::{ToolContext, ToolRegistry};
use atomcode_core::lsp::manager::build_lsp_manager;
use atomcode_core::turn::event::{TurnEvent, TurnResult};
use atomcode_core::tool::PermissionDecision;
use atomcode_core::turn::permission::{
ApprovalRequest, AutoPermissionDecider, AutoPermissionMode,
InteractivePermissionDecider, PermissionDecider,
};
use atomcode_core::turn::runner::TurnRunner;
use atomcode_telemetry::Telemetry;
use tokio::sync::{broadcast, mpsc, Mutex, RwLock};
use tokio_util::sync::CancellationToken;
use crate::CachedMcpRegistry;
static LIVE: StdMutex<Option<Arc<atomcode_core::live::LiveSession>>> = StdMutex::new(None);
static LIVE_SESSION_ID: StdMutex<Option<String>> = StdMutex::new(None);
static LIVE_PROVIDER: StdMutex<Option<String>> = StdMutex::new(None);
fn set_live_provider(provider: Option<String>) {
if let Some(p) = provider {
live_set_provider(p);
}
}
pub fn live_set_provider(provider: String) {
*LIVE_PROVIDER.lock().unwrap() = Some(provider.clone());
if let Some(s) = current_live_session() {
s.notify_provider_changed(provider);
}
}
fn live_current_provider() -> String {
if let Some(p) = LIVE_PROVIDER.lock().unwrap().clone() {
return p;
}
Config::load(&Config::default_path())
.map(|c| c.default_provider)
.unwrap_or_default()
}
static LIVE_MCP_CACHE: OnceLock<Arc<tokio::sync::RwLock<std::collections::HashMap<std::path::PathBuf, crate::CachedMcpRegistry>>>> = OnceLock::new();
fn live_mcp_cache() -> Arc<tokio::sync::RwLock<std::collections::HashMap<std::path::PathBuf, crate::CachedMcpRegistry>>> {
LIVE_MCP_CACHE
.get_or_init(|| Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())))
.clone()
}
pub fn current_live_session() -> Option<Arc<atomcode_core::live::LiveSession>> {
LIVE.lock().unwrap().clone()
}
pub fn ensure_live_session(
working_dir: std::path::PathBuf,
telemetry: Arc<atomcode_telemetry::Telemetry>,
) -> Arc<atomcode_core::live::LiveSession> {
ensure_live_session_global(working_dir, live_mcp_cache(), telemetry, Vec::new(), None)
}
pub fn ensure_live_session_seeded(
working_dir: std::path::PathBuf,
telemetry: Arc<atomcode_telemetry::Telemetry>,
initial: Vec<atomcode_core::conversation::message::Message>,
session_id: Option<atomcode_core::session::SessionId>,
) -> Arc<atomcode_core::live::LiveSession> {
ensure_live_session_global(working_dir, live_mcp_cache(), telemetry, initial, session_id)
}
pub(crate) fn ensure_live_session_global(
working_dir: std::path::PathBuf,
mcp_cache: Arc<tokio::sync::RwLock<std::collections::HashMap<std::path::PathBuf, crate::CachedMcpRegistry>>>,
telemetry: Arc<atomcode_telemetry::Telemetry>,
initial: Vec<atomcode_core::conversation::message::Message>,
session_id: Option<atomcode_core::session::SessionId>,
) -> Arc<atomcode_core::live::LiveSession> {
let mut g = LIVE.lock().unwrap();
if let Some(s) = g.as_ref() {
return s.clone();
}
let session_id = session_id.unwrap_or_else(atomcode_core::session::SessionId::new);
*LIVE_SESSION_ID.lock().unwrap() = Some(session_id.to_string());
let executor: Arc<dyn atomcode_core::live::TurnExecutor> = Arc::new(DaemonTurnExecutor {
working_dir,
provider_name: None,
mcp_cache,
telemetry,
auto_approve: false,
session_id,
});
let session = atomcode_core::live::LiveSession::new(executor, initial);
*g = Some(session.clone());
session
}
fn live_session_id() -> Option<String> {
LIVE_SESSION_ID.lock().unwrap().clone()
}
pub(crate) struct TurnParts {
pub provider: Arc<dyn atomcode_core::provider::LlmProvider>,
pub tools: Arc<ToolRegistry>,
pub context: ToolContext,
pub config: Config,
pub ctx: Arc<dyn atomcode_core::ctx::CtxBuilder>,
pub system_prompt: String,
}
pub(crate) async fn build_turn_parts(
working_dir: &Path,
provider_name: Option<&str>,
mcp_cache: &Arc<RwLock<HashMap<PathBuf, CachedMcpRegistry>>>,
telemetry: Arc<Telemetry>,
) -> anyhow::Result<TurnParts> {
use atomcode_core::tool::{
bash::BashTool, edit::EditFileTool, glob::GlobTool, grep::GrepTool,
list_dir::ListDirTool, read::ReadFileTool, search_replace::SearchReplaceTool,
todo::TodoTool, web_fetch::WebFetchTool, web_search::WebSearchTool,
write::WriteFileTool,
};
let config_path = Config::default_path();
let config = Config::load(&config_path)?;
let resolved_provider_name = provider_name
.map(|s| s.to_string())
.unwrap_or_else(|| config.default_provider.clone());
let provider_config = config
.providers
.get(&resolved_provider_name)
.ok_or_else(|| anyhow::anyhow!("Provider '{}' not found", resolved_provider_name))?;
let provider = provider::create_provider(provider_config)?;
let mut tool_context = ToolContext::with_telemetry(
working_dir.to_path_buf(),
"live",
telemetry,
);
let mut tool_registry = ToolRegistry::new();
let disabled_tools: std::collections::HashSet<String> =
std::env::var("ATOMCODE_DISABLE_TOOLS")
.ok()
.map(|v| {
v.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect()
})
.unwrap_or_default();
let enabled = |name: &str| !disabled_tools.contains(name);
if enabled("read_file") {
tool_registry.register_sync(Box::new(ReadFileTool));
}
if enabled("write_file") {
tool_registry.register_sync(Box::new(WriteFileTool));
}
if enabled("edit_file") {
tool_registry.register_sync(Box::new(EditFileTool));
}
if enabled("bash") {
tool_registry.register_sync(Box::new(BashTool));
}
if enabled("grep") {
tool_registry.register_sync(Box::new(GrepTool));
}
if enabled("glob") {
tool_registry.register_sync(Box::new(GlobTool));
}
if enabled("list_directory") {
tool_registry.register_sync(Box::new(ListDirTool));
}
if enabled("web_search") {
tool_registry.register_sync(Box::new(WebSearchTool));
}
if enabled("web_fetch") {
tool_registry.register_sync(Box::new(WebFetchTool));
}
if enabled("search_replace") {
tool_registry.register_sync(Box::new(SearchReplaceTool));
}
if enabled("todo") {
tool_registry.register_sync(Box::new(TodoTool::new()));
}
let mut skill_registry = atomcode_core::skill::SkillRegistry::new();
skill_registry.reload(working_dir);
let has_skills = !skill_registry.is_empty();
let skill_registry = Arc::new(std::sync::RwLock::new(skill_registry));
if has_skills && enabled("use_skill") {
tool_registry.register_sync(Box::new(atomcode_core::tool::use_skill::UseSkillTool {
registry: skill_registry.clone(),
}));
}
let working_dir_buf = working_dir.to_path_buf();
let mcp_registry: Arc<McpRegistry> = {
let cache = mcp_cache.read().await;
if let Some(cached) = cache.get(&working_dir_buf) {
cached.registry.clone()
} else {
drop(cache);
let new_registry = Arc::new(McpRegistry::from_config_background(&working_dir_buf));
new_registry
.wait_for_initial_connections(Duration::from_secs(5))
.await;
let mut cache = mcp_cache.write().await;
if cache.len() >= crate::MCP_CACHE_MAX {
if let Some(oldest_key) = cache
.iter()
.min_by_key(|(_, v)| v.last_used)
.map(|(k, _)| k.clone())
{
cache.remove(&oldest_key);
}
}
cache.insert(
working_dir_buf.clone(),
CachedMcpRegistry {
registry: new_registry.clone(),
last_used: std::time::Instant::now(),
},
);
new_registry
}
};
{
let mut cache = mcp_cache.write().await;
if let Some(entry) = cache.get_mut(&working_dir_buf) {
entry.last_used = std::time::Instant::now();
}
}
let mcp_tools = mcp_registry.list_all_tools().await;
if !mcp_tools.is_empty() {
register_mcp_tools(&mut tool_registry, mcp_registry.clone(), mcp_tools);
}
let lsp_manager = build_lsp_manager(&config.lsp, working_dir);
if lsp_manager.is_some() && enabled("diagnostics") {
tool_registry.register_sync(Box::new(DiagnosticsTool));
}
tool_context.lsp = lsp_manager;
let ctx = match config.providers.get(&resolved_provider_name) {
Some(pc) => atomcode_core::ctx::for_provider(pc),
None => atomcode_core::ctx::for_provider(
&atomcode_core::config::provider::ProviderConfig {
provider_type: String::new(),
api_key: None,
model: String::new(),
base_url: None,
system_prompt: None,
user_agent: None,
context_window: 128_000,
max_tokens: None,
thinking_type: None,
thinking_keep: None,
reasoning_history: None,
thinking_enabled: None,
thinking_budget: None,
skip_tls_verify: false,
ephemeral: true,
},
),
};
let system_prompt =
crate::build_api_system_prompt(&working_dir_buf, &config, provider_config, &skill_registry);
Ok(TurnParts {
provider: provider.into(),
tools: Arc::new(tool_registry),
context: tool_context,
config,
ctx,
system_prompt,
})
}
pub(crate) struct DaemonTurnExecutor {
pub working_dir: PathBuf,
pub provider_name: Option<String>,
pub mcp_cache: Arc<RwLock<HashMap<PathBuf, CachedMcpRegistry>>>,
pub telemetry: Arc<Telemetry>,
pub auto_approve: bool,
pub session_id: atomcode_core::session::SessionId,
}
#[async_trait]
impl TurnExecutor for DaemonTurnExecutor {
async fn preprocess_input(&self, input: UserInput) -> UserInput {
if input.images.is_empty() {
return input;
}
let live_provider = LIVE_PROVIDER.lock().unwrap().clone();
let provider_name = live_provider.as_deref().or(self.provider_name.as_deref());
let text = preprocess_live_caption(&input.text, &input.images, provider_name).await;
UserInput { text, images: input.images }
}
async fn run_turn(
&self,
conv: &Arc<Mutex<Conversation>>,
events: broadcast::Sender<LiveEvent>,
approver: Arc<Mutex<Option<mpsc::UnboundedSender<PermissionDecision>>>>,
cancel: CancellationToken,
) {
let live_provider = LIVE_PROVIDER.lock().unwrap().clone();
let provider_name = live_provider.as_deref().or(self.provider_name.as_deref());
let parts = match build_turn_parts(
&self.working_dir,
provider_name,
&self.mcp_cache,
self.telemetry.clone(),
)
.await
{
Ok(p) => p,
Err(e) => {
let _ = events.send(LiveEvent::Turn(TurnEvent::Error(format!("构造 turn 失败:{e}"))));
return;
}
};
let (permission, _perm_req_keep): (Box<dyn PermissionDecider>, Option<_>) =
if self.auto_approve {
(
Box::new(AutoPermissionDecider::new(AutoPermissionMode::BypassAll)),
None,
)
} else {
let (perm_req_tx, perm_req_rx) =
tokio::sync::mpsc::unbounded_channel::<ApprovalRequest>();
let (perm_resp_tx, perm_resp_rx) =
tokio::sync::mpsc::unbounded_channel::<PermissionDecision>();
*approver.lock().await = Some(perm_resp_tx);
let perm_store = std::sync::Arc::new(std::sync::RwLock::new(
atomcode_core::tool::PermissionStore::new(),
));
(
Box::new(InteractivePermissionDecider::new(
perm_req_tx,
perm_resp_rx,
perm_store,
)),
Some(perm_req_rx),
)
};
let mut hook_engine = atomcode_core::hook::HookEngine::new();
hook_engine.load_all(&self.working_dir);
let mut runner = TurnRunner {
provider: parts.provider,
tools: parts.tools,
context: parts.context,
config: parts.config,
ctx: parts.ctx,
permission,
recently_edited_files: Vec::new(),
hook_engine: std::sync::Arc::new(hook_engine),
loop_guard: Default::default(),
current_turn_number: 0,
};
let (turn_tx, mut turn_rx) = mpsc::unbounded_channel::<TurnEvent>();
let ev2 = events.clone();
let forward = tokio::spawn(async move {
while let Some(te) = turn_rx.recv().await {
let _ = ev2.send(LiveEvent::Turn(te));
}
});
{
let mut c = conv.lock().await;
loop {
let result = runner
.run(&mut c, &parts.system_prompt, &turn_tx, cancel.clone())
.await;
match result {
TurnResult::UsedTools { .. } => continue,
TurnResult::Responded { .. } | TurnResult::Cancelled => break,
TurnResult::Failed(e) => {
let _ = turn_tx.send(TurnEvent::Error(e));
break;
}
}
}
}
drop(turn_tx);
let _ = forward.await;
{
use atomcode_core::session::{Session, SessionManager};
let conv_guard = conv.lock().await;
let mut session = Session::new(self.working_dir.clone());
session.id = self.session_id.clone();
session.messages = conv_guard.messages.clone();
session.auto_name_from_messages();
session.touch();
if let Err(e) = SessionManager::new(&self.working_dir).save(&session) {
eprintln!("Warning: failed to save live session: {e}");
}
}
}
}
use axum::{
extract::State,
response::{
sse::{Event, Sse, KeepAlive},
IntoResponse,
Json,
},
};
use futures::stream::StreamExt;
use serde::Serialize;
use crate::AppState;
#[derive(Serialize)]
#[serde(tag = "type")]
pub(crate) enum LiveWireEvent {
#[serde(rename = "snapshot")]
Snapshot { messages: Vec<crate::MessageInfo>, session_id: String, project_hash: String, provider: String },
#[serde(rename = "provider")]
Provider { provider: String },
#[serde(rename = "user")]
UserMessage { text: String, images: Vec<crate::ImageData> },
#[serde(rename = "text")]
TextDelta { content: String },
#[serde(rename = "reasoning")]
ReasoningDelta { content: String },
#[serde(rename = "tool_start")]
ToolStart { id: String, name: String, arguments: String },
#[serde(rename = "tool_output")]
ToolOutput { chunk: String },
#[serde(rename = "tool_result")]
ToolResult { id: String, name: String, output: String, success: bool, duration_ms: u64 },
#[serde(rename = "tokens")]
Tokens { prompt: usize, completion: usize, total: usize },
#[serde(rename = "state")]
State { running: bool },
#[serde(rename = "error")]
Error { message: String },
#[serde(rename = "permission_request")]
PermissionRequest { tool_name: String, reason: String, call_id: String, arguments: String },
}
fn to_wire(ev: LiveEvent) -> Option<LiveWireEvent> {
use atomcode_core::turn::event::TurnEvent as TE;
Some(match ev {
LiveEvent::UserMessage { text, images } => LiveWireEvent::UserMessage {
text,
images: images.into_iter().map(|i| crate::ImageData { media_type: i.media_type, data: i.data }).collect(),
},
LiveEvent::StateChanged(s) => LiveWireEvent::State { running: matches!(s, TurnState::Running) },
LiveEvent::ProviderChanged(p) => LiveWireEvent::Provider { provider: p },
LiveEvent::Turn(te) => match te {
TE::TextDelta(content) => LiveWireEvent::TextDelta { content },
TE::ReasoningDelta(content) => LiveWireEvent::ReasoningDelta { content },
TE::ToolCallStarted { id, name, arguments } => LiveWireEvent::ToolStart { id, name, arguments },
TE::ToolOutputChunk { call_id: _, chunk } => LiveWireEvent::ToolOutput { chunk },
TE::ToolCallResult { call_id, name, output, success, duration } =>
LiveWireEvent::ToolResult { id: call_id, name, output, success, duration_ms: duration.as_millis() as u64 },
TE::TokenUsage { prompt_tokens, completion_tokens, total_tokens, .. } =>
LiveWireEvent::Tokens { prompt: prompt_tokens, completion: completion_tokens, total: total_tokens },
TE::Error(message) => LiveWireEvent::Error { message },
TE::Warning(w) => LiveWireEvent::Error { message: format!("[warning] {w}") },
TE::ApprovalRequested { tool_name, reason, call, .. } =>
LiveWireEvent::PermissionRequest {
tool_name,
reason,
call_id: call.id,
arguments: call.arguments,
},
TE::ToolCallStreaming { .. } | TE::ToolBatchStarted { .. } | TE::ToolBatchCompleted { .. }
| TE::ContextStats { .. } | TE::WorkingDirChanged(_) => return None,
},
})
}
pub(crate) async fn live_stream(State(state): State<AppState>) -> impl IntoResponse {
let working_dir = { state.project.read().await.working_dir.clone() };
let project_hash = crate::hash_path(&working_dir);
let session = ensure_live_session(working_dir, state.telemetry.clone());
let (snapshot, mut rx) = session.join().await;
let (tx, out_rx) = mpsc::unbounded_channel::<LiveWireEvent>();
let _ = tx.send(LiveWireEvent::Snapshot {
messages: snapshot.iter().map(crate::MessageInfo::from).collect(),
session_id: live_session_id().unwrap_or_default(),
project_hash,
provider: live_current_provider(),
});
tokio::spawn(async move {
loop {
match rx.recv().await {
Ok(ev) => { if let Some(w) = to_wire(ev) { if tx.send(w).is_err() { break; } } }
Err(broadcast::error::RecvError::Lagged(_)) => continue,
Err(broadcast::error::RecvError::Closed) => break,
}
}
});
let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(out_rx).map(|w| {
let json = serde_json::to_string(&w).unwrap_or_default();
Ok::<_, std::convert::Infallible>(Event::default().data(json))
});
Sse::new(stream).keep_alive(KeepAlive::new().interval(std::time::Duration::from_secs(15)).text("ping"))
}
#[derive(serde::Deserialize)]
pub(crate) struct LiveMessageReq {
pub message: String,
#[serde(default)]
pub images: Vec<crate::ImageInput>,
#[serde(default)]
pub provider: Option<String>,
}
async fn preprocess_live_caption(
message: &str,
images: &[ImagePart],
provider_name: Option<&str>,
) -> String {
use atomcode_core::vision_preprocessor::{maybe_preprocess, PreprocessOutcome};
if images.is_empty() {
return message.to_string();
}
let config = match Config::load(&Config::default_path()) {
Ok(c) => c,
Err(_) => return message.to_string(),
};
let name = provider_name
.map(str::to_string)
.unwrap_or_else(|| config.default_provider.clone());
let active = match config.providers.get(&name).map(provider::create_provider) {
Some(Ok(p)) => p,
_ => return message.to_string(),
};
match maybe_preprocess(&config, &*active, message, images).await {
PreprocessOutcome::Skipped => message.to_string(),
PreprocessOutcome::Replaced { text, vl_key } => {
if message.trim().is_empty() {
format!("[图片内容(由 {vl_key} 识别)]\n{text}")
} else {
format!("{message}\n\n[图片内容(由 {vl_key} 识别)]\n{text}")
}
}
PreprocessOutcome::Failed { .. } => {
if message.trim().is_empty() {
"[图片识别失败]".to_string()
} else {
format!("{message}\n\n[图片识别失败]")
}
}
}
}
pub(crate) async fn live_message(State(state): State<AppState>, Json(req): Json<LiveMessageReq>) -> impl IntoResponse {
let working_dir = { state.project.read().await.working_dir.clone() };
set_live_provider(req.provider);
let session = ensure_live_session(working_dir, state.telemetry.clone());
let ok = session.send_input(UserInput {
text: req.message,
images: req.images.into_iter().map(|i| ImagePart { media_type: i.media_type, data: i.data }).collect(),
});
Json(serde_json::json!({ "accepted": ok }))
}
#[derive(serde::Deserialize)]
pub(crate) struct LiveProviderReq {
pub provider: String,
}
pub(crate) async fn live_provider(
State(state): State<AppState>,
Json(req): Json<LiveProviderReq>,
) -> impl IntoResponse {
if let Ok(mut cfg) = Config::load(&Config::default_path()) {
if cfg.providers.contains_key(&req.provider) && cfg.default_provider != req.provider {
cfg.default_provider = req.provider.clone();
let _ = cfg.save(&Config::default_path());
}
}
let working_dir = { state.project.read().await.working_dir.clone() };
ensure_live_session(working_dir, state.telemetry.clone());
live_set_provider(req.provider);
Json(serde_json::json!({ "ok": true }))
}
#[derive(serde::Deserialize)]
pub(crate) struct LivePermissionReq {
pub decision: String,
}
pub(crate) async fn live_permission(
State(state): State<AppState>,
Json(req): Json<LivePermissionReq>,
) -> impl IntoResponse {
use atomcode_core::tool::PermissionDecision;
let decision = match req.decision.as_str() {
"allow" | "always_allow" => PermissionDecision::Allow,
_ => PermissionDecision::Deny,
};
let working_dir = { state.project.read().await.working_dir.clone() };
let ok = match current_live_session() {
Some(s) => s.approve(decision).await,
None => {
ensure_live_session(working_dir, state.telemetry.clone());
false
}
};
Json(serde_json::json!({ "accepted": ok }))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn live_message_parses_provider_and_updates_override() {
let req: LiveMessageReq =
serde_json::from_str(r#"{"message":"hi","provider":"openai"}"#).unwrap();
assert_eq!(req.provider.as_deref(), Some("openai"));
set_live_provider(req.provider);
assert_eq!(LIVE_PROVIDER.lock().unwrap().as_deref(), Some("openai"));
let req2: LiveMessageReq = serde_json::from_str(r#"{"message":"hi"}"#).unwrap();
assert_eq!(req2.provider, None);
set_live_provider(req2.provider);
assert_eq!(LIVE_PROVIDER.lock().unwrap().as_deref(), Some("openai"));
}
#[tokio::test]
async fn preprocess_live_caption_is_passthrough_without_images() {
let out = preprocess_live_caption("看下这个图片", &[], None).await;
assert_eq!(out, "看下这个图片");
}
}