* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*/
#include "mstx_manager.h"
#include <cstring>
#include "securec.h"
#include "call_stack.h"
#include "event_report.h"
#include "record_info.h"
#include "log.h"
#include "bit_field.h"
#include "aten_manager.h"
#include "memory_pool_trace/memory_pool_trace_manager.h"
#include "memory_pool_trace/atb_memory_pool_trace.h"
#include "memory_pool_trace/mindspore_memory_pool_trace.h"
#include "memory_pool_trace/pta_caching_pool_trace.h"
#include "memory_pool_trace/pta_workspace_pool_trace.h"
namespace MemScope {
void MstxManager::ReportMarkA(const char* msg, int32_t streamId, MemScopeCommType type)
{
if (msg && strncmp(msg, ATEN_MSG, strlen(ATEN_MSG)) == 0) {
const char* atenMsg = msg + strlen(ATEN_MSG);
AtenManager::GetInstance().ProcessMsg(atenMsg, streamId);
return ;
}
std::string markMsg = std::string(msg);
if (!EventReport::Instance(type).ReportMark(MarkType::MARK_A, markMsg, streamId, onlyMarkId_)) {
LOG_ERROR("Report Mark FAILED");
}
}
uint64_t MstxManager::ReportRangeStart(const char* msg, int32_t streamId)
{
uint64_t rangeId = GetRangeId();
std::string markMsg = std::string(msg);
if (!EventReport::Instance(MemScopeCommType::SHARED_MEMORY).ReportMark(MarkType::RANGE_START_A, markMsg, streamId, rangeId)) {
LOG_ERROR("Report Mark FAILED");
}
return rangeId;
}
void MstxManager::ReportRangeEnd(uint64_t id)
{
std::string msg = "Range end from id " + std::to_string(id);
if (!EventReport::Instance(MemScopeCommType::SHARED_MEMORY).ReportMark(MarkType::RANGE_END, msg, -1, id)) {
LOG_ERROR("Report Mark FAILED");
}
}
uint64_t MstxManager::GetRangeId()
{
return rangeId_++;
}
mstxDomainHandle_t MstxManager::ReportDomainCreateA(char const *domainName)
{
if (std::string(domainName) == "atb") {
if (MemoryPoolTraceManager::GetInstance().RegisterMemoryPoolTracer("atb", &ATBMemoryPoolTrace::GetInstance())) {
return MemoryPoolTraceManager::GetInstance().CreateDomain(domainName);
}
}
if (std::string(domainName) == "mindsporeMemPool") {
if (MemoryPoolTraceManager::GetInstance().RegisterMemoryPoolTracer("mindsporeMemPool",
&MindsporeMemoryPoolTrace::GetInstance())) {
return MemoryPoolTraceManager::GetInstance().CreateDomain(domainName);
}
}
if (std::string(domainName) == "ptaCaching" || std::string(domainName) == "msleaks") {
MemoryPoolTraceManager::GetInstance().RegisterMemoryPoolTracer("ptaCaching", &PTACachingPoolTrace::GetInstance());
return MemoryPoolTraceManager::GetInstance().CreateDomain("ptaCaching");
}
if (std::string(domainName) == "ptaWorkspace") {
MemoryPoolTraceManager::GetInstance().RegisterMemoryPoolTracer("ptaWorkspace", &PTAWorkspacePoolTrace::GetInstance());
return MemoryPoolTraceManager::GetInstance().CreateDomain(domainName);
}
return nullptr;
}
mstxMemHeapHandle_t MstxManager::ReportHeapRegister(mstxDomainHandle_t domain, mstxMemHeapDesc_t const *desc)
{
auto tracer = MemoryPoolTraceManager::GetInstance().GetMemoryPoolTracer(domain);
if (tracer) {
return tracer->Allocate(domain, desc);
}
return nullptr;
}
void MstxManager::ReportHeapUnregister(mstxDomainHandle_t domain, mstxMemHeapHandle_t heap)
{
auto tracer = MemoryPoolTraceManager::GetInstance().GetMemoryPoolTracer(domain);
if (tracer) {
return tracer->Deallocate(domain, heap);
}
}
void MstxManager::ReportRegionsRegister(mstxDomainHandle_t domain, mstxMemRegionsRegisterBatch_t const *desc)
{
auto tracer = MemoryPoolTraceManager::GetInstance().GetMemoryPoolTracer(domain);
if (tracer) {
return tracer->Reallocate(domain, desc);
}
}
void MstxManager::ReportRegionsUnregister(mstxDomainHandle_t domain, mstxMemRegionsUnregisterBatch_t const *desc)
{
auto tracer = MemoryPoolTraceManager::GetInstance().GetMemoryPoolTracer(domain);
if (tracer) {
return tracer->Release(domain, desc);
}
}
}