* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*/
#include <sstream>
#define private public
#include "event_trace/memory_pool_trace/memory_pool_trace_manager.h"
#include "event_trace/mstx_hooks/mstx_inject.h"
#include "event_trace/memory_pool_trace/atb_memory_pool_trace.h"
#include "event_trace/memory_pool_trace/mindspore_memory_pool_trace.h"
#include "event_trace/memory_pool_trace/pta_caching_pool_trace.h"
#include "event_trace/memory_pool_trace/pta_workspace_pool_trace.h"
#undef private
#include <gtest/gtest.h>
using namespace MemScope;
TEST(MemPoolTraceTest, MemPoolTraceTestATBHeapRegisterAndRegionRegister)
{
auto domainHandle = MstxDomainCreateAFunc("atb");
EXPECT_EQ(domainHandle, ATBMemoryPoolTrace::GetInstance().atbDomain_);
mstxMemHeapDesc_t heapDesc;
uint32_t deviceId = 1;
void const* ptrHeap = reinterpret_cast<void const*>(123);
mstxMemVirtualRangeDesc_t memRangeDesc{deviceId, ptrHeap, 500};
heapDesc.typeSpecificDesc = reinterpret_cast<void const*>(&memRangeDesc);
auto HeapHandle = MstxMemHeapRegisterFunc(domainHandle, &heapDesc);
EXPECT_EQ(ATBMemoryPoolTrace::GetInstance().memUsageMp_[deviceId].totalReserved, 500);
void const* ptr = reinterpret_cast<void const*>(123);
memRangeDesc = {deviceId, ptr, 50};
mstxMemRegionsRegisterBatch_t desc;
desc.regionCount = 1;
desc.regionDescArray = reinterpret_cast<const void *>(&memRangeDesc);
auto mstxMemRegionHandle = mstxMemRegionHandle_t {};
auto handleArrayOut = &mstxMemRegionHandle;
desc.regionHandleArrayOut = handleArrayOut;
MstxMemRegionsRegisterFunc(domainHandle, &desc);
EXPECT_EQ(ATBMemoryPoolTrace::GetInstance().memUsageMp_[deviceId].totalAllocated, 50);
mstxMemRegionsUnregisterBatch_t unregisterBatch;
unregisterBatch.refCount = 1;
auto mstxMemRegionRef = new mstxMemRegionRef_t {};
mstxMemRegionRef->handle = mstxMemRegionHandle;
unregisterBatch.refArray = mstxMemRegionRef;
MstxMemRegionsUnregisterFunc(domainHandle, &unregisterBatch);
EXPECT_EQ(ATBMemoryPoolTrace::GetInstance().memUsageMp_[deviceId].totalAllocated, 0);
MstxMemHeapUnregisterFunc(domainHandle, HeapHandle);
EXPECT_EQ(ATBMemoryPoolTrace::GetInstance().memUsageMp_[deviceId].totalReserved, 0);
}
TEST(MemPoolTraceTest, MemPoolTraceTestAllocateGetNullPtr)
{
mstxDomainHandle_t domain = nullptr;
mstxMemHeapDesc_t const *desc;
EXPECT_EQ(ATBMemoryPoolTrace::GetInstance().Allocate(domain, desc), nullptr);
EXPECT_EQ(MindsporeMemoryPoolTrace::GetInstance().Allocate(domain, desc), nullptr);
}
TEST(MemPoolTraceTest, MemPoolTraceTestDeallocateGetNullPtr)
{
mstxDomainHandle_t domain = nullptr;
mstxMemHeapHandle_t desc;
ATBMemoryPoolTrace::GetInstance().Deallocate(domain, desc);
MindsporeMemoryPoolTrace::GetInstance().Deallocate(domain, desc);
}
TEST(MemPoolTraceTest, MemPoolTraceTestReallocateGetNullPtr)
{
mstxDomainHandle_t domain = nullptr;
mstxMemRegionsRegisterBatch_t const *desc;
ATBMemoryPoolTrace::GetInstance().Reallocate(domain, desc);
MindsporeMemoryPoolTrace::GetInstance().Reallocate(domain, desc);
}
TEST(MemPoolTraceTest, MemPoolTraceTestReleaseGetNullPtr)
{
mstxDomainHandle_t domain = nullptr;
mstxMemRegionsUnregisterBatch_t const *desc;
ATBMemoryPoolTrace::GetInstance().Release(domain, desc);
MindsporeMemoryPoolTrace::GetInstance().Release(domain, desc);
}
TEST(MemPoolTraceTest, ATBMemoryPoolTraceCreateDomainReturnNull)
{
EXPECT_EQ(ATBMemoryPoolTrace::GetInstance().CreateDomain("test"), nullptr);
}
TEST(MemPoolTraceTest, MemPoolTraceTestMindsporeHeapRegisterAndRegionRegister)
{
auto domainHandle = MstxDomainCreateAFunc("mindsporeMemPool");
EXPECT_EQ(domainHandle, MindsporeMemoryPoolTrace::GetInstance().mindsporeDomain_);
mstxMemHeapDesc_t heapDesc;
uint32_t deviceId = 1;
void const* ptrHeap = reinterpret_cast<void const*>(123);
mstxMemVirtualRangeDesc_t memRangeDesc{deviceId, ptrHeap, 500};
heapDesc.typeSpecificDesc = reinterpret_cast<void const*>(&memRangeDesc);
auto HeapHandle = MstxMemHeapRegisterFunc(domainHandle, &heapDesc);
EXPECT_EQ(MindsporeMemoryPoolTrace::GetInstance().memUsageMp_[deviceId].totalReserved, 500);
void const* ptr = reinterpret_cast<void const*>(123);
memRangeDesc = {deviceId, ptr, 50};
mstxMemRegionsRegisterBatch_t desc;
desc.regionCount = 1;
desc.regionDescArray = reinterpret_cast<const void *>(&memRangeDesc);
auto mstxMemRegionHandle = mstxMemRegionHandle_t {};
auto handleArrayOut = &mstxMemRegionHandle;
desc.regionHandleArrayOut = handleArrayOut;
MstxMemRegionsRegisterFunc(domainHandle, &desc);
EXPECT_EQ(MindsporeMemoryPoolTrace::GetInstance().memUsageMp_[deviceId].totalAllocated, 50);
mstxMemRegionsUnregisterBatch_t unregisterBatch;
unregisterBatch.refCount = 1;
auto mstxMemRegionRef = new mstxMemRegionRef_t {};
mstxMemRegionRef->pointer = ptr;
unregisterBatch.refArray = mstxMemRegionRef;
MstxMemRegionsUnregisterFunc(domainHandle, &unregisterBatch);
EXPECT_EQ(MindsporeMemoryPoolTrace::GetInstance().memUsageMp_[deviceId].totalAllocated, 0);
MstxMemHeapUnregisterFunc(domainHandle, static_cast<mstxMemHeapHandle_t>(const_cast<void *>(ptrHeap)));
EXPECT_EQ(MindsporeMemoryPoolTrace::GetInstance().memUsageMp_[deviceId].totalReserved, 0);
}
TEST(MemPoolTraceTest, MemPoolTraceTestMindsporeHeapRegisterAndRegionRegisterReturnNull)
{
auto domainHandle = MstxDomainCreateAFunc("mindspore");
EXPECT_NE(domainHandle, MindsporeMemoryPoolTrace::GetInstance().mindsporeDomain_);
mstxMemHeapDesc_t heapDesc;
uint32_t deviceId = 1;
void const* ptrHeap = reinterpret_cast<void const*>(123);
mstxMemVirtualRangeDesc_t memRangeDesc{deviceId, ptrHeap, 500};
heapDesc.typeSpecificDesc = reinterpret_cast<void const*>(&memRangeDesc);
auto HeapHandle = MstxMemHeapRegisterFunc(domainHandle, &heapDesc);
EXPECT_EQ(MindsporeMemoryPoolTrace::GetInstance().memUsageMp_[deviceId].totalReserved, 0);
void const* ptr = reinterpret_cast<void const*>(123);
memRangeDesc = {deviceId, ptr, 50};
mstxMemRegionsRegisterBatch_t desc;
desc.regionCount = 1;
desc.regionDescArray = reinterpret_cast<const void *>(&memRangeDesc);
auto mstxMemRegionHandle = mstxMemRegionHandle_t {};
auto handleArrayOut = &mstxMemRegionHandle;
desc.regionHandleArrayOut = handleArrayOut;
MstxMemRegionsRegisterFunc(domainHandle, &desc);
EXPECT_EQ(MindsporeMemoryPoolTrace::GetInstance().memUsageMp_[deviceId].totalAllocated, 0);
mstxMemRegionsUnregisterBatch_t unregisterBatch;
unregisterBatch.refCount = 1;
auto mstxMemRegionRef = new mstxMemRegionRef_t {};
mstxMemRegionRef->handle = mstxMemRegionHandle;
unregisterBatch.refArray = mstxMemRegionRef;
MstxMemRegionsUnregisterFunc(domainHandle, &unregisterBatch);
EXPECT_EQ(MindsporeMemoryPoolTrace::GetInstance().memUsageMp_[deviceId].totalAllocated, 0);
MstxMemHeapUnregisterFunc(domainHandle, HeapHandle);
EXPECT_EQ(MindsporeMemoryPoolTrace::GetInstance().memUsageMp_[deviceId].totalReserved, 0);
}
TEST(MemPoolTraceTest, MindsporeMemoryPoolTraceReleaseRegionHandleMpNull)
{
mstxDomainRegistration_t domainRegistration{};
mstxDomainHandle_t domainHandle = &domainRegistration;
mstxMemRegionsUnregisterBatch_t desc;
void const* ptr = reinterpret_cast<void const*>(123);
mstxMemRegionRef_t memRangeDesc = {mstxMemRegionRefType::MSTX_MEM_REGION_REF_TYPE_HANDLE, ptr};
desc.refCount = 1;
desc.refArray = &memRangeDesc;
MindsporeMemoryPoolTrace::GetInstance().regionHandleMp_ = {};
MindsporeMemoryPoolTrace::GetInstance().Release(domainHandle, &desc);
}
TEST(MemPoolTraceTest, MemoryPoolTraceManagerRelease)
{
mstxDomainRegistration_t domainRegistration{};
mstxDomainHandle_t domainHandle = &domainRegistration;
mstxMemRegionsUnregisterBatch_t desc;
void const* ptr = reinterpret_cast<void const*>(123);
mstxMemRegionRef_t memRangeDesc = {mstxMemRegionRefType::MSTX_MEM_REGION_REF_TYPE_HANDLE, ptr};
desc.refCount = 1;
desc.refArray = &memRangeDesc;
MemoryPoolTraceManager::GetInstance().Release(domainHandle, &desc);
}
TEST(MemPoolTraceTest, MemoryPoolTraceManagerAllocate)
{
mstxDomainRegistration_t domainRegistration{};
mstxDomainHandle_t domainHandle = &domainRegistration;
mstxMemHeapDesc_t heapDesc;
uint32_t deviceId = 1;
void const* ptrHeap = reinterpret_cast<void const*>(123);
mstxMemVirtualRangeDesc_t memRangeDesc{deviceId, ptrHeap, 500};
heapDesc.typeSpecificDesc = reinterpret_cast<void const*>(&memRangeDesc);
MemoryPoolTraceManager::GetInstance().Allocate(domainHandle, &heapDesc);
}
TEST(MemPoolTraceTest, MemoryPoolTraceManagerDeallocate)
{
mstxDomainRegistration_t domainRegistration{};
mstxDomainHandle_t domainHandle = &domainRegistration;
mstxMemHeapHandle_t heap;
MemoryPoolTraceManager::GetInstance().Deallocate(domainHandle, heap);
}
TEST(MemPoolTraceTest, MemoryPoolTraceManagerReallocate)
{
mstxDomainRegistration_t domainRegistration{};
mstxDomainHandle_t domainHandle = &domainRegistration;
void const* ptr = reinterpret_cast<void const*>(123);
mstxMemVirtualRangeDesc_t memRangeDesc = {1, ptr, 50};
mstxMemRegionsRegisterBatch_t desc;
desc.regionCount = 1;
desc.regionDescArray = reinterpret_cast<const void *>(&memRangeDesc);
MemoryPoolTraceManager::GetInstance().Reallocate(domainHandle, &desc);
}
TEST(MemPoolTraceTest, MindsporeMemoryPoolTraceCreateDomainReturnNull)
{
EXPECT_EQ(MindsporeMemoryPoolTrace::GetInstance().CreateDomain("mdspe"), nullptr);
}
TEST(MemPoolTraceTest, PTACachingPoolTraceTestPtaHeapRegisterAndRegionRegister)
{
auto domainHandle = MstxDomainCreateAFunc("ptaCaching");
EXPECT_EQ(domainHandle, PTACachingPoolTrace::GetInstance().ptaCachingDomain_);
mstxMemHeapDesc_t heapDesc;
uint32_t deviceId = 1;
void const* ptrHeap = reinterpret_cast<void const*>(123);
mstxMemVirtualRangeDesc_t memRangeDesc{deviceId, ptrHeap, 500};
heapDesc.typeSpecificDesc = reinterpret_cast<void const*>(&memRangeDesc);
auto HeapHandle = MstxMemHeapRegisterFunc(domainHandle, &heapDesc);
EXPECT_EQ(PTACachingPoolTrace::GetInstance().memUsageMp_[deviceId].totalReserved, 500);
void const* ptr = reinterpret_cast<void const*>(123);
memRangeDesc = {deviceId, ptr, 50};
mstxMemRegionsRegisterBatch_t desc;
desc.regionCount = 1;
desc.regionDescArray = reinterpret_cast<const void *>(&memRangeDesc);
auto mstxMemRegionHandle = mstxMemRegionHandle_t {};
auto handleArrayOut = &mstxMemRegionHandle;
desc.regionHandleArrayOut = handleArrayOut;
MstxMemRegionsRegisterFunc(domainHandle, &desc);
EXPECT_EQ(PTACachingPoolTrace::GetInstance().memUsageMp_[deviceId].totalAllocated, 50);
mstxMemRegionsUnregisterBatch_t unregisterBatch;
unregisterBatch.refCount = 1;
auto mstxMemRegionRef = new mstxMemRegionRef_t {};
mstxMemRegionRef->pointer = ptr;
unregisterBatch.refArray = mstxMemRegionRef;
MstxMemRegionsUnregisterFunc(domainHandle, &unregisterBatch);
EXPECT_EQ(PTACachingPoolTrace::GetInstance().memUsageMp_[deviceId].totalAllocated, 0);
MstxMemHeapUnregisterFunc(domainHandle, static_cast<mstxMemHeapHandle_t>(const_cast<void *>(ptrHeap)));
EXPECT_EQ(PTACachingPoolTrace::GetInstance().memUsageMp_[deviceId].totalReserved, 0);
}
TEST(MemPoolTraceTest, PTACachingPoolTraceReleaseRegionHandleMpNull)
{
mstxDomainRegistration_t domainRegistration{};
mstxDomainHandle_t domainHandle = &domainRegistration;
mstxMemRegionsUnregisterBatch_t desc;
void const* ptr = reinterpret_cast<void const*>(123);
mstxMemRegionRef_t memRangeDesc = {mstxMemRegionRefType::MSTX_MEM_REGION_REF_TYPE_HANDLE, ptr};
desc.refCount = 1;
desc.refArray = &memRangeDesc;
PTACachingPoolTrace::GetInstance().regionHandleMp_ = {};
PTACachingPoolTrace::GetInstance().Release(domainHandle, &desc);
}
TEST(MemPoolTraceTest, PTACachingPoolTraceCreateDomainReturnNull)
{
EXPECT_EQ(PTACachingPoolTrace::GetInstance().CreateDomain("msmemscope"), nullptr);
}
TEST(MemPoolTraceTest, PTAWorkspacePoolTraceTestPtaHeapRegisterAndRegionRegister)
{
auto domainHandle = MstxDomainCreateAFunc("ptaWorkspace");
EXPECT_EQ(domainHandle, PTAWorkspacePoolTrace::GetInstance().ptaWorkspaceDomain_);
mstxMemHeapDesc_t heapDesc;
uint32_t deviceId = 1;
void const* ptrHeap = reinterpret_cast<void const*>(123);
mstxMemVirtualRangeDesc_t memRangeDesc{deviceId, ptrHeap, 500};
heapDesc.typeSpecificDesc = reinterpret_cast<void const*>(&memRangeDesc);
auto HeapHandle = MstxMemHeapRegisterFunc(domainHandle, &heapDesc);
EXPECT_EQ(PTAWorkspacePoolTrace::GetInstance().memUsageMp_[deviceId].totalReserved, 500);
void const* ptr = reinterpret_cast<void const*>(123);
memRangeDesc = {deviceId, ptr, 50};
mstxMemRegionsRegisterBatch_t desc;
desc.regionCount = 1;
desc.regionDescArray = reinterpret_cast<const void *>(&memRangeDesc);
auto mstxMemRegionHandle = mstxMemRegionHandle_t {};
auto handleArrayOut = &mstxMemRegionHandle;
desc.regionHandleArrayOut = handleArrayOut;
MstxMemRegionsRegisterFunc(domainHandle, &desc);
EXPECT_EQ(PTAWorkspacePoolTrace::GetInstance().memUsageMp_[deviceId].totalAllocated, 50);
mstxMemRegionsUnregisterBatch_t unregisterBatch;
unregisterBatch.refCount = 1;
auto mstxMemRegionRef = new mstxMemRegionRef_t {};
mstxMemRegionRef->pointer = ptr;
unregisterBatch.refArray = mstxMemRegionRef;
MstxMemRegionsUnregisterFunc(domainHandle, &unregisterBatch);
EXPECT_EQ(PTAWorkspacePoolTrace::GetInstance().memUsageMp_[deviceId].totalAllocated, 0);
MstxMemHeapUnregisterFunc(domainHandle, static_cast<mstxMemHeapHandle_t>(const_cast<void *>(ptrHeap)));
EXPECT_EQ(PTAWorkspacePoolTrace::GetInstance().memUsageMp_[deviceId].totalReserved, 0);
}
TEST(MemPoolTraceTest, PTAWorkspacePoolTraceReleaseRegionHandleMpNull)
{
mstxDomainRegistration_t domainRegistration{};
mstxDomainHandle_t domainHandle = &domainRegistration;
mstxMemRegionsUnregisterBatch_t desc;
void const* ptr = reinterpret_cast<void const*>(123);
mstxMemRegionRef_t memRangeDesc = {mstxMemRegionRefType::MSTX_MEM_REGION_REF_TYPE_HANDLE, ptr};
desc.refCount = 1;
desc.refArray = &memRangeDesc;
PTAWorkspacePoolTrace::GetInstance().regionHandleMp_ = {};
PTAWorkspacePoolTrace::GetInstance().Release(domainHandle, &desc);
}
TEST(MemPoolTraceTest, PTAWorkspacePoolTraceCreateDomainReturnNull)
{
EXPECT_EQ(PTAWorkspacePoolTrace::GetInstance().CreateDomain("msmemscope"), nullptr);
}