* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <vector>
#include "gtest/gtest.h"
#include <mockcpp/mockcpp.hpp>
#include "hccl_mem_defs.h"
#define private public
#define protected public
#include "global_mem_record.h"
#include "global_mem_manager.h"
#undef protected
#undef private
using namespace std;
using namespace hccl;
class GlobalMemMgrMulThreadTest : public testing::Test
{
protected:
static void SetUpTestCase()
{
std::cout << "--GlobalMemMgrMulThreadTest SetUP--" << std::endl;
}
static void TearDownTestCase()
{
std::cout << "--GlobalMemMgrMulThreadTest TearDown--" << std::endl;
}
virtual void SetUp()
{
std::cout << "A Test SetUP" << std::endl;
}
virtual void TearDown()
{
std::cout << "A Test TearDown" << std::endl;
}
};
HcclResult hrtGetDeviceRefreshStubSameId(s32* deviceLogicID)
{
*deviceLogicID = 0;
return HCCL_SUCCESS;
}
TEST_F(GlobalMemMgrMulThreadTest, Ut_GlobalMemMgr_GetInstance_When_DeviceSameID_Expect_Success)
{
MOCKER(hrtGetDeviceRefresh).stubs().will(invoke(hrtGetDeviceRefreshStubSameId));
const int threadCount = 16;
std::vector<std::thread> threads;
std::vector<GlobalMemRegMgr*> instances(threadCount);
for (int i = 0; i < threadCount; i++) {
threads.emplace_back([&, i]() {
instances[i] = &GlobalMemRegMgr::GetInstance();
});
}
for (auto& thread : threads) {
thread.join();
}
for (int i = 1; i < threadCount; i++) {
EXPECT_EQ(instances[i], instances[0]);
}
GlobalMockObject::verify();
}
HcclResult hrtGetDeviceRefreshDifferentIdForTest(s32* deviceLogicID)
{
static std::atomic<int> threadId(0);
int id = threadId++ % MAX_MODULE_DEVICE_NUM;
*deviceLogicID = id;
return HCCL_SUCCESS;
}
TEST_F(GlobalMemMgrMulThreadTest, Ut_GlobalMemMgr_GetInstance_When_DifferentDeviceID_Expect_Success)
{
MOCKER(hrtGetDeviceRefresh).stubs().will(invoke(hrtGetDeviceRefreshDifferentIdForTest));
const int threadCount = 16;
std::vector<std::thread> threads;
std::vector<GlobalMemRegMgr*> instances(threadCount);
for (int i = 0; i < threadCount; i++) {
threads.emplace_back([&, i]() {
instances[i] = &GlobalMemRegMgr::GetInstance();
});
}
for (auto& thread : threads) {
thread.join();
}
std::unordered_set<GlobalMemRegMgr*> uniqueInstances;
for (auto instance : instances) {
uniqueInstances.insert(instance);
}
EXPECT_EQ(uniqueInstances.size(), threadCount);
GlobalMockObject::verify();
}
TEST_F(GlobalMemMgrMulThreadTest, Ut_GlobalMemMgr_InicNic_When_DeviceSameID_Expect_Success)
{
MOCKER(hrtGetDeviceRefresh).stubs().will(invoke(hrtGetDeviceRefreshStubSameId));
MOCKER(HcclNetInit).stubs().with(mockcpp::any()).will(returnValue(HCCL_SUCCESS));
MOCKER(HcclNetDeInit).stubs().with(mockcpp::any()).will(returnValue(HCCL_SUCCESS));
const int threadCount = 16;
std::vector<std::thread> threads;
std::vector<HcclResult> ret(threadCount);
for (int i = 0; i < threadCount; i++) {
threads.emplace_back([&, i]() {
ret[i] = GlobalMemRegMgr::GetInstance().InitNic();;
});
}
for (auto& thread : threads) {
thread.join();
}
for (int i = 1; i < threadCount; i++) {
EXPECT_EQ(ret[i], HCCL_SUCCESS);
}
GlobalMemRegMgr::GetInstance().DeInitNic();;
GlobalMockObject::verify();
}