* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This software is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
use std::{
ffi::c_void,
mem, process,
sync::{
atomic::{AtomicBool, AtomicUsize, Ordering},
LazyLock,
},
thread,
time::Duration,
};
use anyhow::{bail, Context};
use parking_lot::Mutex;
use tracing::debug;
use xgpu_common::{
api_name::ApiFuncName,
ipc::{
framer::Framer,
message::{Argument, ArgumentFlag, Request},
peer::Server,
transport::Transport,
},
};
use crate::fault_guard::{bootstrapx, journal, recovery::NCCLAPI_COMMIDX_MAP, virt};
static GLOBAL_AGENT: LazyLock<Agent> =
LazyLock::new(|| Agent::new().expect("FATAL: Failed to initialize agent"));
static RECOVERY_FLAG: AtomicBool = AtomicBool::new(false);
static RECOVERY_TRIGGERED: AtomicBool = AtomicBool::new(false);
static ACTIVE_REQUEST_COUNT: AtomicUsize = AtomicUsize::new(0);
#[derive(Debug)]
pub struct Agent {
server: Mutex<Server>,
}
impl Agent {
#[inline]
pub fn get_instance() -> &'static Self {
&GLOBAL_AGENT
}
pub fn new() -> anyhow::Result<Self> {
const BUFFER_SIZE: usize = 512 * 1024 * 1024;
let addr = process::id().to_string();
let transport = Transport::new(BUFFER_SIZE);
let framer = Framer::new(BUFFER_SIZE);
let server = Server::create(&transport, framer, &addr)
.with_context(|| format!("Failed to create ipc server on '{}'", addr))?;
debug!("{:#?}", server);
Ok(Self {
server: Mutex::new(server),
})
}
pub fn invoke_api<T: Copy + 'static>(&self, mut req: Request) -> anyhow::Result<T> {
while RECOVERY_FLAG.load(Ordering::Acquire) {
thread::sleep(Duration::from_millis(1));
}
ACTIVE_REQUEST_COUNT.fetch_add(1, Ordering::Release);
debug!(
"[<-agent] invoke_api, request_id: {}, method_id: {}, arg_num: {}",
req.request_id(),
req.method_id(),
req.args().len()
);
let req_bytes = journal::serialize_request(&req)
.with_context(|| format!("Failed to serialize request {}", req.request_id()))?;
journal::add_minibatch_request(req_bytes)
.with_context(|| format!("Failed to record request {}", req.request_id()))?;
let server = self.server.lock();
let resp = server
.invoke(&req)
.with_context(|| format!("Failed to invoke request {}", req.request_id()))?;
if req.request_id() != resp.request_id() {
bail!("Request id mismatch");
}
if req.method_id() != resp.method_id() {
bail!("Method id mismatch");
}
let ret_arg = *resp.ret_value();
let ret_value = ret_arg.downcast::<T>().unwrap();
req.update_from(&resp)
.with_context(|| format!("Failed to update request {} parameters", req.request_id()))?;
debug!("[->agent] get response ok, update OUT args ok");
ACTIVE_REQUEST_COUNT.fetch_sub(1, Ordering::Release);
Ok(ret_value)
}
pub fn replay_api(&self, req: &mut Request) -> anyhow::Result<()> {
let req_id = req.request_id();
let method_id = req.method_id();
let args = req.args_mut();
for (i, arg) in args.iter_mut().enumerate() {
if arg
.flag()
.contains(ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT)
{
if method_id == ApiFuncName::Cudalaunchkernel as u64 && i >= 5 {
let param_data = unsafe {
arg.downcast_mut_slice::<u8>()
.context("Failed to downcast kernel params argument")
}?;
let param_size = param_data.len();
for j in (0..param_size).step_by(mem::size_of::<usize>()) {
if j + mem::size_of::<usize>() > param_size {
break;
}
let chunk = ¶m_data[j..j + mem::size_of::<usize>()];
let key = usize::from_ne_bytes(chunk.try_into().unwrap());
let value = if virt::is_vhandle_valid(key as *mut c_void) {
virt::handle_map(key as *mut c_void).unwrap_or_else(|_| {
panic!(
"Failed to map virtual handle: req_id {}, arg_idx: {}",
req_id, i
)
}) as usize
} else {
key
};
let value_bytes = value.to_ne_bytes();
param_data[j..j + mem::size_of::<usize>()].copy_from_slice(&value_bytes);
}
let new_arg = Argument::from_slice(
param_data,
ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT,
);
arg.update_from(&new_arg)
.context("Failed to update kernel params argument from virtual handles")?;
} else if NCCLAPI_COMMIDX_MAP
.iter()
.any(|(mid, comm_idx)| *mid == method_id && *comm_idx == i)
{
let new_comm = bootstrapx::get_comm();
let mut new_comm_usize = new_comm as usize;
let new_arg = Argument::from_mut(
&mut new_comm_usize,
ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT,
);
arg.update_from(&new_arg)
.context("Failed to update ncclcomm argument from new communicator")?;
} else {
let vh = arg
.downcast::<usize>()
.context("Failed to downcast argument to usize")?;
let ph = match virt::handle_map(vh as *mut c_void) {
Ok(p) => p,
Err(e) => {
if method_id == ApiFuncName::Cudapointergetattributes as u64 {
continue;
}
eprintln!("Error mapping virtual handle: {}", e);
bail!(
"Failed to map virtual handle: req_id: {}, method_id: {}, arg_idx: {}",
req_id,
method_id,
i
)
}
};
let mut ph_usize = ph as usize;
let new_arg = Argument::from_mut(
&mut ph_usize,
ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT,
);
arg.update_from(&new_arg)
.context("Failed to update argument from virtual handle")?;
}
}
}
let server = self.server.lock();
let resp = server.invoke(req).context("server_invoke failed")?;
req.update_from(&resp).context("update should succeed")?;
let args = req.args();
for (i, arg) in args.iter().enumerate() {
if arg
.flag()
.contains(ArgumentFlag::ARG_OUT | ArgumentFlag::ARG_VIRT)
{
let vh = virt::req_id_vhandle_map(req.request_id()).unwrap_or_else(|_| {
panic!(
"Failed to map request_id to virtual handle: req_id {}",
req_id
)
});
let ph = unsafe { arg.downcast_mut::<usize>().unwrap() } as *mut usize
as *mut *mut c_void;
let size = if req.method_id() == ApiFuncName::Cudamalloc as u64
|| req.method_id() == ApiFuncName::Cudahostalloc as u64
{
*req.args()[1].downcast_ref::<usize>().unwrap()
} else {
0
};
virt::handle_insert(vh, unsafe { *ph }, size).unwrap_or_else(|_| {
panic!(
"Failed to insert virtual handle: req_id: {}, method_id: {}, arg_idx: {}",
req_id, method_id, i
)
});
}
}
Ok(())
}
pub fn send_request(&self, req: &mut Request) -> anyhow::Result<()> {
let server = self.server.lock();
let resp = server.invoke(req).context("server_invoke failed")?;
req.update_from(&resp).context("update should succeed")?;
Ok(())
}
}
#[inline]
pub fn set_recovery_flag(value: bool) {
RECOVERY_FLAG.store(value, Ordering::Release);
}
#[inline]
pub fn set_recovery_triggered(value: bool) {
RECOVERY_TRIGGERED.store(value, Ordering::Release);
}
#[inline]
pub fn get_recovery_triggered() -> bool {
RECOVERY_TRIGGERED.load(Ordering::Acquire)
}
#[inline]
pub fn get_active_request_count() -> usize {
ACTIVE_REQUEST_COUNT.load(Ordering::Acquire)
}