* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "gtest/gtest.h"
#include "mockcpp/mockcpp.hpp"
#include "prof_mstx_plugin.h"
#include "mstx_def.h"
using namespace ProfApi::MstxPlugin;
namespace {
int g_markACalled = 0;
int g_rangeStartACalled = 0;
int g_rangeEndCalled = 0;
int g_domainCreateCalled = 0;
int g_domainDestroyCalled = 0;
int g_domainMarkACalled = 0;
int g_domainRangeStartCalled = 0;
int g_domainRangeEndCalled = 0;
MstxDomainHandle g_fakeDomain;
void ResetCounters()
{
g_markACalled = 0;
g_rangeStartACalled = 0;
g_rangeEndCalled = 0;
g_domainCreateCalled = 0;
g_domainDestroyCalled = 0;
g_domainMarkACalled = 0;
g_domainRangeStartCalled = 0;
g_domainRangeEndCalled = 0;
}
void StubMstxMarkA(const char *, aclrtStream)
{
g_markACalled++;
}
mstxRangeId StubMstxRangeStartA(const char *, aclrtStream)
{
g_rangeStartACalled++;
return 42;
}
void StubMstxRangeEnd(mstxRangeId)
{
g_rangeEndCalled++;
}
mstxDomainHandle_t StubMstxDomainCreate(const char *)
{
g_domainCreateCalled++;
return &g_fakeDomain;
}
void StubMstxDomainDestroy(mstxDomainHandle_t)
{
g_domainDestroyCalled++;
}
void StubMstxDomainMarkA(mstxDomainHandle_t, const char *, aclrtStream)
{
g_domainMarkACalled++;
}
mstxRangeId StubMstxDomainRangeStartA(mstxDomainHandle_t, const char *, aclrtStream)
{
g_domainRangeStartCalled++;
return 99;
}
void StubMstxDomainRangeEnd(mstxDomainHandle_t, mstxRangeId)
{
g_domainRangeEndCalled++;
}
int FakeMstxInit(MstxGetModuleFuncTableFunc getFuncTable)
{
MstxFuncTable outTable = nullptr;
unsigned int outSize = 0;
if (getFuncTable(MSTX_API_MODULE_CORE, &outTable, &outSize) == MSTX_SUCCESS &&
outTable != nullptr && outSize > MSTX_FUNC_RANGE_END) {
*(outTable[MSTX_FUNC_MARKA]) = reinterpret_cast<MstxFuncPointer>(StubMstxMarkA);
*(outTable[MSTX_FUNC_RANGE_STARTA]) = reinterpret_cast<MstxFuncPointer>(StubMstxRangeStartA);
*(outTable[MSTX_FUNC_RANGE_END]) = reinterpret_cast<MstxFuncPointer>(StubMstxRangeEnd);
}
if (getFuncTable(MSTX_API_MODULE_CORE_DOMAIN, &outTable, &outSize) == MSTX_SUCCESS &&
outTable != nullptr && outSize > MSTX_FUNC_DOMAIN_RANGE_END) {
*(outTable[MSTX_FUNC_DOMAIN_CREATEA]) = reinterpret_cast<MstxFuncPointer>(StubMstxDomainCreate);
*(outTable[MSTX_FUNC_DOMAIN_DESTROY]) = reinterpret_cast<MstxFuncPointer>(StubMstxDomainDestroy);
*(outTable[MSTX_FUNC_DOMAIN_MARKA]) = reinterpret_cast<MstxFuncPointer>(StubMstxDomainMarkA);
*(outTable[MSTX_FUNC_DOMAIN_RANGE_STARTA]) = reinterpret_cast<MstxFuncPointer>(StubMstxDomainRangeStartA);
*(outTable[MSTX_FUNC_DOMAIN_RANGE_END]) = reinterpret_cast<MstxFuncPointer>(StubMstxDomainRangeEnd);
}
return MSTX_SUCCESS;
}
int FakeMstxInitFail(MstxGetModuleFuncTableFunc)
{
return MSTX_FAIL;
}
int FakeGetTableFail(MstxFuncModule, MstxFuncTable *, unsigned int *)
{
return MSTX_FAIL;
}
int FakeGetTableEmpty(MstxFuncModule, MstxFuncTable *outTable, unsigned int *outSize)
{
*outTable = nullptr;
*outSize = 0;
return MSTX_SUCCESS;
}
}
class PROF_MSTX_PLUGIN_UTEST : public testing::Test {
protected:
virtual void SetUp()
{
ResetCounters();
}
virtual void TearDown()
{
GlobalMockObject::verify();
}
};
TEST_F(PROF_MSTX_PLUGIN_UTEST, ProfRegisterMstxFunc_NullInitFunc)
{
::ProfRegisterMstxFunc(nullptr, PROF_MODULE_MSPROF);
::ProfRegisterMstxFunc(nullptr, PROF_MODULE_MSPTI);
}
TEST_F(PROF_MSTX_PLUGIN_UTEST, ProfRegisterMstxFunc_InvalidModule)
{
::ProfRegisterMstxFunc(FakeMstxInitFail, MODULE_INVALID);
}
TEST_F(PROF_MSTX_PLUGIN_UTEST, ProfRegisterMstxFunc_Msprof_Then_ImplDispatch)
{
::ProfRegisterMstxFunc(FakeMstxInit, PROF_MODULE_MSPROF);
::ProfEnableMstxFunc(PROF_MODULE_MSPROF);
MstxMarkAImpl("hello", nullptr);
EXPECT_EQ(1, g_markACalled);
EXPECT_EQ(42u, MstxRangeStartAImpl("range", nullptr));
EXPECT_EQ(1, g_rangeStartACalled);
MstxRangeEndImpl(42);
EXPECT_EQ(1, g_rangeEndCalled);
uint64_t toolId = 0;
MstxGetToolIdImpl(&toolId);
EXPECT_EQ(static_cast<uint64_t>(MSTX_TOOLS_MSPROF_ID), toolId);
auto domain = MstxDomainCreateAImpl("dom");
EXPECT_EQ(&g_fakeDomain, domain);
EXPECT_GE(g_domainCreateCalled, 1);
MstxDomainMarkAImpl(domain, "msg", nullptr);
EXPECT_EQ(1, g_domainMarkACalled);
EXPECT_EQ(99u, MstxDomainRangeStartAImpl(domain, "msg", nullptr));
EXPECT_EQ(1, g_domainRangeStartCalled);
MstxDomainRangeEndImpl(domain, 99);
EXPECT_EQ(1, g_domainRangeEndCalled);
MstxDomainDestroyImpl(domain);
EXPECT_GE(g_domainDestroyCalled, 1);
}
TEST_F(PROF_MSTX_PLUGIN_UTEST, ProfRegisterMstxFunc_Mspti_Then_EnableSwitch)
{
::ProfRegisterMstxFunc(FakeMstxInit, PROF_MODULE_MSPROF);
::ProfRegisterMstxFunc(FakeMstxInit, PROF_MODULE_MSPTI);
::ProfEnableMstxFunc(PROF_MODULE_MSPTI);
MstxMarkAImpl("hi", nullptr);
EXPECT_GE(g_markACalled, 1);
auto domain = MstxDomainCreateAImpl("dom");
EXPECT_EQ(&g_fakeDomain, domain);
::ProfEnableMstxFunc(PROF_MODULE_MSPROF);
MstxRangeEndImpl(0);
}
TEST_F(PROF_MSTX_PLUGIN_UTEST, ImplDispatch_NullDomainPassthrough)
{
MstxDomainDestroyImpl(nullptr);
MstxDomainMarkAImpl(nullptr, "x", nullptr);
EXPECT_GE(g_domainMarkACalled, 0);
MstxDomainRangeEndImpl(nullptr, 0);
}
TEST_F(PROF_MSTX_PLUGIN_UTEST, MsprofMstxGetModuleFuncTable_Variants)
{
MstxFuncTable outTable = nullptr;
unsigned int outSize = 0;
EXPECT_EQ(MSTX_SUCCESS, MsprofMstxGetModuleFuncTable(MSTX_API_MODULE_CORE, &outTable, &outSize));
EXPECT_NE(outTable, nullptr);
EXPECT_GT(outSize, 0u);
EXPECT_EQ(MSTX_SUCCESS, MsprofMstxGetModuleFuncTable(MSTX_API_MODULE_CORE_DOMAIN, &outTable, &outSize));
EXPECT_NE(outTable, nullptr);
EXPECT_GT(outSize, 0u);
EXPECT_EQ(MSTX_FAIL, MsprofMstxGetModuleFuncTable(MSTX_API_MODULE_INVALID, &outTable, &outSize));
EXPECT_EQ(MSTX_FAIL, MsprofMstxGetModuleFuncTable(MSTX_API_MODULE_CORE, nullptr, &outSize));
EXPECT_EQ(MSTX_FAIL, MsprofMstxGetModuleFuncTable(MSTX_API_MODULE_CORE, &outTable, nullptr));
}
TEST_F(PROF_MSTX_PLUGIN_UTEST, MsptiMstxGetModuleFuncTable_Variants)
{
MstxFuncTable outTable = nullptr;
unsigned int outSize = 0;
EXPECT_EQ(MSTX_SUCCESS, MsptiMstxGetModuleFuncTable(MSTX_API_MODULE_CORE, &outTable, &outSize));
EXPECT_EQ(MSTX_SUCCESS, MsptiMstxGetModuleFuncTable(MSTX_API_MODULE_CORE_DOMAIN, &outTable, &outSize));
EXPECT_EQ(MSTX_FAIL, MsptiMstxGetModuleFuncTable(MSTX_API_MODULE_INVALID, &outTable, &outSize));
}
TEST_F(PROF_MSTX_PLUGIN_UTEST, InitInjectionMstx_Null)
{
EXPECT_EQ(MSTX_FAIL, ::InitInjectionMstx(nullptr));
}
TEST_F(PROF_MSTX_PLUGIN_UTEST, InitInjectionMstx_GetTableFail)
{
EXPECT_EQ(MSTX_SUCCESS, ::InitInjectionMstx(FakeGetTableFail));
}
TEST_F(PROF_MSTX_PLUGIN_UTEST, InitInjectionMstx_EmptyTable)
{
EXPECT_EQ(MSTX_SUCCESS, ::InitInjectionMstx(FakeGetTableEmpty));
}
TEST_F(PROF_MSTX_PLUGIN_UTEST, InitInjectionMstx_RealRoundTrip)
{
EXPECT_EQ(MSTX_SUCCESS, ::InitInjectionMstx(MsprofMstxGetModuleFuncTable));
EXPECT_EQ(MSTX_SUCCESS, ::InitInjectionMstx(MsptiMstxGetModuleFuncTable));
}
TEST_F(PROF_MSTX_PLUGIN_UTEST, EnableMstxFunc)
{
::ProfEnableMstxFunc(PROF_MODULE_MSPTI);
::ProfEnableMstxFunc(PROF_MODULE_MSPROF);
}