* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* ------------------------------------------------------------------------- */
#include <gtest/gtest.h>
#include <thread>
#include <vector>
#include <dlfcn.h>
#define MSTX_NO_IMPL
#include "mstx/ms_tools_ext.h"
#undef MSTX_NO_IMPL
#ifdef LANG_C_TEST
#define EXTERN extern "C"
#else
#define EXTERN extern
#endif
using GetRangeStateFunc = int (*)();
using GetDomainStateFunc = int (*)();
using GetDomainMarkMessageFunc = const char *(*)();
const int TEST_RANGE_ID = 123456;
const uint64_t TOOL_ID = 1234;
EXTERN mstxMemHeapHandle_t MstxMemHeapRegisterTest(mstxDomainHandle_t domain, mstxMemHeapDesc_t const *desc);
EXTERN void MstxMemHeapUnregisterTest(mstxDomainHandle_t domain, mstxMemHeapHandle_t heap);
EXTERN void MstxMemRegionsRegisterTest(mstxDomainHandle_t domain, mstxMemRegionsRegisterBatch_t const* desc);
EXTERN void MstxMemRegionsUnregisterTest(mstxDomainHandle_t domain, mstxMemRegionsUnregisterBatch_t const* desc);
EXTERN void MstxMemPermissionsAssignTest(mstxDomainHandle_t domain, const mstxMemPermissionsAssignBatch_t *desc);
EXTERN void MstxMarkATest(const char *message, aclrtStream stream);
EXTERN mstxRangeId MstxRangeStartATest(const char *message, aclrtStream stream);
EXTERN void MstxRangeEndTest(mstxRangeId id);
EXTERN void MstxGetToolIdTest(uint64_t *id);
EXTERN mstxDomainHandle_t MstxDomainCreateATest(const char *name);
EXTERN void MstxDomainDestroyTest(mstxDomainHandle_t handle);
EXTERN void MstxDomainMarkATest(mstxDomainHandle_t handle, const char *message, aclrtStream stream);
EXTERN mstxRangeId MstxDomainRangeStartATest(mstxDomainHandle_t handle, const char *message, aclrtStream stream);
EXTERN void MstxDomainRangeEndTest(mstxDomainHandle_t handle, mstxRangeId id);
EXTERN int GetInitResult();
EXTERN void MstxDeInit();
EXTERN int MstxGetModuleFuncTableTest(mstxFuncModule module, mstxFuncTable* outTable, unsigned int* outSize);
EXTERN void MstxInit();
EXTERN void RefreshUninitMstxContextFuncPtrTest(int forceNull);
EXTERN int CheckFuncPointerAllNull();
void TestMultiThreadInit()
{
mstxRangeId id = MstxRangeStartATest("test", NULL);
MstxRangeEndTest(id);
}
TEST(CoreApi, test_init_with_multi_thread)
{
std::vector<std::thread> threads;
int threadNum = 100;
for (int i = 0; i < threadNum; ++i) {
threads.emplace_back(TestMultiThreadInit);
}
for (int i = 0; i < threadNum; ++i) {
threads[i].join();
}
}
TEST(CoreApi, test_range_start_a_with_range_end_if_init_success_expect_success)
{
MstxInit();
if (GetInitResult() != 2) {
return;
}
mstxRangeId id = MstxRangeStartATest("test", nullptr);
ASSERT_TRUE(id == TEST_RANGE_ID);
GetRangeStateFunc stateFunc = (GetRangeStateFunc)dlsym(nullptr, "GetRangeState");
if (stateFunc) {
ASSERT_TRUE(stateFunc() == 1);
}
MstxRangeEndTest(id);
if (stateFunc) {
ASSERT_TRUE(stateFunc() == 0);
}
}
TEST(CoreApi, test_range_start_a_with_range_end_if_init_failed_expect_invalid_id)
{
MstxInit();
if (GetInitResult() == 2) {
return;
}
mstxRangeId id = MstxRangeStartATest("test", nullptr);
ASSERT_TRUE(id == 0);
MstxRangeEndTest(id);
}
TEST(CoreApi, test_get_tool_id_if_init_success_expect_success)
{
uint64_t id = 0;
MstxInit();
if (GetInitResult() != 2) {
return;
}
MstxGetToolIdTest(&id);
ASSERT_TRUE(id == TOOL_ID);
}
TEST(CoreApi, test_get_tool_id_if_init_failed_expect_default_id)
{
uint64_t id = 7;
MstxInit();
if (GetInitResult() == 2) {
return;
}
MstxGetToolIdTest(&id);
ASSERT_TRUE(id == 0);
}
TEST(CoreApi, test_domain_createA_if_init_fail_expect_nullptr)
{
MstxInit();
if (GetInitResult() == 2) {
return;
}
auto handle = MstxDomainCreateATest("domain");
ASSERT_TRUE(handle == nullptr);
}
TEST(CoreApi, test_init_with_range_end_func_expect_success)
{
MstxDeInit();
ASSERT_TRUE(GetInitResult() == 0);
MstxRangeEndTest(0);
ASSERT_TRUE(GetInitResult() != 0);
}
TEST(CoreApi, test_domain_createA_if_init_succ_expect_valid_handle)
{
MstxInit();
if (GetInitResult() != 2) {
return;
}
auto handle = MstxDomainCreateATest("domain");
ASSERT_TRUE(handle != nullptr);
}
TEST(CoreApi, test_domain_destroy_if_init_success_expect_destroy_handle)
{
MstxInit();
if (GetInitResult() != 2) {
return;
}
auto handle = MstxDomainCreateATest("domain");
MstxDomainDestroyTest(handle);
GetDomainStateFunc stateFunc = (GetDomainStateFunc)dlsym(nullptr, "GetDomainState");
if (stateFunc) {
ASSERT_TRUE(stateFunc() == 0);
}
}
TEST(CoreApi, test_domain_destroy_if_init_fail_expect_destroy_handle)
{
MstxInit();
if (GetInitResult() == 2) {
return;
}
auto handle = MstxDomainCreateATest("domain");
MstxDomainDestroyTest(handle);
}
TEST(CoreApi, test_domain_markA_with_input_if_init_fail_expect_not_save_mark_msg)
{
MstxInit();
if (GetInitResult() == 2) {
return;
}
int msgLen = 4;
std::string msg = "test";
auto handle = MstxDomainCreateATest("domain");
MstxDomainMarkATest(handle, msg.c_str(), nullptr);
GetDomainMarkMessageFunc func = (GetDomainMarkMessageFunc)dlsym(nullptr, "GetDomainMarkMessage");
if (func) {
const char *domainMarkMsg = func();
ASSERT_TRUE(domainMarkMsg == nullptr);
}
}
TEST(CoreApi, test_domain_markA_with_input_if_init_success_expect_save_mark_msg)
{
MstxInit();
if (GetInitResult() != 2) {
return;
}
int msgLen = 4;
std::string msg = "test";
auto handle = MstxDomainCreateATest("domain");
MstxDomainMarkATest(handle, msg.c_str(), nullptr);
GetDomainMarkMessageFunc func = (GetDomainMarkMessageFunc)dlsym(nullptr, "GetDomainMarkMessage");
if (func) {
const char *domainMarkMsg = func();
EXPECT_STREQ(domainMarkMsg, msg.c_str());
}
}
TEST(CoreApi, test_domain_range_startA_with_input_if_init_fail_expect_zero_range_id)
{
MstxInit();
if (GetInitResult() == 2) {
return;
}
int msgLen = 4;
std::string msg = "test";
auto handle = MstxDomainCreateATest("domain");
auto id = MstxDomainRangeStartATest(handle, msg.c_str(), nullptr);
ASSERT_TRUE(id == 0);
GetRangeStateFunc stateFunc = (GetRangeStateFunc)dlsym(nullptr, "GetRangeState");
if (stateFunc) {
ASSERT_TRUE(stateFunc() == 0);
}
MstxDomainRangeEndTest(handle, id);
}
TEST(CoreApi, test_domain_range_if_init_succ_expect_success)
{
MstxInit();
if (GetInitResult() != 2) {
return;
}
int msgLen = 4;
std::string msg = "test";
auto handle = MstxDomainCreateATest("domain");
auto id = MstxDomainRangeStartATest(handle, msg.c_str(), nullptr);
ASSERT_TRUE(id == TEST_RANGE_ID);
GetRangeStateFunc stateFunc = (GetRangeStateFunc)dlsym(nullptr, "GetRangeState");
if (stateFunc) {
ASSERT_TRUE(stateFunc() == 1);
}
MstxDomainRangeEndTest(handle, id);
if (stateFunc) {
ASSERT_TRUE(stateFunc() == 0);
}
}
TEST(CoreApi, test_getFuncTable_with_invalid_input_expect_failed)
{
mstxFuncTable outTable;
unsigned int outSize;
ASSERT_TRUE(MstxGetModuleFuncTableTest(MSTX_API_MODULE_INVALID, &outTable, &outSize) == MSTX_FAIL);
ASSERT_TRUE(MstxGetModuleFuncTableTest(MSTX_API_MODULE_CORE, nullptr, &outSize) == MSTX_FAIL);
ASSERT_TRUE(MstxGetModuleFuncTableTest(MSTX_API_MODULE_CORE, &outTable, nullptr) == MSTX_FAIL);
ASSERT_TRUE(MstxGetModuleFuncTableTest(MSTX_API_MODULE_CORE_DOMAIN, &outTable, nullptr) == MSTX_FAIL);
}
TEST(CoreApi, test_double_init)
{
MstxDeInit();
ASSERT_TRUE(GetInitResult() == 0);
MstxRangeEndTest(0);
ASSERT_TRUE(GetInitResult() != 0);
MstxInit();
ASSERT_TRUE(GetInitResult() != 0);
}
TEST(CoreApi, test_refresh_unit_api_to_null)
{
MstxDeInit();
RefreshUninitMstxContextFuncPtrTest(0);
ASSERT_TRUE(CheckFuncPointerAllNull() == 1);
MstxDeInit();
RefreshUninitMstxContextFuncPtrTest(1);
ASSERT_TRUE(CheckFuncPointerAllNull() == 1);
MstxInit();
RefreshUninitMstxContextFuncPtrTest(1);
ASSERT_TRUE(CheckFuncPointerAllNull() == 1);
}
TEST(CoreApi, test_init_with_mark_a)
{
MstxDeInit();
ASSERT_TRUE(GetInitResult() == 0);
MstxMarkATest("test", nullptr);
ASSERT_TRUE(GetInitResult() != 0);
RefreshUninitMstxContextFuncPtrTest(1);
MstxMarkATest("test", nullptr);
}
TEST(CoreApi, test_init_with_mem_heap_register_expect_init_success)
{
MstxDeInit();
ASSERT_TRUE(GetInitResult() == 0);
MstxMemHeapRegisterTest(globalDomain, nullptr);
ASSERT_TRUE(GetInitResult() != 0);
}
TEST(CoreApi, test_init_with_mem_heap_unregisterr_expect_init_success)
{
MstxDeInit();
ASSERT_TRUE(GetInitResult() == 0);
MstxMemHeapUnregisterTest(globalDomain, {});
ASSERT_TRUE(GetInitResult() != 0);
}
TEST(CoreApi, test_init_with_mem_regions_registerr_expect_init_success)
{
MstxDeInit();
ASSERT_TRUE(GetInitResult() == 0);
MstxMemRegionsRegisterTest(globalDomain, nullptr);
ASSERT_TRUE(GetInitResult() != 0);
}
TEST(CoreApi, test_init_with_mem_regions_unregisterr_expect_init_success)
{
MstxDeInit();
ASSERT_TRUE(GetInitResult() == 0);
MstxMemRegionsUnregisterTest(globalDomain, nullptr);
ASSERT_TRUE(GetInitResult() != 0);
}
TEST(CoreApi, test_init_with_mem_permission_assign_expect_init_success)
{
MstxDeInit();
ASSERT_TRUE(GetInitResult() == 0);
MstxMemPermissionsAssignTest(globalDomain, nullptr);
ASSERT_TRUE(GetInitResult() != 0);
}
TEST(CoreApi, test_init_with_domain_createA_expect_init_success)
{
MstxDeInit();
ASSERT_TRUE(GetInitResult() == 0);
MstxDomainCreateATest("domain");
ASSERT_TRUE(GetInitResult() != 0);
}
TEST(CoreApi, test_init_with_domain_destroy_expect_init_success)
{
MstxDeInit();
ASSERT_TRUE(GetInitResult() == 0);
MstxDomainDestroyTest(nullptr);
ASSERT_TRUE(GetInitResult() != 0);
}
TEST(CoreApi, test_init_with_domain_markA_expect_init_success)
{
MstxDeInit();
ASSERT_TRUE(GetInitResult() == 0);
MstxDomainMarkATest(nullptr, "test", nullptr);
ASSERT_TRUE(GetInitResult() != 0);
}
TEST(CoreApi, test_init_with_domain_range_startA_expect_init_success)
{
MstxDeInit();
ASSERT_TRUE(GetInitResult() == 0);
MstxDomainRangeStartATest(nullptr, "test", nullptr);
ASSERT_TRUE(GetInitResult() != 0);
}
TEST(CoreApi, test_init_with_domain_range_end_expect_init_success)
{
MstxDeInit();
ASSERT_TRUE(GetInitResult() == 0);
MstxDomainRangeEndTest(nullptr, 0);
ASSERT_TRUE(GetInitResult() != 0);
}