#include "gtest/gtest.h"
#include <mockcpp/mockcpp.hpp>
#include "hccl_api_base_test.h"
#include "../../stub/llt_hccl_stub_rank_graph.h"
using namespace hccl;
class HcclCommHostTest : public testing::Test {
protected:
void SetUp() override {}
void TearDown() override { GlobalMockObject::verify(); }
};
TEST_F(HcclCommHostTest, Ut_ResumeWhenIsCommunicatorV2ExpectSuccess)
{
MOCKER(hrtGetDeviceType)
.stubs()
.with(outBound(DevType::DEV_TYPE_950))
.will(returnValue(HCCL_SUCCESS));
MOCKER(IsSupportHCCLV2)
.stubs()
.will(returnValue(true));
setenv("HCCL_INDEPENDENT_OP","1",1);
void* commV2 = (void*)0x2000;
RankGraphStub rankGraphStub;
std::shared_ptr<Hccl::RankGraph> rankGraphV2 = rankGraphStub.Create2PGraph();
u32 rank = 1;
HcclMem cclBuffer{};
cclBuffer.size = 1;
cclBuffer.type = HcclMemType::HCCL_MEM_TYPE_HOST;
cclBuffer.addr = (void*)0x1000;
char commName[ROOTINFO_INDENTIFIER_MAX_LENGTH] = {};
std::shared_ptr<hccl::hcclComm> hcclCommPtr = std::make_shared<hccl::hcclComm>(1, 1, commName);
MOCKER_CPP(&CollComm::Init)
.stubs()
.will(returnValue(HCCL_SUCCESS));
MOCKER_CPP(&CollComm::GetHDCommunicate)
.stubs()
.will(returnValue(HCCL_SUCCESS));
MOCKER_CPP(&CollComm::Resume)
.stubs()
.will(returnValue(HCCL_SUCCESS));
HcclCommConfig config{};
unsetenv("HCCL_DFS_CONFIG");
HcclResult ret = hcclCommPtr->InitCollComm(commV2, rankGraphV2.get(), rank, cclBuffer, commName, &config);
EXPECT_EQ(ret, HCCL_SUCCESS);
hcclCommPtr->devType_ = DevType::DEV_TYPE_950;
ret = hcclCommPtr->Resume();
EXPECT_EQ(ret, HCCL_SUCCESS);
}
TEST_F(HcclCommHostTest, Ut_ResumeWhenIsCommunicatorV2AndCollResumeFailsExpectError)
{
MOCKER(hrtGetDeviceType)
.stubs()
.with(outBound(DevType::DEV_TYPE_950))
.will(returnValue(HCCL_SUCCESS));
MOCKER(IsSupportHCCLV2)
.stubs()
.will(returnValue(true));
setenv("HCCL_INDEPENDENT_OP","1",1);
void* commV2 = (void*)0x2000;
RankGraphStub rankGraphStub;
std::shared_ptr<Hccl::RankGraph> rankGraphV2 = rankGraphStub.Create2PGraph();
u32 rank = 1;
HcclMem cclBuffer{};
cclBuffer.size = 1;
cclBuffer.type = HcclMemType::HCCL_MEM_TYPE_HOST;
cclBuffer.addr = (void*)0x1000;
char commName[ROOTINFO_INDENTIFIER_MAX_LENGTH] = {};
std::shared_ptr<hccl::hcclComm> hcclCommPtr = std::make_shared<hccl::hcclComm>(1, 1, commName);
MOCKER_CPP(&CollComm::Init)
.stubs()
.will(returnValue(HCCL_SUCCESS));
MOCKER_CPP(&CollComm::GetHDCommunicate)
.stubs()
.will(returnValue(HCCL_SUCCESS));
MOCKER_CPP(&CollComm::Resume)
.stubs()
.will(returnValue(HCCL_E_INTERNAL));
HcclCommConfig config{};
unsetenv("HCCL_DFS_CONFIG");
HcclResult ret = hcclCommPtr->InitCollComm(commV2, rankGraphV2.get(), rank, cclBuffer, commName, &config);
EXPECT_EQ(ret, HCCL_SUCCESS);
hcclCommPtr->devType_ = DevType::DEV_TYPE_950;
ret = hcclCommPtr->Resume();
EXPECT_EQ(ret, HCCL_E_INTERNAL);
}
TEST_F(HcclCommHostTest, Ut_GetCommStatusWhenIsCommunicatorV2ExpectCollStatus)
{
MOCKER(hrtGetDeviceType)
.stubs()
.with(outBound(DevType::DEV_TYPE_950))
.will(returnValue(HCCL_SUCCESS));
MOCKER(IsSupportHCCLV2)
.stubs()
.will(returnValue(true));
setenv("HCCL_INDEPENDENT_OP","1",1);
void* commV2 = (void*)0x2000;
RankGraphStub rankGraphStub;
std::shared_ptr<Hccl::RankGraph> rankGraphV2 = rankGraphStub.Create2PGraph();
u32 rank = 1;
HcclMem cclBuffer{};
cclBuffer.size = 1;
cclBuffer.type = HcclMemType::HCCL_MEM_TYPE_HOST;
cclBuffer.addr = (void*)0x1000;
char commName[ROOTINFO_INDENTIFIER_MAX_LENGTH] = {};
std::shared_ptr<hccl::hcclComm> hcclCommPtr = std::make_shared<hccl::hcclComm>(1, 1, commName);
MOCKER_CPP(&CollComm::Init)
.stubs()
.will(returnValue(HCCL_SUCCESS));
MOCKER_CPP(&CollComm::GetHDCommunicate)
.stubs()
.will(returnValue(HCCL_SUCCESS));
MOCKER_CPP(&CollComm::GetCommStatus)
.stubs()
.will(returnValue(HcclCommStatus::HCCL_COMM_STATUS_SUSPENDING));
HcclCommConfig config{};
unsetenv("HCCL_DFS_CONFIG");
HcclResult ret = hcclCommPtr->InitCollComm(commV2, rankGraphV2.get(), rank, cclBuffer, commName, &config);
EXPECT_EQ(ret, HCCL_SUCCESS);
hcclCommPtr->devType_ = DevType::DEV_TYPE_950;
HcclCommStatus status = HcclCommStatus::HCCL_COMM_STATUS_INVALID;
ret = hcclCommPtr->GetCommStatus(status);
EXPECT_EQ(ret, HCCL_SUCCESS);
EXPECT_EQ(status, HcclCommStatus::HCCL_COMM_STATUS_SUSPENDING);
}
TEST_F(HcclCommHostTest, Ut_GetCommStatusWhenIsCommunicatorV1ExpectReturnE_NOT_SUPPORT)
{
std::shared_ptr<hccl::hcclComm> hcclCommPtr = std::make_shared<hccl::hcclComm>();
hcclCommPtr->devType_ = DevType::DEV_TYPE_910_93;
HcclCommStatus status = HcclCommStatus::HCCL_COMM_STATUS_INVALID;
HcclResult ret = hcclCommPtr->GetCommStatus(status);
EXPECT_EQ(ret, HCCL_E_NOT_SUPPORT);
}
TEST_F(HcclCommHostTest, Ut_InitCollCommInner_When_Success_Expect_Success)
{
std::shared_ptr<hccl::hcclComm> hcclCommPtr = std::make_shared<hccl::hcclComm>(1, 1, "test_comm");
MOCKER_CPP(&HcclCommunicator::GetConnectMode).stubs().will(returnValue(1));
MOCKER_CPP(&HcclCommunicator::GetRankGraphV1).stubs().will(returnValue(reinterpret_cast<void*>(0x1000)));
MOCKER_CPP(&HcclCommunicator::GetInCCLbuffer).stubs().will(returnValue(HCCL_SUCCESS));
MOCKER_CPP(&hcclComm::CreateCommCCLbuffer).stubs().will(returnValue(HCCL_SUCCESS));
MOCKER_CPP(&CollComm::Init).stubs().will(returnValue(HCCL_SUCCESS));
MOCKER_CPP(&CollComm::GetHDCommunicate).stubs().will(returnValue(HCCL_SUCCESS));
u32 userRank = 0;
HcclResult ret = hcclCommPtr->InitCollCommInner(userRank);
EXPECT_EQ(ret, HCCL_SUCCESS);
}
TEST_F(HcclCommHostTest, Ut_BinaryUnLoad_When_BinHandleNotNullAndUnloadFailed_Expect_WarningLog)
{
std::shared_ptr<hccl::hcclComm> hcclCommPtr = std::make_shared<hccl::hcclComm>(1, 1, "test_comm");
hcclCommPtr->binHandle_ = reinterpret_cast<aclrtBinHandle>(0x1234);
MOCKER(aclrtBinaryUnLoad)
.stubs()
.will(returnValue(1));
hcclCommPtr->BinaryUnLoad();
EXPECT_EQ(hcclCommPtr->binHandle_, nullptr);
}
TEST_F(HcclCommHostTest, Ut_BinaryUnLoad_When_BinHandleNull_Expect_Noop)
{
std::shared_ptr<hccl::hcclComm> hcclCommPtr = std::make_shared<hccl::hcclComm>(1, 1, "test_comm");
hcclCommPtr->binHandle_ = nullptr;
hcclCommPtr->BinaryUnLoad();
EXPECT_EQ(hcclCommPtr->binHandle_, nullptr);
}
TEST_F(HcclCommHostTest, Ut_BinaryUnLoad_When_BinHandleNotNullAndUnloadSuccess_Expect_BinHandleSetNull)
{
std::shared_ptr<hccl::hcclComm> hcclCommPtr = std::make_shared<hccl::hcclComm>(1, 1, "test_comm");
hcclCommPtr->binHandle_ = reinterpret_cast<aclrtBinHandle>(0x1234);
MOCKER(aclrtBinaryUnLoad)
.stubs()
.will(returnValue(0));
hcclCommPtr->BinaryUnLoad();
EXPECT_EQ(hcclCommPtr->binHandle_, nullptr);
}