* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This software is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
use std::ffi::{c_char, c_int, c_void, OsStr};
use once_cell::sync::Lazy;
use tracing::{debug, trace};
use xgpu_common::sys::dynlib;
use crate::hook::dl::DlApi;
mod interceptor {
use std::{
collections::HashSet,
ffi::{CStr, CString, OsString},
os::unix::ffi::OsStrExt,
path::{Path, PathBuf},
ptr,
};
use indexmap::IndexMap;
use parking_lot::RwLock;
use thiserror::Error;
use super::*;
#[derive(Debug, Error)]
pub enum Error {
#[error("Failed to resolve current library path")]
CurrentLibraryPathResolveError,
#[error(transparent)]
DlError(#[from] dynlib::DlError),
}
#[derive(Debug)]
pub struct DlInterceptor {
self_path: CString,
hook_files: HashSet<OsString>,
hooked_libs: RwLock<IndexMap<PathBuf, dynlib::DlHandle>>,
}
impl DlInterceptor {
#[inline]
fn normalize_library_name(path: &Path) -> &OsStr {
const LIB_EXT_BYTES: &[u8] = b".so";
const LIB_EXT_LEN: usize = LIB_EXT_BYTES.len();
let file_name = path.file_name().unwrap_or(path.as_os_str());
let name_bytes = file_name.as_bytes();
name_bytes
.windows(LIB_EXT_LEN)
.rposition(|w| w == LIB_EXT_BYTES)
.map(|pos| OsStr::from_bytes(&name_bytes[..pos + LIB_EXT_LEN]))
.unwrap_or(file_name)
}
pub fn new<I, S>(file_names: I) -> Result<Self, Error>
where
I: IntoIterator<Item = S>,
S: AsRef<OsStr>,
{
static LIB_MARKER: () = ();
let dl_info = dynlib::dladdr(ptr::from_ref(&LIB_MARKER).cast())?;
let self_path = dl_info
.file_name()
.ok_or(Error::CurrentLibraryPathResolveError)?
.to_owned();
let hook_files = file_names
.into_iter()
.map(|s| s.as_ref().to_owned())
.collect();
Ok(Self {
self_path,
hook_files,
hooked_libs: RwLock::new(IndexMap::new()),
})
}
pub fn dlopen(&self, filename: *const c_char, flag: c_int) -> *mut c_void {
const WSL_DRIVER_PATH: &str = "/usr/lib/wsl/drivers";
if filename.is_null() {
return unsafe { dynlib::ffi::dlopen(filename, flag) };
}
let filename = unsafe { CStr::from_ptr(filename) };
let lib_path = Path::new(OsStr::from_bytes(filename.to_bytes()));
let lib_name = Self::normalize_library_name(lib_path);
debug!("filename:{:?}", filename);
debug!("lib_path:{:?}", lib_path);
debug!("lib_name:{:?}", lib_name);
let is_wsl_driver = lib_path.starts_with(WSL_DRIVER_PATH);
let filename = if self.hook_files.contains(lib_name) && !is_wsl_driver {
debug!(
"[dlopen] Hooked '{}' -> '{}'.",
lib_path.display(),
self.self_path.to_string_lossy()
);
if !self.hooked_libs.read().contains_key(lib_path) {
match dynlib::dlopen(lib_path, dynlib::DlOpenFlag::LAZY) {
Ok(handle) => {
self.hooked_libs
.write()
.entry(lib_path.to_owned())
.or_insert(handle);
}
Err(_) => {
return ptr::null_mut();
}
}
}
self.self_path.as_ptr()
} else {
trace!("[dlopen] Bypassing '{}'.", lib_path.to_string_lossy());
filename.as_ptr()
};
unsafe { dynlib::ffi::dlopen(filename, flag) }
}
pub fn find_original_symbol<S>(&self, name: S) -> Option<dynlib::DlSymbol>
where
S: AsRef<OsStr>,
{
let loaded_libs = self.hooked_libs.read();
for handle in loaded_libs.values().rev() {
let result = handle.find_symbol(&name).ok().flatten();
if result.is_some() {
return result;
}
}
let result = dynlib::dlsym(&dynlib::NEXT_HANDLE, name).ok().flatten();
if result.is_some() {
return result;
}
None
}
}
}
static DL_INTERCEPTOR: Lazy<interceptor::DlInterceptor> = Lazy::new(|| {
const FILE_NAMES: &[&str] = &[
"libcuda.so",
"libcudart.so",
"libnvidia-ml.so",
"libnccl.so",
"libcublas.so",
"libcublasLt.so",
];
let interceptor = interceptor::DlInterceptor::new(FILE_NAMES)
.expect("FATAL: Failed to initialize dynamic library interceptor.");
debug!(
"Dynamic library interceptor initialized: {:#?}",
interceptor
);
interceptor
});
pub struct DlApiImpl;
impl DlApi for DlApiImpl {
fn dlopen(&self, filename: *const c_char, flag: c_int) -> *mut c_void {
DL_INTERCEPTOR.dlopen(filename, flag)
}
}
#[inline]
pub fn find_original_symbol<S: AsRef<OsStr>>(name: S) -> Option<dynlib::DlSymbol> {
DL_INTERCEPTOR.find_original_symbol(name)
}