* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is part of the MindStudio project.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*/
#include "csrc/callback/callback_manager.h"
#include <memory>
#include "csrc/common/inject/mstx_inject.h"
#include "csrc/activity/activity_manager.h"
#include "csrc/activity/ascend/dev_task_manager.h"
#include "csrc/activity/ascend/channel/channel_pool_manager.h"
#include "csrc/common/plog_manager.h"
#include "csrc/common/context_manager.h"
#include "csrc/common/utils.h"
namespace Mspti {
namespace Callback {
namespace {
inline bool IsValidCBDomain(msptiCallbackDomain domain)
{
return domain > MSPTI_CB_DOMAIN_INVALID && domain < MSPTI_CB_DOMAIN_SIZE;
}
inline bool IsValidCBId(msptiCallbackId cbid)
{
return cbid < sizeof(CallbackManager::BitMap) * 8;
}
inline msptiResult HasLdPreload()
{
static const std::string ld = Mspti::Common::Utils::GetEnv("LD_PRELOAD");
if (ld.find("libmspti.so") == std::string::npos) {
MSPTI_LOGE("Enable callbackDomain requires libmspti.so in LD_PRELOAD.");
return MSPTI_ERROR_WITHOUT_LD_PRELOAD;
}
return MSPTI_SUCCESS;
}
}
std::unordered_map<msptiCallbackDomain, std::unordered_set<msptiCallbackId>> CallbackManager::domain_cbid_map_ = {
{MSPTI_CB_DOMAIN_RUNTIME, {
MSPTI_CBID_RUNTIME_DEVICE_SET, MSPTI_CBID_RUNTIME_DEVICE_RESET,
MSPTI_CBID_RUNTIME_DEVICE_SET_EX, MSPTI_CBID_RUNTIME_CONTEXT_CREATED_EX,
MSPTI_CBID_RUNTIME_CONTEXT_CREATED, MSPTI_CBID_RUNTIME_CONTEXT_DESTROY,
MSPTI_CBID_RUNTIME_STREAM_CREATED, MSPTI_CBID_RUNTIME_STREAM_DESTROY,
MSPTI_CBID_RUNTIME_STREAM_SYNCHRONIZED, MSPTI_CBID_RUNTIME_LAUNCH,
MSPTI_CBID_RUNTIME_CPU_LAUNCH, MSPTI_CBID_RUNTIME_AICPU_LAUNCH, MSPTI_CBID_RUNTIME_AIV_LAUNCH,
MSPTI_CBID_RUNTIME_FFTS_LAUNCH, MSPTI_CBID_RUNTIME_MALLOC, MSPTI_CBID_RUNTIME_FREE,
MSPTI_CBID_RUNTIME_MALLOC_HOST, MSPTI_CBID_RUNTIME_FREE_HOST, MSPTI_CBID_RUNTIME_MALLOC_CACHED,
MSPTI_CBID_RUNTIME_FLUSH_CACHE, MSPTI_CBID_RUNTIME_INVALID_CACHE, MSPTI_CBID_RUNTIME_MEMCPY,
MSPTI_CBID_RUNTIME_MEMCPY_HOST, MSPTI_CBID_RUNTIME_MEMCPY_ASYNC,
MSPTI_CBID_RUNTIME_MEM_CPY2D, MSPTI_CBID_RUNTIME_MEM_CPY2D_ASYNC,
MSPTI_CBID_RUNTIME_MEM_SET, MSPTI_CBID_RUNTIME_MEM_SET_ASYNC, MSPTI_CBID_RUNTIME_MEM_GET_INFO,
MSPTI_CBID_RUNTIME_RESERVE_MEM_ADDRESS, MSPTI_CBID_RUNTIME_RELEASE_MEM_ADDRESS,
MSPTI_CBID_RUNTIME_MALLOC_PHYSICAL, MSPTI_CBID_RUNTIME_FREE_PHYSICAL,
MSPTI_CBID_RUNTIME_MEM_EXPORT_TO_SHAREABLE_HANDLE, MSPTI_CBID_RUNTIME_MEM_IMPORT_FROM_SHAREABLE_HANDLE,
MSPTI_CBID_RUNTIME_MEM_SET_PID_TO_SHAREABLE_HANDLE
}},
{MSPTI_CB_DOMAIN_HCCL, {
MSPTI_CBID_HCCL_ALLREDUCE, MSPTI_CBID_HCCL_BROADCAST, MSPTI_CBID_HCCL_ALLGATHER,
MSPTI_CBID_HCCL_REDUCE_SCATTER, MSPTI_CBID_HCCL_REDUCE, MSPTI_CBID_HCCL_ALL_TO_ALL,
MSPTI_CBID_HCCL_ALL_TO_ALLV, MSPTI_CBID_HCCL_BARRIER,
MSPTI_CBID_HCCL_SCATTER, MSPTI_CBID_HCCL_SEND, MSPTI_CBID_HCCL_RECV, MSPTI_CBID_HCCL_SENDRECV
}}
};
CallbackManager *CallbackManager::GetInstance()
{
static CallbackManager instance;
return &instance;
}
msptiResult CallbackManager::Init(msptiSubscriberHandle *subscriber, msptiCallbackFunc callback, void* userdata)
{
if (subscriber == nullptr) {
MSPTI_LOGE("subscriber cannot be nullptr.");
return MSPTI_ERROR_INVALID_PARAMETER;
}
if (init_.load()) {
MSPTI_LOGE("subscriber cannot be register repeat.");
return MSPTI_ERROR_MULTIPLE_SUBSCRIBERS_NOT_SUPPORTED;
}
Mspti::Common::MsptiMakeUniquePtr(subscriber_ptr_);
if (!subscriber_ptr_) {
MSPTI_LOGE("Failed to init subscriber.");
return MSPTI_ERROR_INNER;
}
for (auto& bitmap : cbid_map_) {
bitmap.store(0, std::memory_order_relaxed);
}
subscriber_ptr_->handle = callback;
subscriber_ptr_->userdata = userdata;
*subscriber = subscriber_ptr_.get();
init_.store(true);
Mspti::Ascend::DevTaskManager::GetInstance()->RegisterReportCallback();
Mspti::Common::ContextManager::GetInstance()->StartSyncTime();
MsptiMstxApi::MsptiEnableMstxFunc();
Mspti::Ascend::Channel::ChannelPoolManager::GetInstance()->Init();
MSPTI_LOGI("CallbackManager Init success.");
return MSPTI_SUCCESS;
}
msptiResult CallbackManager::UnInit(msptiSubscriberHandle subscriber)
{
if (!init_.load()) {
return MSPTI_SUCCESS;
}
if (subscriber_ptr_.get() != subscriber) {
MSPTI_LOGE("[ERROR] subscriber was not subscribe.");
return MSPTI_ERROR_INVALID_PARAMETER;
}
subscriber_ptr_.reset(nullptr);
for (auto& bitmap : cbid_map_) {
bitmap.store(0, std::memory_order_relaxed);
}
init_.store(false);
MsptiMstxApi::MsptiDisableMstxFunc();
Mspti::Common::ContextManager::GetInstance()->StopSyncTime();
Mspti::Ascend::DevTaskManager::GetInstance()->UnRegisterReportCallback();
Mspti::Ascend::Channel::ChannelPoolManager::GetInstance()->UnInit();
MSPTI_LOGI("CallbackManager UnInit success.");
return MSPTI_SUCCESS;
}
msptiResult CallbackManager::Register(msptiCallbackDomain domain, msptiCallbackId cbid)
{
if (!IsValidCBDomain(domain) || !IsValidCBId(cbid)) {
return MSPTI_ERROR_INVALID_PARAMETER;
}
auto idx = static_cast<size_t>(domain);
uint64_t mask = 1ULL << static_cast<int>(cbid);
cbid_map_[idx].fetch_or(mask, std::memory_order_relaxed);
MSPTI_LOGI("CallbackManager Register domain: %d, cbid: %d.", domain, cbid);
return MSPTI_SUCCESS;
}
msptiResult CallbackManager::UnRegister(msptiCallbackDomain domain, msptiCallbackId cbid)
{
if (!IsValidCBDomain(domain) || !IsValidCBId(cbid)) {
return MSPTI_ERROR_INVALID_PARAMETER;
}
auto idx = static_cast<size_t>(domain);
uint64_t mask = ~(1ULL << static_cast<int>(cbid));
cbid_map_[idx].fetch_and(mask, std::memory_order_relaxed);
MSPTI_LOGI("CallbackManager UnRegister domain: %d, cbid: %d.", domain, cbid);
return MSPTI_SUCCESS;
}
msptiResult CallbackManager::EnableCallback(uint32_t enable,
msptiSubscriberHandle subscriber, msptiCallbackDomain domain, msptiCallbackId cbid)
{
if (HasLdPreload() != MSPTI_SUCCESS) {
return MSPTI_ERROR_WITHOUT_LD_PRELOAD;
}
if (!init_.load()) {
MSPTI_LOGW("CallbackManager was not init.");
return MSPTI_SUCCESS;
}
if (subscriber != subscriber_ptr_.get()) {
MSPTI_LOGE("subscriber was not subscribe.");
return MSPTI_ERROR_INVALID_PARAMETER;
}
if (!IsValidCBDomain(domain)) {
MSPTI_LOGE("domain: %d was invalid.", domain);
return MSPTI_ERROR_INVALID_PARAMETER;
}
return (enable != 0) ? Register(domain, cbid) : UnRegister(domain, cbid);
}
msptiResult CallbackManager::EnableDomain(uint32_t enable,
msptiSubscriberHandle subscriber, msptiCallbackDomain domain)
{
if (HasLdPreload() != MSPTI_SUCCESS) {
return MSPTI_ERROR_WITHOUT_LD_PRELOAD;
}
if (!init_.load()) {
MSPTI_LOGW("CallbackManager was not init.");
return MSPTI_SUCCESS;
}
if (subscriber != subscriber_ptr_.get()) {
MSPTI_LOGE("subscriber was not subscribe.");
return MSPTI_ERROR_INVALID_PARAMETER;
}
if (!IsValidCBDomain(domain)) {
MSPTI_LOGE("domain: %d was invalid.", domain);
return MSPTI_ERROR_INVALID_PARAMETER;
}
auto cbid_set = domain_cbid_map_.find(domain);
if (cbid_set == domain_cbid_map_.end()) {
return MSPTI_SUCCESS;
}
msptiResult ret = MSPTI_SUCCESS;
for (const auto& cbid : cbid_set->second) {
auto reg_ret = (enable != 0) ? Register(domain, cbid) : UnRegister(domain, cbid);
ret = (reg_ret != MSPTI_SUCCESS) ? reg_ret : ret;
}
return ret;
}
void CallbackManager::ExecuteCallback(msptiCallbackDomain domain,
msptiCallbackId cbid, msptiApiCallbackSite site, const char* funcName)
{
if (!init_.load()) {
return;
}
if (!IsCallbackIdEnable(domain, cbid)) {
return;
}
if (subscriber_ptr_->handle) {
MSPTI_LOGD("CallbackManager execute Callbackfunc, funcName is %s", funcName);
msptiCallbackData callbackData;
callbackData.callbackSite = site;
callbackData.functionName = funcName;
subscriber_ptr_->handle(subscriber_ptr_->userdata, domain, cbid, &callbackData);
}
}
bool CallbackManager::IsCallbackIdEnable(msptiCallbackDomain domain, msptiCallbackId cbid)
{
if (!IsValidCBDomain(domain) || !IsValidCBId(cbid)) {
return false;
}
auto idx = static_cast<size_t>(domain);
uint64_t bits = cbid_map_[idx].load(std::memory_order_relaxed);
return (bits >> static_cast<int>(cbid)) & 1;
}
}
}
msptiResult msptiSubscribe(msptiSubscriberHandle *subscriber, msptiCallbackFunc callback, void *userdata)
{
return Mspti::Callback::CallbackManager::GetInstance()->Init(subscriber, callback, userdata);
}
msptiResult msptiUnsubscribe(msptiSubscriberHandle subscriber)
{
if (Mspti::Activity::ActivityManager::GetInstance()->ResetAllDevice() != MSPTI_SUCCESS) {
MSPTI_LOGE("Reset all device failed.");
}
return Mspti::Callback::CallbackManager::GetInstance()->UnInit(subscriber);
}
msptiResult msptiEnableCallback(
uint32_t enable, msptiSubscriberHandle subscriber, msptiCallbackDomain domain, msptiCallbackId cbid)
{
return Mspti::Callback::CallbackManager::GetInstance()->EnableCallback(enable, subscriber, domain, cbid);
}
msptiResult msptiEnableDomain(uint32_t enable, msptiSubscriberHandle subscriber, msptiCallbackDomain domain)
{
return Mspti::Callback::CallbackManager::GetInstance()->EnableDomain(enable, subscriber, domain);
}