* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This software is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
use std::ffi::{c_int, c_longlong, c_void};
use smallvec::smallvec;
use tracing::{debug, info};
use cudax::cublas::*;
use xgpu_common::{
api_name::ApiFuncName,
ipc::message::{Argument, ArgumentFlag, Request},
};
use crate::{agent::Agent, fault_guard::virt, hook::cublas::CublasApi};
fn get_compute_type_size(compute_type: u32) -> usize {
#[allow(non_upper_case_globals)]
match compute_type {
CUBLAS_COMPUTE_16F | CUBLAS_COMPUTE_16F_PEDANTIC => 2,
CUBLAS_COMPUTE_32F
| CUBLAS_COMPUTE_32F_PEDANTIC
| CUBLAS_COMPUTE_32F_FAST_16F
| CUBLAS_COMPUTE_32F_FAST_16BF
| CUBLAS_COMPUTE_32F_FAST_TF32 => 4,
CUBLAS_COMPUTE_64F | CUBLAS_COMPUTE_64F_PEDANTIC => 8,
CUBLAS_COMPUTE_32I | CUBLAS_COMPUTE_32I_PEDANTIC => 4,
_ => 8,
}
}
pub struct CublasApiImpl;
impl CublasApi for CublasApiImpl {
fn cublasCreate_v2(&self, handle: *mut cublasHandle_t) -> cublasStatus_t {
info!("[Hooked] api_name: cublasCreate_v2");
debug!("before: handle ptr:{:p}, *handle:{:?}", handle, unsafe {
*handle
});
let mut handle_usize = unsafe { *handle };
let req = Request::with_args(
ApiFuncName::CublascreateV2 as u64,
smallvec![Argument::from_mut(
&mut handle_usize,
ArgumentFlag::ARG_OUT | ArgumentFlag::ARG_VIRT,
)],
);
let request_id = req.request_id();
let res = Agent::get_instance()
.invoke_api::<cublasStatus_t>(req)
.expect("call invoke_api failed");
debug!("after: handle ptr:{:p}, *handle:{:?}", handle, unsafe {
*handle
});
unsafe {
*handle = handle_usize as cublasHandle_t;
virt::handle_insert(*handle as *mut c_void, *handle as *mut c_void, 0)
.expect("handle_insert failed");
virt::req_id_vhandle_insert(request_id, *handle as *mut c_void)
.expect("req_id_vhandle_insert failed");
}
res
}
fn cublasSetWorkspace_v2(
&self,
handle: cublasHandle_t,
workspace: *mut c_void,
workspace_size_in_bytes: usize,
) -> cublasStatus_t {
info!("[Hooked] api_name: cublasSetWorkspace_v2");
let handle_usize = virt::handle_map(handle as *mut c_void).expect("handle_map failed")
as cublasHandle_t as usize;
let workspace_usize = virt::handle_map(workspace).expect("handle_map failed") as usize;
let req = Request::with_args(
ApiFuncName::CublassetworkspaceV2 as u64,
smallvec![
Argument::from_ref(&handle_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
Argument::from_ref(
&workspace_usize,
ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT,
),
Argument::from_ref(&workspace_size_in_bytes, ArgumentFlag::ARG_IN),
],
);
let res = Agent::get_instance()
.invoke_api::<cublasStatus_t>(req)
.expect("call invoke_api failed");
debug!(
"workspace:{:p}, size:{}",
workspace, workspace_size_in_bytes
);
res
}
fn cublasSetStream_v2(
&self,
handle: cublasHandle_t,
stream_id: cudaStream_t,
) -> cublasStatus_t {
info!("[Hooked] api_name: cublasSetStream_v2");
let handle_usize = virt::handle_map(handle as *mut c_void).expect("handle_map failed")
as cublasHandle_t as usize;
let stream_id_usize = virt::handle_map(stream_id as *mut c_void).expect("handle_map failed")
as cudaStream_t as usize;
let req = Request::with_args(
ApiFuncName::CublassetstreamV2 as u64,
smallvec![
Argument::from_ref(&handle_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
Argument::from_ref(
&stream_id_usize,
ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT,
),
],
);
let res = Agent::get_instance()
.invoke_api::<cublasStatus_t>(req)
.expect("call invoke_api failed");
debug!(" client-2 handle ptr:{:?}", handle);
debug!(" client-2 stream_id ptr:{:?}", stream_id);
res
}
fn cublasGetStream_v2(
&self,
handle: cublasHandle_t,
stream_id: *mut cudaStream_t,
) -> cublasStatus_t {
info!("[Hooked] api_name: cublasGetStream_v2");
debug!(" client handle ptr:{:?}", handle);
debug!(
" client1: stream_id ptr:{:p}, *ptr:{:?} ",
stream_id,
unsafe { *stream_id }
);
let handle_usize = virt::handle_map(handle as *mut c_void).expect("handle_map failed")
as cublasHandle_t as usize;
let mut stream_id_usize = unsafe { *stream_id };
let req = Request::with_args(
ApiFuncName::CublasgetstreamV2 as u64,
smallvec![
Argument::from_ref(&handle_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
Argument::from_mut(
&mut stream_id_usize,
ArgumentFlag::ARG_OUT | ArgumentFlag::ARG_VIRT,
),
],
);
let request_id = req.request_id();
let res = Agent::get_instance()
.invoke_api::<cublasStatus_t>(req)
.expect("call invoke_api failed");
debug!(
" client2: stream_id ptr:{:p}, *ptr:{:?} ",
stream_id,
unsafe { *stream_id }
);
unsafe {
*stream_id = stream_id_usize as cudaStream_t;
virt::handle_insert(*stream_id as *mut c_void, *stream_id as *mut c_void, 0)
.expect("handle_insert failed");
virt::req_id_vhandle_insert(request_id, *stream_id as *mut c_void)
.expect("req_id_vhandle_insert failed");
}
res
}
fn cublasGetMathMode(&self, handle: cublasHandle_t, mode: *mut cublasMath_t) -> cublasStatus_t {
info!("[Hooked] api_name: cublasGetMathMode");
debug!(
"before: handle: {}, mode: {:p}, *mode: {:?}",
handle,
mode,
unsafe { *mode }
);
let handle_usize = virt::handle_map(handle as *mut c_void).expect("handle_map failed")
as cublasHandle_t as usize;
let req = Request::with_args(
ApiFuncName::Cublasgetmathmode as u64,
smallvec![
Argument::from_ref(&handle_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
unsafe { Argument::from_mut_ptr(mode, ArgumentFlag::ARG_OUT) },
],
);
let res = Agent::get_instance()
.invoke_api::<cublasStatus_t>(req)
.expect("call invoke_api failed");
debug!(
"after: handle: {}, mode: {:p}, *mode: {:?}",
handle,
mode,
unsafe { *mode }
);
res
}
fn cublasSetMathMode(&self, handle: cublasHandle_t, mode: cublasMath_t) -> cublasStatus_t {
info!("[Hooked] api_name: cublasSetMathMode");
debug!("before: handle: {}, mode:{}", handle, mode);
let handle_usize = virt::handle_map(handle as *mut c_void).expect("handle_map failed")
as cublasHandle_t as usize;
let req = Request::with_args(
ApiFuncName::Cublassetmathmode as u64,
smallvec![
Argument::from_ref(&handle_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
Argument::from_ref(&mode, ArgumentFlag::ARG_IN),
],
);
let res = Agent::get_instance()
.invoke_api::<cublasStatus_t>(req)
.expect("call invoke_api failed");
debug!("after: handle: {}, mode:{}", handle, mode);
res
}
fn cublasSgemm_v2(
&self,
handle: cublasHandle_t,
transa: cublasOperation_t,
transb: cublasOperation_t,
m: c_int,
n: c_int,
k: c_int,
alpha: *const f32,
a: *const f32,
lda: c_int,
b: *const f32,
ldb: c_int,
beta: *const f32,
c: *mut f32,
ldc: c_int,
) -> cublasStatus_t {
info!("[Hooked] api_name: cublasSgemm_v2");
let handle_usize = virt::handle_map(handle as *mut c_void).expect("handle_map failed")
as cublasHandle_t as usize;
let a_usize =
virt::handle_map(a as *mut c_void).expect("handle_map failed") as *const f32 as usize;
let b_usize =
virt::handle_map(b as *mut c_void).expect("handle_map failed") as *const f32 as usize;
let c_usize =
virt::handle_map(c as *mut c_void).expect("handle_map failed") as *mut f32 as usize;
let req = Request::with_args(
ApiFuncName::CublassgemmV2 as u64,
smallvec![
Argument::from_ref(&handle_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
Argument::from_ref(&transa, ArgumentFlag::ARG_IN),
Argument::from_ref(&transb, ArgumentFlag::ARG_IN),
Argument::from_ref(&m, ArgumentFlag::ARG_IN),
Argument::from_ref(&n, ArgumentFlag::ARG_IN),
Argument::from_ref(&k, ArgumentFlag::ARG_IN),
unsafe { Argument::from_ptr(alpha, ArgumentFlag::ARG_IN) },
Argument::from_ref(&a_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
Argument::from_ref(&lda, ArgumentFlag::ARG_IN),
Argument::from_ref(&b_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
Argument::from_ref(&ldb, ArgumentFlag::ARG_IN),
unsafe { Argument::from_ptr(beta, ArgumentFlag::ARG_IN) },
Argument::from_ref(&c_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
Argument::from_ref(&ldc, ArgumentFlag::ARG_IN),
],
);
Agent::get_instance()
.invoke_api::<cublasStatus_t>(req)
.expect("call invoke_api failed")
}
fn cublasGemmEx(
&self,
handle: cublasHandle_t,
transa: cublasOperation_t,
transb: cublasOperation_t,
m: c_int,
n: c_int,
k: c_int,
alpha: *const c_void,
a: *const c_void,
atype: cudaDataType,
lda: c_int,
b: *const c_void,
btype: cudaDataType,
ldb: c_int,
beta: *const c_void,
c: *mut c_void,
ctype: cudaDataType,
ldc: c_int,
compute_type: cublasComputeType_t,
algo: cublasGemmAlgo_t,
) -> cublasStatus_t {
info!("[Hooked] api_name: cublasGemmEx");
let handle_usize = virt::handle_map(handle as *mut c_void).expect("handle_map failed")
as cublasHandle_t as usize;
let a_usize = virt::handle_map(a as *mut c_void).expect("handle_map failed")
as *const c_void as usize;
let b_usize = virt::handle_map(b as *mut c_void).expect("handle_map failed")
as *const c_void as usize;
let c_usize = virt::handle_map(c).expect("handle_map failed") as usize;
let compute_type_size = get_compute_type_size(compute_type);
let alpha_slice =
unsafe { std::slice::from_raw_parts(alpha.cast::<u8>(), compute_type_size) };
let beta_slice =
unsafe { std::slice::from_raw_parts(beta.cast::<u8>(), compute_type_size) };
let req = Request::with_args(
ApiFuncName::Cublasgemmex as u64,
smallvec![
Argument::from_ref(&handle_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
Argument::from_ref(&transa, ArgumentFlag::ARG_IN),
Argument::from_ref(&transb, ArgumentFlag::ARG_IN),
Argument::from_ref(&m, ArgumentFlag::ARG_IN),
Argument::from_ref(&n, ArgumentFlag::ARG_IN),
Argument::from_ref(&k, ArgumentFlag::ARG_IN),
Argument::from_slice(alpha_slice, ArgumentFlag::ARG_IN),
Argument::from_ref(&a_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
Argument::from_ref(&atype, ArgumentFlag::ARG_IN),
Argument::from_ref(&lda, ArgumentFlag::ARG_IN),
Argument::from_ref(&b_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
Argument::from_ref(&btype, ArgumentFlag::ARG_IN),
Argument::from_ref(&ldb, ArgumentFlag::ARG_IN),
Argument::from_slice(beta_slice, ArgumentFlag::ARG_IN),
Argument::from_ref(&c_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
Argument::from_ref(&ctype, ArgumentFlag::ARG_IN),
Argument::from_ref(&ldc, ArgumentFlag::ARG_IN),
Argument::from_ref(&compute_type, ArgumentFlag::ARG_IN),
Argument::from_ref(&algo, ArgumentFlag::ARG_IN),
],
);
Agent::get_instance()
.invoke_api::<cublasStatus_t>(req)
.expect("call invoke_api failed")
}
fn cublasSgemmStridedBatched(
&self,
handle: cublasHandle_t,
transa: cublasOperation_t,
transb: cublasOperation_t,
m: c_int,
n: c_int,
k: c_int,
alpha: *const f32,
a: *const f32,
lda: c_int,
stride_a: c_longlong,
b: *const f32,
ldb: c_int,
stride_b: c_longlong,
beta: *const f32,
c: *mut f32,
ldc: c_int,
stride_c: c_longlong,
batch_count: c_int,
) -> cublasStatus_t {
info!("[Hooked] api_name: cublasSgemmStridedBatched");
let handle_usize = virt::handle_map(handle as *mut c_void).expect("handle_map failed")
as cublasHandle_t as usize;
let a_usize =
virt::handle_map(a as *mut c_void).expect("handle_map failed") as *const f32 as usize;
let b_usize =
virt::handle_map(b as *mut c_void).expect("handle_map failed") as *const f32 as usize;
let c_usize =
virt::handle_map(c as *mut c_void).expect("handle_map failed") as *mut f32 as usize;
let req = Request::with_args(
ApiFuncName::Cublassgemmstridedbatched as u64,
smallvec![
Argument::from_ref(&handle_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
Argument::from_ref(&transa, ArgumentFlag::ARG_IN),
Argument::from_ref(&transb, ArgumentFlag::ARG_IN),
Argument::from_ref(&m, ArgumentFlag::ARG_IN),
Argument::from_ref(&n, ArgumentFlag::ARG_IN),
Argument::from_ref(&k, ArgumentFlag::ARG_IN),
unsafe { Argument::from_ptr(alpha, ArgumentFlag::ARG_IN) },
Argument::from_ref(&a_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
Argument::from_ref(&lda, ArgumentFlag::ARG_IN),
Argument::from_ref(&stride_a, ArgumentFlag::ARG_IN),
Argument::from_ref(&b_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
Argument::from_ref(&ldb, ArgumentFlag::ARG_IN),
Argument::from_ref(&stride_b, ArgumentFlag::ARG_IN),
unsafe { Argument::from_ptr(beta, ArgumentFlag::ARG_IN) },
Argument::from_ref(&c_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
Argument::from_ref(&ldc, ArgumentFlag::ARG_IN),
Argument::from_ref(&stride_c, ArgumentFlag::ARG_IN),
Argument::from_ref(&batch_count, ArgumentFlag::ARG_IN),
],
);
Agent::get_instance()
.invoke_api::<cublasStatus_t>(req)
.expect("call invoke_api failed")
}
fn cublasGemmStridedBatchedEx(
&self,
handle: cublasHandle_t,
transa: cublasOperation_t,
transb: cublasOperation_t,
m: c_int,
n: c_int,
k: c_int,
alpha: *const c_void,
a: *const c_void,
atype: cudaDataType,
lda: c_int,
stride_a: c_longlong,
b: *const c_void,
btype: cudaDataType,
ldb: c_int,
stride_b: c_longlong,
beta: *const c_void,
c: *mut c_void,
ctype: cudaDataType,
ldc: c_int,
stride_c: c_longlong,
batch_count: c_int,
compute_type: cublasComputeType_t,
algo: cublasGemmAlgo_t,
) -> cublasStatus_t {
info!("[Hooked] api_name: cublasGemmStridedBatchedEx");
let handle_usize = virt::handle_map(handle as *mut c_void).expect("handle_map failed")
as cublasHandle_t as usize;
let a_usize = virt::handle_map(a as *mut c_void).expect("handle_map failed")
as *const c_void as usize;
let b_usize = virt::handle_map(b as *mut c_void).expect("handle_map failed")
as *const c_void as usize;
let c_usize = virt::handle_map(c).expect("handle_map failed") as usize;
let compute_type_size = get_compute_type_size(compute_type);
let alpha_slice =
unsafe { std::slice::from_raw_parts(alpha.cast::<u8>(), compute_type_size) };
let beta_slice =
unsafe { std::slice::from_raw_parts(beta.cast::<u8>(), compute_type_size) };
let req = Request::with_args(
ApiFuncName::Cublasgemmstridedbatchedex as u64,
smallvec![
Argument::from_ref(&handle_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
Argument::from_ref(&transa, ArgumentFlag::ARG_IN),
Argument::from_ref(&transb, ArgumentFlag::ARG_IN),
Argument::from_ref(&m, ArgumentFlag::ARG_IN),
Argument::from_ref(&n, ArgumentFlag::ARG_IN),
Argument::from_ref(&k, ArgumentFlag::ARG_IN),
Argument::from_slice(alpha_slice, ArgumentFlag::ARG_IN),
Argument::from_ref(&a_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
Argument::from_ref(&atype, ArgumentFlag::ARG_IN),
Argument::from_ref(&lda, ArgumentFlag::ARG_IN),
Argument::from_ref(&stride_a, ArgumentFlag::ARG_IN),
Argument::from_ref(&b_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
Argument::from_ref(&btype, ArgumentFlag::ARG_IN),
Argument::from_ref(&ldb, ArgumentFlag::ARG_IN),
Argument::from_ref(&stride_b, ArgumentFlag::ARG_IN),
Argument::from_slice(beta_slice, ArgumentFlag::ARG_IN),
Argument::from_ref(&c_usize, ArgumentFlag::ARG_IN | ArgumentFlag::ARG_VIRT),
Argument::from_ref(&ctype, ArgumentFlag::ARG_IN),
Argument::from_ref(&ldc, ArgumentFlag::ARG_IN),
Argument::from_ref(&stride_c, ArgumentFlag::ARG_IN),
Argument::from_ref(&batch_count, ArgumentFlag::ARG_IN),
Argument::from_ref(&compute_type, ArgumentFlag::ARG_IN),
Argument::from_ref(&algo, ArgumentFlag::ARG_IN),
],
);
Agent::get_instance()
.invoke_api::<cublasStatus_t>(req)
.expect("call invoke_api failed")
}
}