* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is part of the MindStudio project.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*/
#include "csrc/common/inject/acl_inject.h"
#include <functional>
#include "csrc/activity/activity_manager.h"
#include "csrc/callback/callback_manager.h"
namespace {
const std::string SO_NAME = "libascendcl";
const std::string SO_FILE_NAME = SO_NAME + ".so";
enum AclRtFuncIndex {
FUNC_ACL_RT_SET_DEVICE,
FUNC_ACL_RT_RESET_DEVICE,
FUNC_ACL_RT_CREATE_CTX,
FUNC_ACL_RT_DESTROY_CTX,
FUNC_ACL_RT_CREATE_STREAM,
FUNC_ACL_RT_DESTROY_STREAM,
FUNC_ACL_RT_SYNCHRONIZE_STREAM,
FUNC_ACL_RT_LAUNCH_KERNEL,
FUNC_ACL_RT_LAUNCH_KERNEL_V2,
FUNC_ACL_RT_LAUNCH_KERNEL_WITH_CFG,
FUNC_ACL_RT_LAUNCH_KERNEL_WITH_HOST_ARGS,
FUNC_ACL_RT_GET_DEVICE,
FUNC_ACL_RT_STREAM_GET_ID,
FUNC_ACL_RT_PROF_TRACE,
FUNC_ACL_RT_BINARY_GET_FUNC_BY_ENTRY,
FUNC_ACL_RT_GET_LOGIC_DEV_ID_BY_USER_DEV_ID,
FUNC_ACL_RT_COUNT
};
pthread_once_t g_once;
void* g_aclrtFuncArray[FUNC_ACL_RT_COUNT];
void LoadAclFunction()
{
g_aclrtFuncArray[FUNC_ACL_RT_SET_DEVICE] =
Mspti::Common::RegisterFunction(SO_NAME, "aclrtSetDevice");
g_aclrtFuncArray[FUNC_ACL_RT_RESET_DEVICE] =
Mspti::Common::RegisterFunction(SO_NAME, "aclrtResetDevice");
g_aclrtFuncArray[FUNC_ACL_RT_CREATE_CTX] =
Mspti::Common::RegisterFunction(SO_NAME, "aclrtCreateContext");
g_aclrtFuncArray[FUNC_ACL_RT_DESTROY_CTX] =
Mspti::Common::RegisterFunction(SO_NAME, "aclrtDestroyContext");
g_aclrtFuncArray[FUNC_ACL_RT_CREATE_STREAM] =
Mspti::Common::RegisterFunction(SO_NAME, "aclrtCreateStream");
g_aclrtFuncArray[FUNC_ACL_RT_DESTROY_STREAM] =
Mspti::Common::RegisterFunction(SO_NAME, "aclrtDestroyStream");
g_aclrtFuncArray[FUNC_ACL_RT_SYNCHRONIZE_STREAM] =
Mspti::Common::RegisterFunction(SO_NAME, "aclrtSynchronizeStream");
g_aclrtFuncArray[FUNC_ACL_RT_LAUNCH_KERNEL] =
Mspti::Common::RegisterFunction(SO_NAME, "aclrtLaunchKernel");
g_aclrtFuncArray[FUNC_ACL_RT_LAUNCH_KERNEL_V2] =
Mspti::Common::RegisterFunction(SO_NAME, "aclrtLaunchKernelV2");
g_aclrtFuncArray[FUNC_ACL_RT_LAUNCH_KERNEL_WITH_CFG] =
Mspti::Common::RegisterFunction(SO_NAME, "aclrtLaunchKernelWithConfig");
g_aclrtFuncArray[FUNC_ACL_RT_LAUNCH_KERNEL_WITH_HOST_ARGS] =
Mspti::Common::RegisterFunction(SO_NAME, "aclrtLaunchKernelWithHostArgs");
g_aclrtFuncArray[FUNC_ACL_RT_GET_DEVICE] =
Mspti::Common::RegisterFunction(SO_NAME, "aclrtGetDevice");
g_aclrtFuncArray[FUNC_ACL_RT_STREAM_GET_ID] =
Mspti::Common::RegisterFunction(SO_NAME, "aclrtStreamGetId");
g_aclrtFuncArray[FUNC_ACL_RT_PROF_TRACE] =
Mspti::Common::RegisterFunction(SO_NAME, "aclrtProfTrace");
g_aclrtFuncArray[FUNC_ACL_RT_BINARY_GET_FUNC_BY_ENTRY] =
Mspti::Common::RegisterFunction(SO_NAME, "aclrtBinaryGetFunctionByEntry");
g_aclrtFuncArray[FUNC_ACL_RT_GET_LOGIC_DEV_ID_BY_USER_DEV_ID] =
Mspti::Common::RegisterFunction(SO_NAME, "aclrtGetLogicDevIdByUserDevId");
}
}
AclError aclrtSetDevice(int32_t deviceId)
{
pthread_once(&g_once, LoadAclFunction);
void* voidFunc = g_aclrtFuncArray[FUNC_ACL_RT_SET_DEVICE];
using aclrtSetDeviceFunc = std::function<decltype(aclrtSetDevice)>;
aclrtSetDeviceFunc func = Mspti::Common::ReinterpretConvert<decltype(&aclrtSetDevice)>(voidFunc);
if (func == nullptr) {
Mspti::Common::GetFunction(SO_NAME, __FUNCTION__, func);
}
THROW_FUNC_NOTFOUND(func, __FUNCTION__, SO_FILE_NAME);
Mspti::Callback::CallbackScope scope(MSPTI_CB_DOMAIN_RUNTIME, MSPTI_CBID_RUNTIME_DEVICE_SET, __FUNCTION__);
return func(deviceId);
}
AclError aclrtResetDevice(int32_t deviceId)
{
pthread_once(&g_once, LoadAclFunction);
void* voidFunc = g_aclrtFuncArray[FUNC_ACL_RT_RESET_DEVICE];
using aclrtResetDeviceFunc = std::function<decltype(aclrtResetDevice)>;
aclrtResetDeviceFunc func = Mspti::Common::ReinterpretConvert<decltype(&aclrtResetDevice)>(voidFunc);
if (func == nullptr) {
Mspti::Common::GetFunction(SO_NAME, __FUNCTION__, func);
}
THROW_FUNC_NOTFOUND(func, __FUNCTION__, SO_FILE_NAME);
Mspti::Callback::CallbackScope scope(MSPTI_CB_DOMAIN_RUNTIME, MSPTI_CBID_RUNTIME_DEVICE_RESET, __FUNCTION__);
return func(deviceId);
}
AclError aclrtCreateContext(void **context, int32_t deviceId)
{
pthread_once(&g_once, LoadAclFunction);
void* voidFunc = g_aclrtFuncArray[FUNC_ACL_RT_CREATE_CTX];
using aclrtCreateContextFunc = std::function<decltype(aclrtCreateContext)>;
aclrtCreateContextFunc func = Mspti::Common::ReinterpretConvert<decltype(&aclrtCreateContext)>(voidFunc);
if (func == nullptr) {
Mspti::Common::GetFunction(SO_NAME, __FUNCTION__, func);
}
THROW_FUNC_NOTFOUND(func, __FUNCTION__, SO_FILE_NAME);
Mspti::Callback::CallbackScope scope(MSPTI_CB_DOMAIN_RUNTIME, MSPTI_CBID_RUNTIME_CONTEXT_CREATED, __FUNCTION__);
return func(context, deviceId);
}
AclError aclrtDestroyContext(void *context)
{
pthread_once(&g_once, LoadAclFunction);
void* voidFunc = g_aclrtFuncArray[FUNC_ACL_RT_DESTROY_CTX];
using aclrtDestroyContextFunc = std::function<decltype(aclrtDestroyContext)>;
aclrtDestroyContextFunc func = Mspti::Common::ReinterpretConvert<decltype(&aclrtDestroyContext)>(voidFunc);
if (func == nullptr) {
Mspti::Common::GetFunction(SO_NAME, __FUNCTION__, func);
}
THROW_FUNC_NOTFOUND(func, __FUNCTION__, SO_FILE_NAME);
Mspti::Callback::CallbackScope scope(MSPTI_CB_DOMAIN_RUNTIME, MSPTI_CBID_RUNTIME_CONTEXT_DESTROY, __FUNCTION__);
return func(context);
}
AclError aclrtCreateStream(AclrtStream *stream)
{
pthread_once(&g_once, LoadAclFunction);
void* voidFunc = g_aclrtFuncArray[FUNC_ACL_RT_CREATE_STREAM];
using aclrtCreateStreamFunc = std::function<decltype(aclrtCreateStream)>;
aclrtCreateStreamFunc func = Mspti::Common::ReinterpretConvert<decltype(&aclrtCreateStream)>(voidFunc);
if (func == nullptr) {
Mspti::Common::GetFunction(SO_NAME, __FUNCTION__, func);
}
THROW_FUNC_NOTFOUND(func, __FUNCTION__, SO_FILE_NAME);
Mspti::Callback::CallbackScope scope(MSPTI_CB_DOMAIN_RUNTIME, MSPTI_CBID_RUNTIME_STREAM_CREATED, __FUNCTION__);
return func(stream);
}
AclError aclrtDestroyStream(AclrtStream stream)
{
pthread_once(&g_once, LoadAclFunction);
void* voidFunc = g_aclrtFuncArray[FUNC_ACL_RT_DESTROY_STREAM];
using aclrtDestroyStreamFunc = std::function<decltype(aclrtDestroyStream)>;
aclrtDestroyStreamFunc func = Mspti::Common::ReinterpretConvert<decltype(&aclrtDestroyStream)>(voidFunc);
if (func == nullptr) {
Mspti::Common::GetFunction(SO_NAME, __FUNCTION__, func);
}
THROW_FUNC_NOTFOUND(func, __FUNCTION__, SO_FILE_NAME);
Mspti::Callback::CallbackScope scope(MSPTI_CB_DOMAIN_RUNTIME, MSPTI_CBID_RUNTIME_STREAM_DESTROY, __FUNCTION__);
return func(stream);
}
AclError aclrtSynchronizeStream(AclrtStream stream)
{
pthread_once(&g_once, LoadAclFunction);
void* voidFunc = g_aclrtFuncArray[FUNC_ACL_RT_SYNCHRONIZE_STREAM];
using aclrtSynchronizeStreamFunc = std::function<decltype(aclrtSynchronizeStream)>;
aclrtSynchronizeStreamFunc func = Mspti::Common::ReinterpretConvert<decltype(&aclrtSynchronizeStream)>(voidFunc);
if (func == nullptr) {
Mspti::Common::GetFunction(SO_NAME, __FUNCTION__, func);
}
THROW_FUNC_NOTFOUND(func, __FUNCTION__, SO_FILE_NAME);
Mspti::Callback::CallbackScope scope(MSPTI_CB_DOMAIN_RUNTIME, MSPTI_CBID_RUNTIME_STREAM_SYNCHRONIZED,
__FUNCTION__);
return func(stream);
}
AclError aclrtLaunchKernel(AclrtFuncHandle funcHandle, uint32_t blockDim,
const void *argsData, size_t argsSize, AclrtStream stream)
{
pthread_once(&g_once, LoadAclFunction);
void* voidFunc = g_aclrtFuncArray[FUNC_ACL_RT_LAUNCH_KERNEL];
using aclrtLaunchKernelFunc = std::function<decltype(aclrtLaunchKernel)>;
aclrtLaunchKernelFunc func = Mspti::Common::ReinterpretConvert<decltype(&aclrtLaunchKernel)>(voidFunc);
if (func == nullptr) {
Mspti::Common::GetFunction(SO_NAME, __FUNCTION__, func);
}
THROW_FUNC_NOTFOUND(func, __FUNCTION__, SO_FILE_NAME);
Mspti::Common::ContextManager::GetInstance()->UpdateAndReportCorrelationId();
Mspti::Callback::CallbackScope scope(MSPTI_CB_DOMAIN_RUNTIME, MSPTI_CBID_RUNTIME_LAUNCH, __FUNCTION__);
return func(funcHandle, blockDim, argsData, argsSize, stream);
}
AclError aclrtLaunchKernelV2(AclrtFuncHandle funcHandle, uint32_t blockDim, const void *argsData,
size_t argsSize, AclrtLaunchKernelCfg *cfg, AclrtStream stream)
{
pthread_once(&g_once, LoadAclFunction);
void* voidFunc = g_aclrtFuncArray[FUNC_ACL_RT_LAUNCH_KERNEL_V2];
using aclrtLaunchKernelV2Func = std::function<decltype(aclrtLaunchKernelV2)>;
aclrtLaunchKernelV2Func func = Mspti::Common::ReinterpretConvert<decltype(&aclrtLaunchKernelV2)>(voidFunc);
if (func == nullptr) {
Mspti::Common::GetFunction(SO_NAME, __FUNCTION__, func);
}
THROW_FUNC_NOTFOUND(func, __FUNCTION__, SO_FILE_NAME);
Mspti::Common::ContextManager::GetInstance()->UpdateAndReportCorrelationId();
Mspti::Callback::CallbackScope scope(MSPTI_CB_DOMAIN_RUNTIME, MSPTI_CBID_RUNTIME_LAUNCH, __FUNCTION__);
return func(funcHandle, blockDim, argsData, argsSize, cfg, stream);
}
AclError aclrtLaunchKernelWithConfig(AclrtFuncHandle funcHandle, uint32_t blockDim, AclrtStream stream,
AclrtLaunchKernelCfg *cfg, AclrtArgsHandle argsHandle, void *reserve)
{
pthread_once(&g_once, LoadAclFunction);
void* voidFunc = g_aclrtFuncArray[FUNC_ACL_RT_LAUNCH_KERNEL_WITH_CFG];
using aclrtLaunchKernelWithConfigFunc = std::function<decltype(aclrtLaunchKernelWithConfig)>;
aclrtLaunchKernelWithConfigFunc func =
Mspti::Common::ReinterpretConvert<decltype(&aclrtLaunchKernelWithConfig)>(voidFunc);
if (func == nullptr) {
Mspti::Common::GetFunction(SO_NAME, __FUNCTION__, func);
}
THROW_FUNC_NOTFOUND(func, __FUNCTION__, SO_FILE_NAME);
Mspti::Common::ContextManager::GetInstance()->UpdateAndReportCorrelationId();
Mspti::Callback::CallbackScope scope(MSPTI_CB_DOMAIN_RUNTIME, MSPTI_CBID_RUNTIME_LAUNCH, __FUNCTION__);
return func(funcHandle, blockDim, stream, cfg, argsHandle, reserve);
}
AclError aclrtLaunchKernelWithHostArgs(AclrtFuncHandle funcHandle, uint32_t blockDim, AclrtStream stream,
AclrtLaunchKernelCfg *cfg, void *hostArgs, size_t argsSize,
AclrtPlaceHolderInfo *placeHolderArray, size_t placeHolderNum)
{
pthread_once(&g_once, LoadAclFunction);
void* voidFunc = g_aclrtFuncArray[FUNC_ACL_RT_LAUNCH_KERNEL_WITH_HOST_ARGS];
using aclrtLaunchKernelWithHostArgsFunc = std::function<decltype(aclrtLaunchKernelWithHostArgs)>;
aclrtLaunchKernelWithHostArgsFunc func =
Mspti::Common::ReinterpretConvert<decltype(&aclrtLaunchKernelWithHostArgs)>(voidFunc);
if (func == nullptr) {
Mspti::Common::GetFunction(SO_NAME, __FUNCTION__, func);
}
THROW_FUNC_NOTFOUND(func, __FUNCTION__, SO_FILE_NAME);
Mspti::Common::ContextManager::GetInstance()->UpdateAndReportCorrelationId();
Mspti::Callback::CallbackScope scope(MSPTI_CB_DOMAIN_RUNTIME, MSPTI_CBID_RUNTIME_LAUNCH, __FUNCTION__);
return func(funcHandle, blockDim, stream, cfg, hostArgs, argsSize, placeHolderArray, placeHolderNum);
}
AclError aclrtGetDevice(int32_t *deviceId)
{
pthread_once(&g_once, LoadAclFunction);
void* voidFunc = g_aclrtFuncArray[FUNC_ACL_RT_GET_DEVICE];
using aclrtGetDeviceFunc = std::function<decltype(aclrtGetDevice)>;
aclrtGetDeviceFunc func = Mspti::Common::ReinterpretConvert<decltype(&aclrtGetDevice)>(voidFunc);
if (func == nullptr) {
Mspti::Common::GetFunction(SO_NAME, __FUNCTION__, func);
}
THROW_FUNC_NOTFOUND(func, __FUNCTION__, SO_FILE_NAME);
return func(deviceId);
}
AclError aclrtStreamGetId(AclrtStream stream, int32_t *streamId)
{
pthread_once(&g_once, LoadAclFunction);
void* voidFunc = g_aclrtFuncArray[FUNC_ACL_RT_STREAM_GET_ID];
using aclrtStreamGetIdFunc = std::function<decltype(aclrtStreamGetId)>;
aclrtStreamGetIdFunc func = Mspti::Common::ReinterpretConvert<decltype(&aclrtStreamGetId)>(voidFunc);
if (func == nullptr) {
Mspti::Common::GetFunction(SO_NAME, __FUNCTION__, func);
}
THROW_FUNC_NOTFOUND(func, __FUNCTION__, SO_FILE_NAME);
return func(stream, streamId);
}
AclError aclrtProfTrace(void *userdata, int32_t length, AclrtStream stream)
{
pthread_once(&g_once, LoadAclFunction);
void* voidFunc = g_aclrtFuncArray[FUNC_ACL_RT_PROF_TRACE];
using aclrtProfTraceFunc = std::function<decltype(aclrtProfTrace)>;
aclrtProfTraceFunc func = Mspti::Common::ReinterpretConvert<decltype(&aclrtProfTrace)>(voidFunc);
if (func == nullptr) {
Mspti::Common::GetFunction(SO_NAME, __FUNCTION__, func);
}
THROW_FUNC_NOTFOUND(func, __FUNCTION__, SO_FILE_NAME);
return func(userdata, length, stream);
}
AclError aclrtBinaryGetFunctionByEntry(AclrtBinHandle binHandle, uint64_t funcEntry, AclrtFuncHandle *funcHandle)
{
pthread_once(&g_once, LoadAclFunction);
void* voidFunc = g_aclrtFuncArray[FUNC_ACL_RT_BINARY_GET_FUNC_BY_ENTRY];
using aclrtBinaryGetFunctionByEntryFunc = std::function<decltype(aclrtBinaryGetFunctionByEntry)>;
aclrtBinaryGetFunctionByEntryFunc func =
Mspti::Common::ReinterpretConvert<decltype(&aclrtBinaryGetFunctionByEntry)>(voidFunc);
if (func == nullptr) {
Mspti::Common::GetFunction(SO_NAME, __FUNCTION__, func);
}
THROW_FUNC_NOTFOUND(func, __FUNCTION__, SO_FILE_NAME);
return func(binHandle, funcEntry, funcHandle);
}
AclError aclrtGetLogicDevIdByUserDevId(const int32_t userDevid, int32_t *const logicDevId)
{
pthread_once(&g_once, LoadAclFunction);
void* voidFunc = g_aclrtFuncArray[FUNC_ACL_RT_GET_LOGIC_DEV_ID_BY_USER_DEV_ID];
using aclrtGetLogicDevIdByUserDevIdFunc = std::function<decltype(aclrtGetLogicDevIdByUserDevId)>;
aclrtGetLogicDevIdByUserDevIdFunc func =
Mspti::Common::ReinterpretConvert<decltype(&aclrtGetLogicDevIdByUserDevId)>(voidFunc);
if (func == nullptr) {
Mspti::Common::GetFunction(SO_NAME, __FUNCTION__, func);
}
THROW_FUNC_NOTFOUND(func, __FUNCTION__, SO_FILE_NAME);
return func(userDevid, logicDevId);
}