* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This software is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
use std::cmp::Ordering;
use std::ffi::{c_int, c_void};
use std::fs;
use std::thread::sleep;
use std::time::Duration;
use cudax::nccl;
use smallvec::smallvec;
use tracing::debug;
use xgpu_common::ipc::message::{ArgumentFlag, Request};
use xgpu_common::{api_name::ApiFuncName, ipc::message::Argument};
use crate::agent::Agent;
use crate::fault_guard::bootstrapx::get_rank_count;
use crate::fault_guard::{bootstrapx, journal, virt};
const RECOVERY_FAILED_SKIP_APIS: &[ApiFuncName] = &[
ApiFuncName::Cudasetdevice,
ApiFuncName::Ncclcomminitrankconfig,
ApiFuncName::Ncclbcast,
ApiFuncName::Ncclallreduce,
ApiFuncName::Ncclallgather,
ApiFuncName::Cudaeventquery,
ApiFuncName::Cudeviceprimaryctxgetstate,
ApiFuncName::Loaddynlibrary,
];
pub const NCCLAPI_COMMIDX_MAP: &[(u64, usize)] = &[
(ApiFuncName::Ncclcommdestroy as u64, 0),
(ApiFuncName::Ncclcommabort as u64, 0),
(ApiFuncName::Ncclcommsplit as u64, 0),
(ApiFuncName::Ncclcommgetasyncerror as u64, 0),
(ApiFuncName::Ncclbcast as u64, 4),
(ApiFuncName::Ncclallreduce as u64, 5),
(ApiFuncName::Ncclallgather as u64, 4),
(ApiFuncName::Ncclsend as u64, 4),
(ApiFuncName::Ncclrecv as u64, 4),
(ApiFuncName::Ncclreduce as u64, 6),
(ApiFuncName::Ncclcommfinalize as u64, 0),
];
pub fn barrier(rank: i32, comment: &str) {
let barrier_file = format!("/tmp/rank_{}_{}", rank, comment);
fs::write(&barrier_file, b"ready").ok();
loop {
let mut all_ready = true;
for r in 0..get_rank_count() as i32 {
if r == rank {
continue;
}
let other_barrier_file = format!("/tmp/rank_{}_{}", r, comment);
if fs::metadata(&other_barrier_file).is_err() {
all_ready = false;
break;
}
}
if all_ready {
break;
}
sleep(Duration::from_millis(100));
}
}
pub fn recovery_nccl(excluded_rank: i32) -> anyhow::Result<()> {
let mut rank = bootstrapx::get_global_rank();
debug!("[Recovery] Rank {} starts nccl recovery...", rank);
debug!(
"[Recovery] Rank {} exclude failed rank {}",
rank, excluded_rank
);
let mut old_comm = bootstrapx::get_comm();
old_comm = virt::handle_map(old_comm as *mut c_void)? as nccl::ncclComm_t;
let mut new_comm = nccl::ncclComm_t::default();
if rank == excluded_rank {
let mut device_reset_req =
Request::with_args(ApiFuncName::Cudadevicereset as u64, smallvec![]);
Agent::get_instance().send_request(&mut device_reset_req)?;
} else {
let excluded_ranks = &mut [excluded_rank];
let exclude_ranks_count = excluded_ranks.len() as i32;
let shrink_flags = nccl::NCCL_SHRINK_DEFAULT as i32;
{
let mut shrink_req = Request::with_args(
ApiFuncName::Ncclcommshrink as u64,
smallvec![
Argument::from_ref(&old_comm, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
Argument::from_mut_slice(excluded_ranks, ArgumentFlag::ARG_IN),
Argument::from_ref(&exclude_ranks_count, ArgumentFlag::ARG_IN),
Argument::from_mut(
&mut new_comm,
ArgumentFlag::ARG_OUT | ArgumentFlag::ARG_VIRT,
),
Argument::empty(),
Argument::from_ref(&shrink_flags, ArgumentFlag::ARG_IN),
],
);
Agent::get_instance().send_request(&mut shrink_req)?;
}
}
barrier(rank, "shrink");
let new_rank = get_rank_count() as i32 - 1;
match rank.cmp(&excluded_rank) {
Ordering::Greater => {
debug!("[Recovery] Rank {} updating rank id to {}", rank, rank - 1);
rank -= 1;
bootstrapx::set_global_rank(rank);
}
Ordering::Equal => {
debug!(
"[Recovery] Rank {} will be replaced by a new id {}",
rank, new_rank
);
rank = new_rank;
bootstrapx::set_global_rank(rank);
}
Ordering::Less => {}
}
debug!("[Recovery] Rank {} adding new rank {}", rank, new_rank);
if rank == new_rank {
let rank_count = bootstrapx::get_rank_count() as c_int;
let mut init_new_req = Request::with_args(
ApiFuncName::Ncclcomminitnewrank as u64,
smallvec![
Argument::from_mut(
&mut new_comm,
ArgumentFlag::ARG_OUT | ArgumentFlag::ARG_VIRT,
),
Argument::from_ref(&rank_count, ArgumentFlag::ARG_IN),
],
);
Agent::get_instance().send_request(&mut init_new_req)?;
} else {
let mut add_new_req = Request::with_args(
ApiFuncName::Ncclcommaddnewrank as u64,
smallvec![Argument::from_ref(
&new_comm,
ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT,
)],
);
Agent::get_instance().send_request(&mut add_new_req)?;
}
barrier(rank, "add");
let mut setup_new_req = Request::with_args(
ApiFuncName::Ncclcommsetupnewrank as u64,
smallvec![Argument::from_ref(
&new_comm,
ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT,
)],
);
Agent::get_instance().send_request(&mut setup_new_req)?;
barrier(rank, "setup");
debug!(
"[Recovery] Rank {}, old comm 0x{:x}, new comm 0x{:x}",
rank, old_comm, new_comm
);
virt::handle_insert(old_comm as *mut c_void, new_comm as *mut c_void, 0)?;
bootstrapx::set_comm(new_comm);
debug!(
"[Recovery] Rank {} nccl recovery completed successfully",
rank
);
Ok(())
}
pub fn recovery_failed() -> anyhow::Result<()> {
let rank = bootstrapx::get_global_rank();
debug!("[Recovery] Rank {} starts replaying (failed rank)", rank);
while let Some(mut req_bytes) = journal::load_next_request()? {
let mut req: Request = journal::deserialize_request(&mut req_bytes)?;
let req_id = req.request_id();
let method_id = req.method_id();
let api = ApiFuncName::try_from(method_id as i32)?;
debug!(
"[Recovery] Rank {} Stage 1, replaying api: request_id={}, api={:?}",
rank, req_id, api
);
if !RECOVERY_FAILED_SKIP_APIS.contains(&api) {
Agent::get_instance().replay_api(&mut req)?;
}
}
while let Some(mut req_bytes) = journal::fetch_minibatch_request()? {
let mut req: Request = journal::deserialize_request(&mut req_bytes)?;
let req_id = req.request_id();
let method_id = req.method_id();
let api = ApiFuncName::try_from(method_id as i32)?;
debug!(
"[Recovery] Rank {} Stage 2, replaying api: request_id={}, api={:?}",
rank, req_id, api
);
Agent::get_instance().replay_api(&mut req)?;
}
Ok(())
}