* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cstdint>
#include <string>
#include <emock/emock.hpp>
#include <gtest/gtest.h>
#include <mpi.h>
#include <runtime/base.h>
#include <runtime/dev.h>
#include <runtime/mem.h>
#include <hccl.h>
#include "lccl/include/lcal_comm.h"
#include "lccl/include/comm_args.h"
#include "lccl/include/lcal_api.h"
#include "lccl/src/tools/socket/lcal_sock_exchange.h"
#include "lccl/include/lcal_types.h"
namespace Lcal {
using std::string;
using std::vector;
class LcalCommTest : public ::testing::Test {
protected:
void SetUp() override
{
emock::GlobalMockObject::reset();
}
};
TEST_F(LcalCommTest, ConstructAndGetOk)
{
auto ranks = vector<int>{0};
auto comm0 = LcalComm(0, 1, ranks);
EXPECT_EQ(comm0.GetRank(), 0);
EXPECT_EQ(comm0.GetRankSize(), 1);
auto comm1 = LcalComm(0, 1, LcalUniqueId{});
EXPECT_EQ(comm1.GetCommSize(), 0);
EXPECT_EQ(comm1.GetCommArgsPtr(), nullptr);
EXPECT_EQ(comm1.GetPhysicalInfo().coreNum, 0);
}
TEST_F(LcalCommTest, InitOk)
{
class MockLcalComm : public LcalComm {
public:
MockLcalComm(int rank, int rankSize) : LcalComm(rank, rankSize) {};
int GatherDevId() override
{
return LCAL_SUCCESS;
}
int InitCommMem() override
{
return LCAL_SUCCESS;
}
int SyncCommArgs() override
{
return LCAL_SUCCESS;
}
};
auto comm = MockLcalComm(0, 1);
EXPECT_EQ(comm.Init(), LCAL_SUCCESS);
}
TEST_F(LcalCommTest, InitErr_InvalidRank)
{
auto comm = LcalComm(-1, 1);
EXPECT_EQ(comm.Init(), LCAL_ERROR_PARA_CHECK_FAIL);
}
TEST_F(LcalCommTest, InitErr_GatherDevIdFailure)
{
class MockLcalComm : public LcalComm {
public:
MockLcalComm(int rank, int rankSize) : LcalComm(rank, rankSize) {};
int GatherDevId() override
{
return LCAL_ERROR_INTERNAL;
}
};
auto comm = MockLcalComm(0, 1);
EXPECT_EQ(comm.Init(), LCAL_ERROR_INTERNAL);
}
TEST_F(LcalCommTest, InitErr_InitCommFailure)
{
class MockLcalComm : public LcalComm {
public:
MockLcalComm(int rank, int rankSize) : LcalComm(rank, rankSize) {};
int GatherDevId() override
{
return LCAL_SUCCESS;
}
int InitCommon() override
{
return LCAL_ERROR_INTERNAL;
}
};
auto comm = MockLcalComm(0, 1);
EXPECT_EQ(comm.Init(), LCAL_ERROR_INTERNAL);
}
TEST_F(LcalCommTest, InitErr_InitCommMemFailure)
{
class MockLcalComm : public LcalComm {
public:
MockLcalComm(int rank, int rankSize) : LcalComm(rank, rankSize) {};
int GatherDevId() override
{
return LCAL_SUCCESS;
}
int InitCommMem() override
{
return LCAL_ERROR_INTERNAL;
}
};
auto comm = MockLcalComm(0, 1);
EXPECT_EQ(comm.Init(), LCAL_ERROR_INTERNAL);
}
TEST_F(LcalCommTest, InitCommErr)
{
class MockLcalComm : public LcalComm {
public:
MockLcalComm(int rank, int rankSize) : LcalComm(rank, rankSize) {};
int EnablePeerAccess() override
{
return LCAL_ERROR_INTERNAL;
}
};
auto comm = MockLcalComm(0, 1);
EXPECT_EQ(comm.InitCommon(), LCAL_ERROR_INTERNAL);
}
TEST_F(LcalCommTest, InitThreadErr_aclrtMallocFailure)
{
EMOCK(aclrtMalloc).stubs().will(returnValue(ACL_ERROR_RT_MEMORY_ALLOCATION));
auto comm = LcalComm(0, 1);
EXPECT_EQ(comm.InitThread(), LCAL_ERROR_INTERNAL);
}
TEST_F(LcalCommTest, InitThreadErr_InvalidRank)
{
auto comm = LcalComm(-1, 1);
EXPECT_EQ(comm.InitThread(), LCAL_ERROR_PARA_CHECK_FAIL);
}
TEST_F(LcalCommTest, InitThreadErr_GatherDevIdFailure)
{
class MockLcalComm : public LcalComm {
public:
MockLcalComm(int rank, int rankSize) : LcalComm(rank, rankSize) {};
int GatherDevIdThread() override
{
return LCAL_ERROR_INTERNAL;
}
};
auto comm = MockLcalComm(0, 1);
EXPECT_EQ(comm.InitThread(), LCAL_ERROR_INTERNAL);
}
TEST_F(LcalCommTest, InitThreadErr_InitCommonFailure)
{
class MockLcalComm : public LcalComm {
public:
MockLcalComm(int rank, int rankSize) : LcalComm(rank, rankSize) {};
int GatherDevIdThread() override
{
return LCAL_SUCCESS;
}
int InitCommon() override
{
return LCAL_ERROR_INTERNAL;
}
};
auto comm = MockLcalComm(0, 1);
EXPECT_EQ(comm.InitThread(), LCAL_ERROR_INTERNAL);
}
TEST_F(LcalCommTest, EnablePeerAccessOk)
{
EMOCK(aclrtDeviceEnablePeerAccess).stubs().with(any()).will(returnValue(ACL_SUCCESS));
auto comm = LcalComm(0, 1);
comm.devId_ = 1;
comm.devList_ = vector<int>{0};
EXPECT_EQ(comm.EnablePeerAccess(), LCAL_SUCCESS);
}
aclError mockRtGetPairDevicesInfo(uint32_t devId, uint32_t peerDevId, int32_t flags, int64_t* value)
{
constexpr auto kInvalidTopologyType = -1;
*value = kInvalidTopologyType;
return 0;
}
TEST_F(LcalCommTest, EnablePeerAccessErr_InvalidRankSizeForPCIE)
{
EMOCK(SkipUnusedChannel910B2C).stubs().with(any()).will(returnValue(false));
EMOCK(rtGetPairDevicesInfo).stubs().will(invoke(mockRtGetPairDevicesInfo));
EMOCK(aclrtDeviceEnablePeerAccess).stubs().with(any()).will(returnValue(ACL_SUCCESS));
auto comm = LcalComm(0, 1);
comm.devId_ = 1;
comm.devList_ = vector<int>{0};
comm.physicalInfo_.physicalLink = PhysicalLink::RESERVED;
comm.rankSize_ = PING_PONG_SIZE + 1;
EXPECT_EQ(comm.EnablePeerAccess(), LCAL_ERROR_INTERNAL);
}
TEST_F(LcalCommTest, CommArgsSetBuffOk)
{
auto commArgs = CommArgs();
auto buff = new int8_t* [1];
buff[0] = nullptr;
commArgs.SetBuff(buff);
EXPECT_EQ(commArgs.peerMems[0], nullptr);
}
TEST_F(LcalCommTest, CommArgsSetBuffErr)
{
auto commArgs = CommArgs();
auto buff = new int8_t* [1];
buff[0] = nullptr;
commArgs.rankSize = LCAL_MAX_RANK_SIZE + 1;
EXPECT_THROW(commArgs.SetBuff(buff), std::invalid_argument);
}
TEST_F(LcalCommTest, GatherDevId)
{
class MockLcalSockExchange : public LcalSockExchange {
public:
MockLcalSockExchange(int rank, int rankSize, LcalUniqueId lcalCommId)
: LcalSockExchange(rank, rankSize, lcalCommId) {};
int GetNodeNum() override
{
return 1;
}
int AllGather(const void* sendBuf, size_t sendSize, void* recvBuf) override
{
return LCAL_SUCCESS;
}
};
EMOCK(aclrtGetDevice).stubs().with(any()).will(returnValue(ACL_SUCCESS));
auto comm = LcalComm(0, 1);
comm.socketExchange_ = new MockLcalSockExchange(0, 1, LcalUniqueId{});
EXPECT_EQ(comm.GatherDevId(), 0);
}
TEST_F(LcalCommTest, InitMemOk)
{
EMOCK(aclrtMalloc).stubs().will(returnValue(ACL_SUCCESS));
EMOCK(aclrtMemset).stubs().will(returnValue(ACL_SUCCESS));
auto comm = LcalComm(0, 1);
EXPECT_EQ(comm.InitMem(), LCAL_SUCCESS);
}
TEST_F(LcalCommTest, GetSidIdOk)
{
class MockLcalSockExchange : public LcalSockExchange {
public:
MockLcalSockExchange(int rank, int rankSize, LcalUniqueId lcalCommId)
: LcalSockExchange(rank, rankSize, lcalCommId) {};
int GetNodeNum() override
{
return 1;
}
int AllGather(const void* sendBuf, size_t sendSize, void* recvBuf) override
{
return LCAL_SUCCESS;
}
};
class MockLcalComm : public LcalComm {
public:
MockLcalComm(int rank, int rankSize) : LcalComm(rank, rankSize) {};
void CloseIpcMem() override
{
return;
}
};
EMOCK(rtGetDeviceInfo).stubs().with(any()).will(returnValue(ACL_SUCCESS));
auto comm = LcalComm(0, 1);
comm.socketExchange_ = new MockLcalSockExchange(0, 1, LcalUniqueId{});
comm.devList_ = vector<int>{0};
comm.physicalInfo_.chipName = ChipName::CHIP_910_9391;
int64_t sdids[LCAL_MAX_RANK_SIZE] = {0};
EXPECT_EQ(comm.GetSidId(sdids), LCAL_SUCCESS);
}
TEST_F(LcalCommTest, GetPidOk)
{
class MockLcalSockExchange : public LcalSockExchange {
public:
MockLcalSockExchange(int rank, int rankSize, LcalUniqueId lcalCommId)
: LcalSockExchange(rank, rankSize, lcalCommId) {};
int AllGather(const void* sendBuf, size_t sendSize, void* recvBuf) override
{
return LCAL_SUCCESS;
}
};
EMOCK(rtDeviceGetBareTgid).stubs().with(any()).will(returnValue(ACL_SUCCESS));
auto comm = LcalComm(0, 1);
comm.socketExchange_ = new MockLcalSockExchange(0, 1, LcalUniqueId{});
auto pid = new uint32_t[LCAL_MAX_RANK_SIZE]();
EXPECT_EQ(comm.GetPid(pid), LCAL_SUCCESS);
}
TEST_F(LcalCommTest, GetNameOk)
{
class MockLcalSockExchange : public LcalSockExchange {
public:
MockLcalSockExchange(int rank, int rankSize, LcalUniqueId lcalCommId)
: LcalSockExchange(rank, rankSize, lcalCommId) {};
int AllGather(const void* sendBuf, size_t sendSize, void* recvBuf) override
{
return LCAL_SUCCESS;
}
};
auto comm = LcalComm(0, 1);
comm.socketExchange_ = new MockLcalSockExchange(0, 1, LcalUniqueId{});
auto name = string("test");
auto names = new char[LCAL_MAX_RANK_SIZE][IPC_NAME_SIZE]();
EXPECT_EQ(comm.GetName(name, names), LCAL_SUCCESS);
}
TEST_F(LcalCommTest, InitCommMemOk)
{
class MockLcalComm : public LcalComm {
public:
MockLcalComm(int rank, int rankSize) : LcalComm(rank, rankSize) {};
int InitMem() override
{
return LCAL_SUCCESS;
}
int GetPid(uint32_t pids[LCAL_MAX_RANK_SIZE]) override
{
return LCAL_SUCCESS;
}
int GetSidId(int64_t sdids[LCAL_MAX_RANK_SIZE]) override
{
return LCAL_SUCCESS;
}
int SetMemoryName(std::string& name) override
{
return LCAL_SUCCESS;
}
int SetIpcPidSdid(std::string& name, const uint32_t* pids, const int64_t* sdids) const override
{
return LCAL_SUCCESS;
}
int GetName(std::string& name, char names[LCAL_MAX_RANK_SIZE][IPC_NAME_SIZE]) override
{
return LCAL_SUCCESS;
}
int OpenIpcMem(const char names[LCAL_MAX_RANK_SIZE][IPC_NAME_SIZE]) override
{
return LCAL_SUCCESS;
}
};
auto comm = MockLcalComm(0, 1);
EXPECT_EQ(comm.InitCommMem(), ACL_SUCCESS);
}
TEST_F(LcalCommTest, CloseIpcMemOk)
{
EMOCK(rtIpcCloseMemory).expects(once()).will(returnValue(ACL_SUCCESS));
auto comm = LcalComm(0, 2);
comm.peerMem_[1] = new int8_t[0];
comm.CloseIpcMem();
}
TEST_F(LcalCommTest, OpenIpcMemOk)
{
EMOCK(GetChipName)
.expects(exactly(3))
.will(returnObjectList(ChipName::CHIP_910B2C, ChipName::CHIP_910_9361, ChipName::CHIP_910_9361));
EMOCK(SkipUnusedChannel910B2C).stubs().will(returnValue(false));
EMOCK(rtIpcOpenMemory).stubs().with(any()).will(returnValue(ACL_SUCCESS));
auto comm = LcalComm(0, 2);
auto names = new char[LCAL_MAX_RANK_SIZE][IPC_NAME_SIZE]();
EXPECT_EQ(comm.OpenIpcMem(names), LCAL_SUCCESS);
EXPECT_EQ(comm.OpenIpcMem(names), LCAL_SUCCESS);
}
TEST_F(LcalCommTest, SetMemoryNameOk)
{
EMOCK(GetChipName).expects(once()).will(returnValue(ChipName::CHIP_910_9382));
EMOCK(rtIpcSetMemoryName).expects(once()).will(returnValue(ACL_SUCCESS));
auto comm = LcalComm(0, 1);
auto memName = string("test");
auto res = comm.SetMemoryName(memName);
EXPECT_EQ(res, LCAL_SUCCESS);
}
TEST_F(LcalCommTest, SetIpcPidSdidOk)
{
EMOCK(rtSetIpcMemPid).expects(exactly(2)).will(returnValue(ACL_SUCCESS));
EMOCK(rtSetIpcMemorySuperPodPid).expects(exactly(2)).will(returnValue(ACL_SUCCESS));
auto name = string("test");
auto pids = new uint32_t[LCAL_MAX_RANK_SIZE]();
auto sdids = new int64_t[LCAL_MAX_RANK_SIZE]();
auto comm = LcalComm(1, 2);
comm.physicalInfo_.chipName = ChipName::CHIP_910B2C;
auto res = comm.SetIpcPidSdid(name, pids, sdids);
EXPECT_EQ(res, LCAL_SUCCESS);
comm.physicalInfo_.chipName = ChipName::CHIP_910_9361;
res = comm.SetIpcPidSdid(name, pids, sdids);
EXPECT_EQ(res, LCAL_SUCCESS);
}
TEST_F(LcalCommTest, SyncCommArgsErr_aclrtMallocFailure)
{
EMOCK(aclrtMalloc).expects(once()).will(returnValue(ACL_ERROR_RT_MEMORY_ALLOCATION));
auto comm = LcalComm(0, 1);
EXPECT_EQ(comm.SyncCommArgs(), LCAL_ERROR_INTERNAL);
}
TEST_F(LcalCommTest, SyncCommArgsErr_aclrtMemcpyFailure)
{
EMOCK(aclrtMalloc).expects(once()).will(returnValue(ACL_SUCCESS));
EMOCK(aclrtMemcpy).expects(once()).will(returnValue(ACL_ERROR_RT_PARAM_INVALID));
auto comm = LcalComm(0, 1);
EXPECT_EQ(comm.SyncCommArgs(), LCAL_ERROR_INTERNAL);
}
TEST_F(LcalCommTest, GetCoreNumOk)
{
constexpr int AI_CORE_NUM_24 = 24;
constexpr int AI_CORE_NUM_20 = 20;
constexpr int AI_CORE_NUM_2 = 2;
auto res = GetCoreNum(ChipName::CHIP_910B2C);
EXPECT_EQ(res, AI_CORE_NUM_24);
res = GetCoreNum(ChipName::CHIP_910B3);
EXPECT_EQ(res, AI_CORE_NUM_20);
res = GetCoreNum(ChipName::CHIP_310P3);
EXPECT_EQ(res, AI_CORE_NUM_2);
}
TEST_F(LcalCommTest, SkipUnusedChannel910B2COk)
{
auto comm = LcalComm(0, 2);
EXPECT_TRUE(SkipUnusedChannel910B2C(1, 8, ChipName::CHIP_910B2C));
}
}