* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "gtest/gtest.h"
#include <mockcpp/mokc.h>
#include <mockcpp/mockcpp.hpp>
#define private public
#define protected public
#include "base_config.h"
#include "cfg_field.h"
#include "env_config.h"
#include "env_func.h"
#include "coll_service_default_impl.h"
#include "communicator_impl.h"
#include "virtual_topo.h"
#include "json_parser.h"
#include "internal_exception.h"
#include "dev_ub_connection.h"
#include "virtual_topo_stub.h"
#include "coll_alg_component_builder.h"
#include "local_rma_buffer.h"
#include "local_ub_rma_buffer.h"
#include "local_ipc_rma_buffer.h"
#include "not_support_exception.h"
#include "rdma_handle_manager.h"
#undef private
#undef protected
#include <hccl/hccl_types.h>
#include "coll_service_base.h"
#include "ccu_ctx_mgr.h"
#include "ccu_transport_manager.h"
#include "hccl.h"
using namespace Hccl;
using namespace std;
HcclResult GetProfilingInfoStub(
s32 deviceLogicId, CcuTaskArg &ccuTaskArg, const uint64_t executorId, std::vector<std::vector<CcuProfilingInfo>> &ccuProfilingInfo);
class CollServiceDefaultImplTest : public testing::Test {
protected:
static void SetUpTestCase()
{
std::cout << "CollServiceDefaultImpl SetUP" << std::endl;
}
static void TearDownTestCase()
{
std::cout << "CollServiceDefaultImpl TearDown" << std::endl;
}
virtual void SetUp()
{
std::cout << "A Test case in CollServiceDefaultImpl SetUP" << std::endl;
}
virtual void TearDown()
{
std::cout << "A Test case in CollServiceDefaultImpl TearDown" << std::endl;
GlobalMockObject::verify();
}
};
const char filePath[] = "ranktable.json";
const char ranktable4pPath[] = "ranktable.json";
const char topoPath[] = "topo.json";
const std::string RankTable1Ser8Dev = R"(
{
"server_count":"1",
"server_list":
[
{
"device":[
{
"device_id":"0",
"rank_id":"0"
},
{
"device_id":"1",
"rank_id":"1"
},
{
"device_id":"2",
"rank_id":"2"
},
{
"device_id":"3",
"rank_id":"3"
},
{
"device_id":"4",
"rank_id":"4"
},
{
"device_id":"5",
"rank_id":"5"
},
{
"device_id":"6",
"rank_id":"6"
},
{
"device_id":"7",
"rank_id":"7"
}
],
"server_id":"1"
}
],
"status":"completed",
"version":"1.0"
}
)";
const std::string RankTable4p = R"(
{
"version": "2.0",
"rank_count" : "4",
"rank_list": [
{
"rank_id": 0,
"local_id": 0,
"level_list": [
{
"level": 0,
"id" : "az0-rack0",
"fabric_type": "INNER",
"rank_addr_type": "",
"rank_addrs": []
}
]
},
{
"rank_id": 1,
"local_id": 1,
"level_list": [
{
"level": 0,
"id" : "az0-rack0",
"fabric_type": "INNER",
"rank_addr_type": "",
"rank_addrs": []
}
]
},
{
"rank_id": 2,
"local_id": 2,
"level_list": [
{
"level": 0,
"id" : "az0-rack0",
"fabric_type": "INNER",
"rank_addr_type": "",
"rank_addrs": []
}
]
},
{
"rank_id": 3,
"local_id": 3,
"level_list": [
{
"level": 0,
"id" : "az0-rack0",
"fabric_type": "INNER",
"rank_addr_type": "",
"rank_addrs": []
}
]
}
]
}
)";
const std::string Topo1Ser8Dev = R"(
{
"version": "2.0",
"hardware_type" : "950-2D-Fullmsh_64_plus_1",
"peer_count" : 4,
"peer_list" :[
{ "id" : 0},
{ "id" : 1},
{ "id" : 2},
{ "id" : 3}
],
"edge_count" : 6,
"edge_list": [
{
"level": 0,
"protocol": "UB-CTP",
"u_endpoint": {"type": "PEER", "id": 0, "addr": "192.168.7.100", "position" : "device"},
"v_endpoint": {"type": "PEER", "id": 1, "addr": "192.168.7.101", "position" : "device"}
},
{
"level": 0,
"protocol": "UB-CTP",
"u_endpoint": {"type": "PEER", "id": 0, "addr": "192.168.6.100", "position" : "device"},
"v_endpoint": {"type": "PEER", "id": 2, "addr": "192.168.6.101", "position" : "device"}
},
{
"level": 0,
"protocol": "UB-CTP",
"u_endpoint": {"type": "PEER", "id": 0, "addr": "192.168.5.100", "position" : "device"},
"v_endpoint": {"type": "PEER", "id": 3, "addr": "192.168.5.101", "position" : "device"}
},
{
"level": 0,
"protocol": "UB-CTP",
"u_endpoint": {"type": "PEER", "id": 1, "addr": "192.168.8.100", "position" : "device"},
"v_endpoint": {"type": "PEER", "id": 2, "addr": "192.168.8.101", "position" : "device"}
},
{
"level": 0,
"protocol": "UB-CTP",
"u_endpoint": {"type": "PEER", "id": 1, "addr": "192.168.9.100", "position" : "device"},
"v_endpoint": {"type": "PEER", "id": 3, "addr": "192.168.9.101", "position" : "device"}
},
{
"level": 0,
"protocol": "UB-CTP",
"u_endpoint": {"type": "PEER", "id": 2, "addr": "192.168.10.100", "position" : "device"},
"v_endpoint": {"type": "PEER", "id": 3, "addr": "192.168.10.101", "position" : "device"}
}
]
}
)";
static void GenRankTableFile1Ser8Dev()
{
try {
nlohmann::json rankTableJson = nlohmann::json::parse(RankTable1Ser8Dev);
std::ofstream out(filePath, std::ofstream::out);
out << rankTableJson;
} catch (...) {
std::cout << filePath << " generate failed!" << std::endl;
return;
}
std::cout << filePath << " generated." << std::endl;
}
static void DelRankTableFile()
{
int res = unlink(filePath);
if (res == -1) {
std::cout << filePath << " delete failed!" << std::endl;
return;
}
std::cout << filePath << " deleted." << std::endl;
}
static void GenRankTableFile4p()
{
try {
nlohmann::json rankTableJson = nlohmann::json::parse(RankTable4p);
std::ofstream out(ranktable4pPath, std::ofstream::out);
out << rankTableJson;
} catch (...) {
std::cout << ranktable4pPath << " generate failed!" << std::endl;
return;
}
std::cout << ranktable4pPath << " generated." << std::endl;
}
static void DelRankTableFile4p()
{
int res = unlink(ranktable4pPath);
if (res == -1) {
std::cout << ranktable4pPath << " delete failed!" << std::endl;
return;
}
std::cout << ranktable4pPath << " deleted." << std::endl;
}
static void GenTopoFile()
{
try {
nlohmann::json topoJson = nlohmann::json::parse(Topo1Ser8Dev);
std::ofstream out(topoPath, std::ofstream::out);
out << topoJson;
} catch (...) {
std::cout << topoPath << " generate failed!" << std::endl;
return;
}
std::cout << topoPath << " generated." << std::endl;
}
static void DelTopoFile()
{
int res = unlink(topoPath);
if (res == -1) {
std::cout << topoPath << " delete failed!" << std::endl;
return;
}
std::cout << topoPath << " deleted." << std::endl;
}
TEST(CollServiceDefaultImplTest, test_base_register_offload_buf)
{
MOCKER(HrtGetDevice).stubs().will(returnValue(0));
MOCKER(HrtGetDevicePhyIdByIndex).stubs().will(returnValue(static_cast<DevId>(1)));
DevType devType = DevType::DEV_TYPE_950;
MOCKER(HrtGetDeviceType).stubs().will(returnValue(devType));
MOCKER(HrtIpcSetMemoryName).stubs();
MOCKER(HrtDevMemAlignWithPage).stubs();
MOCKER(HrtIpcDestroyMemoryName).stubs();
void *devPtr = nullptr;
MOCKER(HrtMalloc).stubs().with(mockcpp::any(),mockcpp::any()).will(returnValue(devPtr));
MOCKER(HrtMemset).stubs().with(mockcpp::any(), mockcpp::any(), mockcpp::any(), mockcpp::any());
MOCKER_CPP(&CommunicatorImpl::InitNotifyFixedValue).stubs().will(ignoreReturnValue());
GenRankTableFile4p();
GenTopoFile();
MOCKER_CPP(&CommunicatorImpl::InitCollService).stubs().will(returnValue(HcclResult::HCCL_SUCCESS));
MOCKER(HrtSetDevice).stubs().with(mockcpp::any()).will(ignoreReturnValue());
std::shared_ptr<DevBuffer> devBuf = DevBuffer::Create(0x100, 0x100);
LocalIpcRmaBuffer localRmaBuf(devBuf);
MOCKER_CPP(&HcclOneSidedService::AddOpCounterMems).stubs().will(ignoreReturnValue());
MOCKER_CPP(&LocalRmaBufManager::Reg,
LocalRmaBuffer * (LocalRmaBufManager::*)(const string &, BufferType, std::shared_ptr<Buffer>, const PortData &, LinkProtocol))
.stubs()
.will(returnValue(dynamic_cast<LocalRmaBuffer *>(&localRmaBuf)));
CommunicatorImpl comm;
CommParams commParams;
commParams.commId = "commId";
HcclCommConfig config;
HcclCommConfigInit(&config);
commParams.myRank = 1;
commParams.rankSize = 4;
commParams.devType = DevType::DEV_TYPE_950;
comm.rankGraph = make_unique<RankGraph>(0);
comm.rankGraph->peers_[0] = make_shared<NetInstance::Peer>(0, 0, 0, 0);
comm.Init(commParams, RankTable4p, config);
comm.InitDataBufferManager();
DelRankTableFile4p();
DelTopoFile();
CollServiceDefaultImpl service(&comm);
CollOperator op;
op.inputMem = DevBuffer::Create(0x100, 1);
op.outputMem = DevBuffer::Create(0x100, 1);
op.scratchMem = DevBuffer::Create(0x100, 1);
op.opTag = "test_opTag";
EXPECT_NO_THROW(service.RegisterOpBufToBufMgr(op));
}
TEST(CollServiceDefaultImplTest, test_base_register_offload_stream)
{
MOCKER(HrtStreamDestroy).stubs();
void* temp = (void *)0x1;
MOCKER(HrtStreamCreateWithFlags).stubs().will(returnValue(temp));
MOCKER(HrtGetStreamId).stubs().will(returnValue(0));
CommunicatorImpl comm;
comm.id = "commId";
comm.streamManager = make_unique<StreamManager>(&comm);
CollServiceDefaultImpl service(&comm);
std::string opTag = "test";
EXPECT_NO_THROW(service.RegisterOffloadMasterStream(opTag, move(make_unique<Stream>())));
}
TEST(CollServiceDefaultImplTest, test_calc_coll_offload_op_res_with_hccl_success_returned)
{
setenv("PRIM_QUEUE_GEN_NAME", "CcuAllReduceMesh1D", 1);
CommunicatorImpl comm;
comm.id = "commId";
comm.streamManager = make_unique<StreamManager>(&comm);
CollServiceDefaultImpl service(&comm);
OpType opType = OpType::ALLREDUCE;
u64 dataSize = 100;
CollOffloadOpResReq resReq1;
resReq1.requiredSubQueNum = 2;
resReq1.requiredScratchMemSize = 0;
VirtualTopoStub virtTopo(0);
string rankTable = "test";
virtTopo.TopoInit91095OneTimesFour(rankTable);
RankId myRank = 0;
u32 rankSize = 4;
CollAlgComponentBuilder collAlgComponentBuilder;
std::shared_ptr<CollAlgComponent> collAlgComponent = collAlgComponentBuilder.SetRankGraph(&virtTopo)
.SetDevType(DevType::DEV_TYPE_950)
.SetMyRank(myRank)
.SetRankSize(rankSize)
.Build();
comm.collAlgComponent = collAlgComponent;
OpExecuteConfig opConfig;
opConfig.accState = AcceleratorState::HOSTCPU_TS;
comm.opExecuteConfig = opConfig;
EXPECT_NO_THROW(comm.CalcCollOffloadOpRes(opType, dataSize, HCCL_DATA_TYPE_INT8, resReq1));
EXPECT_EQ(16, resReq1.requiredSubQueNum);
EXPECT_EQ(256*1024*1024, resReq1.requiredScratchMemSize);
}
class FakeCollAlgComponent : public CollAlgComponent {
public:
FakeCollAlgComponent() : CollAlgComponent(nullptr, DevType::DEV_TYPE_950, 0, 1){};
HcclResult Orchestrate(
const CollAlgOperator &op, const CollAlgParams ¶ms, const string &algName, InsQuePtr queue) override
{
return HcclResult::HCCL_SUCCESS;
}
HcclResult Orchestrate(
const CollAlgOperator &op, const CollAlgParams ¶ms, const string &algName, PrimQuePtr queue) override
{
return HcclResult::HCCL_SUCCESS;
}
};
class FakeCollAlgComponentWithError : public CollAlgComponent {
public:
FakeCollAlgComponentWithError() : CollAlgComponent(nullptr, DevType::DEV_TYPE_950, 0, 1)
{}
HcclResult Orchestrate(
const CollAlgOperator &op, const CollAlgParams ¶ms, const string &algName, InsQuePtr queue) override
{
return HcclResult::HCCL_E_INTERNAL;
}
};
TEST(CollServiceDefaultImplTest, test_orchestrate_with_ins)
{
CommunicatorImpl comm;
comm.id = "test";
CollServiceDefaultImpl service(&comm);
CollOperator op;
op.opMode = OpMode::OPBASE;
op.opType = OpType::ALLREDUCE;
op.reduceOp = ReduceOp::SUM;
op.dataType = DataType::INT8;
op.dataCount = 4;
op.root = 0;
op.scratchMem = DevBuffer::Create(0x100, 1);
std::shared_ptr<FakeCollAlgComponent> collAlgComponent = std::make_shared<FakeCollAlgComponent>();
comm.collAlgComponent = collAlgComponent;
MOCKER_CPP_VIRTUAL(*collAlgComponent,
&CollAlgComponent::Orchestrate,
HcclResult(CollAlgComponent::*)(
const CollAlgOperator &op, const CollAlgParams ¶ms, const string &algName, InsQuePtr queue))
.stubs()
.with(mockcpp::any(), mockcpp::any(), mockcpp::any(), mockcpp::any())
.will(returnValue(HcclResult::HCCL_SUCCESS));
EXPECT_NO_THROW(service.OrchestrateWithIns(op));
}
TEST(CollServiceDefaultImplTest, test_orchestrate_with_ins_throw_exception)
{
CommunicatorImpl comm;
comm.id = "test";
CollServiceDefaultImpl service(&comm);
CollOperator op;
op.opMode = OpMode::OPBASE;
op.opType = OpType::ALLREDUCE;
op.reduceOp = ReduceOp::SUM;
op.dataType = DataType::INT8;
op.dataCount = 4;
op.root = 0;
op.scratchMem = DevBuffer::Create(0x100, 1);
comm.collAlgComponent = std::make_shared<FakeCollAlgComponentWithError>();
EXPECT_THROW(service.OrchestrateWithIns(op), InternalException);
}
TEST(CollServiceDefaultImplTest, alloc_queue_notify_for_single_queue)
{
CommunicatorImpl comm;
auto queueNotifyManager = std::make_unique<QueueNotifyManager>(comm);
auto queueWaitGroupCntNotifyManager = std::make_unique<QueueWaitGroupCntNotifyManager>();
CollServiceDefaultImpl service(&comm);
comm.queueNotifyManager = std::move(queueNotifyManager);
comm.queueWaitGroupCntNotifyManager = std::move(queueWaitGroupCntNotifyManager);
MOCKER_CPP(&QueueWaitGroupCntNotifyManager::ApplyFor).stubs().with();
InsQueue insQueue;
auto insLocalWaitGroup = std::make_unique<InsLocalWaitGroup>(0);
insQueue.Append(std::move(insLocalWaitGroup));
service.AllocQNotifyForSingleQ(insQueue);
GlobalMockObject::reset();
}
TEST(CollServiceDefaultImplTest, alloc_cnt_notify_for_single_queue)
{
CommunicatorImpl comm;
auto queueNotifyManager = std::make_unique<QueueNotifyManager>(comm);
auto queueBcastPostCntNotifyManager = std::make_unique<QueueBcastPostCntNotifyManager>();
CollServiceDefaultImpl service(&comm);
comm.queueNotifyManager = std::move(queueNotifyManager);
comm.queueBcastPostCntNotifyManager = std::move(queueBcastPostCntNotifyManager);
MOCKER_CPP(&QueueBcastPostCntNotifyManager::ApplyFor).stubs().with();
InsQueue insQueue;
auto insLocalBcastPost = std::make_unique<InsLocalBcastPost>(0);
insQueue.Append(std::move(insLocalBcastPost));
service.AllocQNotifyForSingleQ(insQueue);
GlobalMockObject::reset();
}
TEST(CollServiceDefaultImplTest, test_communicator_impl_hostDeviceSyncNotifyManager_GetIdIndex)
{
MOCKER(HrtGetDevice).stubs().will(returnValue(0));
MOCKER(HrtNotifyCreate).stubs().will(returnValue((void *)(1)));
MOCKER(HrtNotifyCreateWithFlag).stubs().will(returnValue((void *)(1)));
MOCKER(HrtGetNotifyID).stubs().will(returnValue(1));
MOCKER(HrtGetDevicePhyIdByIndex).stubs().will(returnValue(static_cast<DevId>(1)));
CommunicatorImpl comm;
comm.hostDeviceSyncNotifyManager = std::make_unique<HostDeviceSyncNotifyManager>();
EXPECT_EQ(comm.hostDeviceSyncNotifyManager.get(), &(comm.GetHostDeviceSyncNotifyManager()));
EXPECT_EQ(0, comm.GetIdIndex());
}
TEST(CollServiceDefaultImplTest, coll_service_default_impl_orchestrate_with_ins_success)
{
CommunicatorImpl comm;
comm.id = "test";
comm.rmaConnectionManager = std::make_unique<RmaConnManager>(comm);
comm.connLocalNotifyManager = std::make_unique<ConnLocalNotifyManager>(&comm);
comm.connLocalCntNotifyManager = std::make_unique<ConnLocalCntNotifyManager>(&comm);
comm.streamManager = make_unique<StreamManager>(&comm);
comm.streamManager->opbase = make_unique<OpbaseStreamManager>(&comm);
comm.socketManager = std::make_unique<SocketManager>(comm, 1, 1, 1);
comm.trace = std::make_unique<Trace>();
comm.memTransportManager = make_unique<MemTransportManager>(comm);
comm.cclBuffer = DevBuffer::Create(0x100, 200);
u32 remoteRank = 1;
CollOpParams collOpParams;
collOpParams.opType = OpType::SEND;
collOpParams.dataType = DataType::INT8;
collOpParams.reduceOp = ReduceOp::SUM;
collOpParams.dstRank = remoteRank;
collOpParams.sendBuf = nullptr;
collOpParams.recvBuf = nullptr;
collOpParams.count = 10;
collOpParams.root = 0;
collOpParams.staticAddr = true;
collOpParams.staticShape = true;
collOpParams.outputDataType = DataType::INT8;
collOpParams.debugCase = 1;
collOpParams.dstRank = 0;
std::string name = "test";
comm.CovertToCurrentCollOperator(name, collOpParams, OpMode::OPBASE);
CollServiceDefaultImpl service(&comm);
std::vector<Stream *> slaveVec;
MOCKER_CPP(&Interpreter::Submit).stubs();
std::shared_ptr<FakeCollAlgComponent> collAlgComponent = std::make_shared<FakeCollAlgComponent>();
comm.collAlgComponent = collAlgComponent;
MOCKER_CPP_VIRTUAL(*collAlgComponent,
&CollAlgComponent::Orchestrate,
HcclResult(CollAlgComponent::*)(
const CollAlgOperator &op, const CollAlgParams ¶ms, const string &algName, InsQuePtr queue))
.stubs()
.with(mockcpp::any(), mockcpp::any(), mockcpp::any(), mockcpp::any())
.will(returnValue(HcclResult::HCCL_SUCCESS));
service.connectionsBuilders[comm.id] = std::make_unique<ConnectionsBuilder>(comm);
DevType devType = DevType::DEV_TYPE_910A;
MOCKER(HrtGetDeviceType).stubs().will(returnValue(devType));
MOCKER_CPP(&CollServiceDefaultImpl::RegisterOpBufToBufMgr).stubs();
MOCKER_CPP(&CollServiceDefaultImpl::RegisterOpbasedStream).stubs();
MOCKER_CPP(&SocketManager::BatchCreateSockets).stubs();
MOCKER_CPP(&CollServiceBase::SaveMirrorDfxOpInfo).stubs();
vector<LinkData> links;
MOCKER_CPP(&InsQueue::GetUniqueLinks).stubs().will(returnValue(links));
shared_ptr<InsQueue> insQueue = make_shared<InsQueue>();
MOCKER_CPP(&PrimTranslator::Translate).stubs().will(returnValue(insQueue));
MOCKER_CPP(&ConnectionsBuilder::BatchBuild).stubs();
CollOperator op;
op.opMode = OpMode::OPBASE;
op.inputMem = DevBuffer::Create(0x100, 1);
op.outputMem = DevBuffer::Create(0x100, 1);
op.scratchMem = DevBuffer::Create(0x100, 1);
s32 fakeStreamId = 123;
MOCKER(HrtGetStreamId).stubs().with(mockcpp::any()).will(returnValue(0));
auto stream = std::make_unique<Stream>(nullptr);
comm.streamManager->opbase->master = make_unique<Stream>(&comm);
comm.currentCollOperator->opMode = OpMode::OPBASE;
comm.rankGraph = std::make_unique<RankGraph>(comm.GetMyRank());
comm.rankGraph->peers_[comm.GetMyRank()] = std::make_shared<NetInstance::Peer>(comm.GetMyRank(), 0, 0, 0);
EnvAlgoConfig fakeAlgoCfg;
EnvAlgoConfig &algoCfg = fakeAlgoCfg;
algoCfg.bufferSize =
CfgField<u64>("HCCL_BUFFSIZE", 200 * 1024 * 1024, Str2T<u64>, CHK_RANGE_CLOSED<u64>(1, ULLONG_MAX), [](u64 &i) {
i *= 1024 * 1024;
});
algoCfg.bufferSize.isParsed = true;
MOCKER_CPP(&EnvConfig::GetAlgoConfig).stubs().will(returnValue(algoCfg));
MOCKER_CPP(&CollServiceBase::SaveMirrorDfxOpInfo).stubs().with(mockcpp::any()).will(ignoreReturnValue());
EXPECT_NO_THROW(service.LoadWithOpBasedMode(op, std::move(stream)));
}
TEST(CollServiceDefaultImplTest, coll_service_default_RmaConnManager_Get_allDTO_return_empty)
{
CommunicatorImpl comm;
CollOpParams collOpParams;
collOpParams.opType = OpType::SEND;
collOpParams.dataType = DataType::INT8;
collOpParams.reduceOp = ReduceOp::SUM;
collOpParams.dstRank = 1;
u32 buffer = 10;
collOpParams.sendBuf = static_cast<void *>(&buffer);
collOpParams.recvBuf = static_cast<void *>(&buffer);
collOpParams.count = 10;
collOpParams.root = 0;
collOpParams.staticAddr = true;
collOpParams.staticShape = true;
uint64_t a = 10;
uintptr_t devAddr = reinterpret_cast<uintptr_t>(&a);
std::size_t devSize = 2;
comm.cclBuffer = make_shared<DevBuffer>(10);
string tag = "optag";
comm.CovertToCurrentCollOperator(tag, collOpParams, OpMode::OPBASE);
}
TEST(CollServiceDefaultImplTest, col_service_default_impl_update_ub_ci_if_need_success)
{
DevType devType = DevType::DEV_TYPE_910A;
MOCKER(HrtGetDeviceType).stubs().will(returnValue(devType));
u32 remoteRank = 1;
CommunicatorImpl comm;
CollOpParams collOpParams;
collOpParams.opType = OpType::SEND;
collOpParams.dataType = DataType::INT8;
collOpParams.reduceOp = ReduceOp::SUM;
collOpParams.dstRank = remoteRank;
collOpParams.sendBuf = nullptr;
collOpParams.recvBuf = nullptr;
collOpParams.count = 10;
collOpParams.root = 0;
collOpParams.staticAddr = true;
collOpParams.staticShape = true;
collOpParams.outputDataType = DataType::INT8;
collOpParams.debugCase = 1;
collOpParams.dstRank = 0;
std::string name = "test";
comm.cclBuffer = DevBuffer::Create(0x100, 200);
comm.CovertToCurrentCollOperator(name, collOpParams, OpMode::OPBASE);
comm.id = "test";
comm.rmaConnectionManager = std::make_unique<RmaConnManager>(comm);
comm.connLocalNotifyManager = std::make_unique<ConnLocalNotifyManager>(&comm);
comm.connLocalCntNotifyManager = std::make_unique<ConnLocalCntNotifyManager>(&comm);
comm.streamManager = make_unique<StreamManager>(&comm);
comm.streamManager->opbase = make_unique<OpbaseStreamManager>(&comm);
void* ptr;
unique_ptr<Stream> master = make_unique<Stream>((ptr));
comm.streamManager->opbase->master = std::move(master);
comm.socketManager = std::make_unique<SocketManager>(comm, 1, 1, 1);
comm.memTransportManager = make_unique<MemTransportManager>(comm);
CollServiceDefaultImpl service(&comm);
service.updatingUbCiEvent = nullptr;
MOCKER(IfNeedUpdatingUbCi).stubs().will(returnValue(true));
MOCKER_CPP(&UbCiUpdaterManager::SaveConnsCi).stubs().with(mockcpp::any()).will(ignoreReturnValue());
CollOperator op;
op.opTag = "test";
service.UpdateUbCiIfNeed(op.opTag);
service.updatingUbCiEvent = make_unique<MaskEvent>();
RtEvent_t fakePtr = nullptr;
aclrtEventWaitStatus status = ACL_EVENT_WAIT_STATUS_COMPLETE;
MOCKER(aclrtQueryEventWaitStatus).stubs().with(mockcpp::any(), outBoundP(&status, sizeof(status))).will(returnValue(ACL_SUCCESS));
MOCKER_CPP(&UbCiUpdaterManager::UpdateConnsCi).stubs().with(mockcpp::any());
service.UpdateUbCiIfNeed(op.opTag);
}
TEST(CollServiceDefaultImplTest, AddCountTask)
{
GlobalMockObject::verify();
MOCKER(HrtStreamDestroy).stubs().with(mockcpp::any());
void *ptr = nullptr;
MOCKER(HrtStreamCreateWithFlags).stubs().with(mockcpp::any(), mockcpp::any()).will(returnValue(ptr));
MOCKER(HrtGetStreamId).stubs().with(mockcpp::any()).will(returnValue(0));
DevType devType = DevType::DEV_TYPE_910A;
MOCKER(HrtGetDeviceType).stubs().will(returnValue(devType));
MOCKER(HrtReduceAsync).stubs().with(mockcpp::any());
MOCKER(HrtMemcpy).stubs().with(mockcpp::any(), mockcpp::any(), mockcpp::any(), mockcpp::any(), mockcpp::any());
void *devPtr = nullptr;
MOCKER(HrtMalloc).stubs().with(mockcpp::any(),mockcpp::any()).will(returnValue(devPtr));
MOCKER(HrtMemset).stubs().with(mockcpp::any(), mockcpp::any(), mockcpp::any(), mockcpp::any());
CommunicatorImpl comm;
comm.streamManager = make_unique<StreamManager>(&comm);
comm.streamManager->opbase = make_unique<OpbaseStreamManager>(&comm);
comm.streamManager->opbase->master = make_unique<Stream>(&comm);
comm.currentCollOperator = make_unique<CollOperator>();
comm.currentCollOperator->opMode = OpMode::OPBASE;
CollServiceDefaultImpl collServiceDefaultImpl(&comm);
EXPECT_NO_THROW(collServiceDefaultImpl.AddCountTask(true));
}
TEST(CollServiceDefaultImplTest, Test_RecoverTransport)
{
std::unique_ptr<CommunicatorImpl> comm = std::make_unique<CommunicatorImpl>();
CollServiceDefaultImpl collServiceDefaultImpl(comm.get());
MOCKER_CPP(&RdmaHandleManager::GetDieAndFuncId).stubs().will(returnValue(make_pair<uint32_t,uint32_t>(0,0)));
DevType devType = DevType::DEV_TYPE_950;
MOCKER(HrtGetDeviceType).stubs().will(returnValue(devType));
vector<LinkData> links;
vector<std::pair<LinkGroup, u32>> linkGroupPair;
LinkData linkData(PortDeploymentType::P2P,LinkProtocol::UB_CTP, 0, 1, IpAddress{"10.0.0.1"}, IpAddress{"10.0.0.2"});
links.push_back(linkData);
LinkGroup linkGroup{};
linkGroup.AddLink({linkData});
LinkData otherLinkData(PortDeploymentType::P2P,LinkProtocol::UB_CTP, 1, 1, IpAddress{"10.0.0.3"}, IpAddress{"10.0.0.4"});;
linkGroup.AddLink({otherLinkData});
linkGroupPair.push_back(make_pair(linkGroup, 0));
EXPECT_THROW(collServiceDefaultImpl.RecoverTransport(links, linkGroupPair), NotSupportException);
}