* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "gtest/gtest.h"
#include <mockcpp/mockcpp.hpp>
#include <stdio.h>
#include <stdlib.h>
#include <assert.h>
#include <securec.h>
#include <ifaddrs.h>
#include <sys/socket.h>
#include <netdb.h>
#include <sys/types.h>
#include <stddef.h>
#include <sys/mman.h>
#include <fcntl.h>
#include <sys/mman.h>
#include <memory>
#include <iostream>
#include <fstream>
#include <hccl/hccl_comm.h>
#include <hccl/hccl_inner.h>
#include "llt_hccl_stub_pub.h"
#include "v80_rank_table.h"
#include "hccl/base.h"
#include <hccl/hccl_types.h>
#include "sal.h"
#include "llt_hccl_stub_gdr.h"
#include "network_manager_pub.h"
#include <externalinput_pub.h>
#include "tsd/tsd_client.h"
#include "dltdt_function.h"
#include "dlra_function.h"
#include "externalinput.h"
#include "adapter_rts.h"
#define private public
#define protected public
#include "hccl_socket.h"
#include "hccl_socket_manager.h"
#include "hccl_communicator.h"
#include "zero_copy/zero_copy_address_mgr.h"
#undef private
#undef protected
#include "socket/hccl_network.h"
#include <queue>
#include <mutex>
using namespace std;
using namespace hccl;
s32 stub_SocketManagerTest_hrtRaSocketNonBlockSendHB(
const FdHandle fdHandle, const void *data, u64 size, u64 *sent_size)
{
*sent_size = size;
return 0;
}
template <typename T>
HcclResult ConstructData(u8 *&exchangeDataPtr, u32 &exchangeDataBlankSize, T &value)
{
CHK_SAFETY_FUNC_RET(memcpy_s(exchangeDataPtr, exchangeDataBlankSize, &value, sizeof(T)));
exchangeDataPtr += sizeof(T);
exchangeDataBlankSize -= sizeof(T);
return HCCL_SUCCESS;
}
static std::queue<std::vector<u8>> exchangeDataForAck_;
static std::unordered_map<u32, std::array<uint64_t, 2 * 1024 * 1024 / sizeof(uint64_t)>> vir_ptr_map;
u32 devicePhyId_ = 1;
u64 addr;
size_t size = 2 * 1024 * 1024;
size_t lenth = 1;
size_t alignment = 2 * 1024 * 1024;
uint64_t flags = 1;
std::mutex stub_ZeroCopyMemoryAgentUt_mutex;
HcclResult stub_ZeroCopyMemoryAgentSt_Send(hccl::HcclSocket * socket, const void *data, u64 size)
{
std::unique_lock<std::mutex> lock(stub_ZeroCopyMemoryAgentUt_mutex);
std::vector<u8> temp;
temp.resize(size);
memcpy_s(temp.data(), size, data, size);
exchangeDataForAck_.push(temp);
lock.unlock();
return HCCL_SUCCESS;
}
HcclResult ZeroCopyMemoryAgentRecv(hccl::HcclSocket *socket, void *recvBuf, u32 recvBufLen, u64 &compSize)
{
RequestType requestType = RequestType::RESERVED;
std::vector<u8> temp = exchangeDataForAck_.front();
exchangeDataForAck_.pop();
memcpy_s(&requestType, sizeof(RequestType), temp.data(), sizeof(RequestType));
switch (requestType) {
case RequestType::SET_MEMORY_RANGE:
{
static std::vector<u8> exchangeDataForAck_reserve_ipc_memory;
exchangeDataForAck_reserve_ipc_memory.resize(recvBufLen);
RequestType requestType = RequestType::SET_MEMORY_RANGE;
u32 buf_len = recvBufLen;
auto data = exchangeDataForAck_reserve_ipc_memory.data();
CHK_RET(ConstructData(data, buf_len, requestType));
CHK_RET(ConstructData(data, buf_len, devicePhyId_));
vir_ptr_map[devicePhyId_];
u64 addr = reinterpret_cast<u64>(vir_ptr_map[devicePhyId_].data());
CHK_RET(ConstructData(data, buf_len, addr));
CHK_RET(ConstructData(data, buf_len, lenth));
CHK_RET(ConstructData(data, buf_len, alignment));
CHK_RET(ConstructData(data, buf_len, flags));
memcpy_s(recvBuf,
recvBufLen,
exchangeDataForAck_reserve_ipc_memory.data(),
exchangeDataForAck_reserve_ipc_memory.size());
compSize = recvBufLen;
} break;
case RequestType::UNSET_MEMORY_RANGE:
{
static std::vector<u8> exchangeDataForAck_reserve_ipc_memory;
exchangeDataForAck_reserve_ipc_memory.resize(recvBufLen);
RequestType requestType = RequestType::UNSET_MEMORY_RANGE;
u32 buf_len = recvBufLen;
auto data = exchangeDataForAck_reserve_ipc_memory.data();
CHK_RET(ConstructData(data, buf_len, requestType));
CHK_RET(ConstructData(data, buf_len, devicePhyId_));
vir_ptr_map[devicePhyId_];
u64 addr = reinterpret_cast<u64>(vir_ptr_map[devicePhyId_].data());
CHK_RET(ConstructData(data, buf_len, addr));
memcpy_s(recvBuf,
recvBufLen,
exchangeDataForAck_reserve_ipc_memory.data(),
exchangeDataForAck_reserve_ipc_memory.size());
compSize = recvBufLen;
} break;
case RequestType::SET_MEMORY_RANGE_ACK: {
static std::vector<u8> exchangeDataForAck_release_ipc_memory;
exchangeDataForAck_release_ipc_memory.resize(recvBufLen);
RequestType requestType = RequestType::SET_MEMORY_RANGE_ACK;
u32 buf_len = recvBufLen;
auto data = exchangeDataForAck_release_ipc_memory.data();
CHK_RET(ConstructData(data, buf_len, requestType));
CHK_RET(ConstructData(data, buf_len, devicePhyId_));
vir_ptr_map[devicePhyId_];
u64 addr = reinterpret_cast<u64>(vir_ptr_map[devicePhyId_].data());
CHK_RET(ConstructData(data, buf_len, addr));
memcpy_s(recvBuf,
recvBufLen,
exchangeDataForAck_release_ipc_memory.data(),
exchangeDataForAck_release_ipc_memory.size());
compSize = recvBufLen;
} break;
case RequestType::SET_REMOTE_BARE_TGID:
{
static std::vector<u8> exchangeDataForAck_bare_tgid;
exchangeDataForAck_bare_tgid.resize(recvBufLen);
u8 *exchangeDataPtr = exchangeDataForAck_bare_tgid.data();
u32 exchangeDataBlankSize = recvBufLen;
RequestType requestType = RequestType::SET_REMOTE_BARE_TGID;
CHK_RET(ConstructData(exchangeDataPtr, exchangeDataBlankSize, requestType));
CHK_RET(ConstructData(exchangeDataPtr, exchangeDataBlankSize, devicePhyId_));
memcpy_s(recvBuf, recvBufLen, exchangeDataForAck_bare_tgid.data(), exchangeDataForAck_bare_tgid.size());
compSize = recvBufLen;
} break;
case RequestType::SET_REMOTE_BARE_TGID_ACK:
{
static std::vector<u8> exchangeDataForAck_bare_tgid;
exchangeDataForAck_bare_tgid.resize(recvBufLen);
u8 *exchangeDataPtr = exchangeDataForAck_bare_tgid.data();
u32 exchangeDataBlankSize = recvBufLen;
RequestType requestType = RequestType::SET_REMOTE_BARE_TGID_ACK;
CHK_RET(ConstructData(exchangeDataPtr, exchangeDataBlankSize, requestType));
CHK_RET(ConstructData(exchangeDataPtr, exchangeDataBlankSize, devicePhyId_));
u64 addr = reinterpret_cast<u64>(vir_ptr_map[devicePhyId_].data());
CHK_RET(ConstructData(exchangeDataPtr, exchangeDataBlankSize, addr));
memcpy_s(recvBuf, recvBufLen, exchangeDataForAck_bare_tgid.data(), exchangeDataForAck_bare_tgid.size());
compSize = recvBufLen;
} break;
case RequestType::ACTIVATE_COMM_MEMORY: {
static std::vector<u8> exchangeDataForAck_validate_ipc_memory;
exchangeDataForAck_validate_ipc_memory.resize(recvBufLen);
u8 *exchangeDataPtr = exchangeDataForAck_validate_ipc_memory.data();
u32 exchangeDataBlankSize = recvBufLen;
RequestType requestType = RequestType::ACTIVATE_COMM_MEMORY;
CHK_RET(ConstructData(exchangeDataPtr, exchangeDataBlankSize, requestType));
CHK_RET(ConstructData(exchangeDataPtr, exchangeDataBlankSize, devicePhyId_));
u64 addr = reinterpret_cast<u64>(vir_ptr_map[devicePhyId_].data());
CHK_RET(ConstructData(exchangeDataPtr, exchangeDataBlankSize, addr));
long unsigned int vv = size;
CHK_RET(ConstructData(exchangeDataPtr, exchangeDataBlankSize, lenth));
int offset = 0;
CHK_RET(ConstructData(exchangeDataPtr, exchangeDataBlankSize, offset));
uint64_t shareableHandle = 0x01;
CHK_RET(ConstructData(exchangeDataPtr, exchangeDataBlankSize, shareableHandle));
CHK_RET(ConstructData(exchangeDataPtr, exchangeDataBlankSize, flags));
memcpy_s(recvBuf,
recvBufLen,
exchangeDataForAck_validate_ipc_memory.data(),
exchangeDataForAck_validate_ipc_memory.size());
compSize = recvBufLen;
} break;
case RequestType::DEACTIVATE_COMM_MEMORY: {
static std::vector<u8> exchangeDataForAck_invalidate_ipc_memory;
exchangeDataForAck_invalidate_ipc_memory.resize(recvBufLen);
u8 *exchangeDataPtr = exchangeDataForAck_invalidate_ipc_memory.data();
u32 exchangeDataBlankSize = recvBufLen;
RequestType requestType = RequestType::DEACTIVATE_COMM_MEMORY;
CHK_RET(ConstructData(exchangeDataPtr, exchangeDataBlankSize, requestType));
CHK_RET(ConstructData(exchangeDataPtr, exchangeDataBlankSize, devicePhyId_));
u64 addr = reinterpret_cast<u64>(vir_ptr_map[devicePhyId_].data());
CHK_RET(ConstructData(exchangeDataPtr, exchangeDataBlankSize, addr));
memcpy_s(recvBuf,
recvBufLen,
exchangeDataForAck_invalidate_ipc_memory.data(),
exchangeDataForAck_invalidate_ipc_memory.size());
compSize = recvBufLen;
} break;
case RequestType::BARRIER_CLOSE: {
static std::vector<u8> exchangeDataForAck_bare_close;
exchangeDataForAck_bare_close.resize(recvBufLen);
u8 *exchangeDataPtr = exchangeDataForAck_bare_close.data();
u32 exchangeDataBlankSize = recvBufLen;
RequestType requestType = RequestType::BARRIER_CLOSE;
CHK_RET(ConstructData(exchangeDataPtr, exchangeDataBlankSize, requestType));
CHK_RET(ConstructData(exchangeDataPtr, exchangeDataBlankSize, devicePhyId_));
u64 addr = reinterpret_cast<u64>(vir_ptr_map[devicePhyId_].data());
CHK_RET(ConstructData(exchangeDataPtr, exchangeDataBlankSize, addr));
memcpy_s(recvBuf,
recvBufLen,
exchangeDataForAck_bare_close.data(),
exchangeDataForAck_bare_close.size());
compSize = recvBufLen;
} break;
default: {
memcpy_s(recvBuf, recvBufLen, temp.data(), temp.size());
compSize = temp.size();
}
}
return HCCL_SUCCESS;
}
HcclResult stub_ZeroCopyMemoryAgent_IRecv(hccl::HcclSocket *socket, void *recvBuf, u32 recvBufLen, u64 &compSize)
{
std::unique_lock<std::mutex> lock(stub_ZeroCopyMemoryAgentUt_mutex);
while (exchangeDataForAck_.empty()) {
compSize = 0;
return HCCL_SUCCESS;
}
return ZeroCopyMemoryAgentRecv(socket, recvBuf, recvBufLen, compSize);
}
s32 stub_SocketManagerTest_hrtRaGetSockets(u32 role, struct SocketInfoT conn[], u32 num, u32 *connectedNum)
{
static std::vector<int> fdHandle;
for (int i = 0; i < num; i++) {
fdHandle.push_back(0);
conn[i].fdHandle = 0;
conn[i].status = CONNECT_OK;
}
*connectedNum = num;
return 0;
}
HcclResult stub_SocketManagerTest_GetIsSupSockBatchCloseImmed(u32 phyId, bool &isSupportBatchClose)
{
isSupportBatchClose = true;
return HCCL_SUCCESS;
}
HcclResult stub_exchangerSocketTest_hrtRaBlockGetSockets(u32 role, struct SocketInfoT conn[], u32 num)
{
static std::vector<int> fdHandle;
for (int i = 0; i < num; i++) {
fdHandle.push_back(0);
conn[i].fdHandle = &fdHandle[fdHandle.size() - 1];
conn[i].status = CONNECT_OK;
}
return HCCL_SUCCESS;
}
HcclResult stub_GetRaResourceInfo_exchangerSocketTest(NetworkManager *that, RaResourceInfo &raResourceInfo)
{
static bool initialized = false;
static RaResourceInfo fake_raResourceInfo;
static int fake_handle = 1;
HcclIpAddress ipAddr = HcclIpAddress(1684515008);
if (!initialized) {
IpSocket tmpIpSocket;
tmpIpSocket.nicSocketHandle = &fake_handle;
for (int i = 0; i < 8; i++) {
fake_raResourceInfo.vnicSocketMap[ipAddr] = tmpIpSocket;
fake_raResourceInfo.nicSocketMap[ipAddr] = tmpIpSocket;
}
}
raResourceInfo = fake_raResourceInfo;
return HCCL_SUCCESS;
}
s32 stub_SocketManagerTest_hrtRaSocketNonBlockRecvHB(const FdHandle fdHandle, void *data, u64 size, u64 *recvSize)
{
static u32 count = 0;
if (count++ % 5 != 0) {
*recvSize = size;
count = 0;
}
return 0;
}
class ZeroCopyMemoryAgentUt : public testing::Test
{
protected:
static void SetUpTestCase()
{
DlTdtFunction::GetInstance().DlTdtFunctionInit();
DlRaFunction::GetInstance().DlRaFunctionInit();
TsdOpen(0,2);
std::cout << "\033[36m--OneSidedUt SetUP--\033[0m" << std::endl;
}
static void TearDownTestCase()
{
TsdClose(0);
std::cout << "\033[36m--OneSidedUt TearDown--\033[0m" << std::endl;
}
virtual void SetUp()
{
MOCKER(hrtRaSocketNonBlockRecv).stubs().will(invoke(stub_SocketManagerTest_hrtRaSocketNonBlockRecvHB));
MOCKER(hrtRaSocketWhiteListAdd).stubs().will(returnValue(HCCL_SUCCESS));
MOCKER(hrtRaSocketWhiteListDel).stubs().will(returnValue(HCCL_SUCCESS));
MOCKER(hrtRaSocketBatchConnect).stubs().will(returnValue(HCCL_SUCCESS));
MOCKER(hrtRaGetSockets).stubs().will(invoke(stub_SocketManagerTest_hrtRaGetSockets));
MOCKER(hrtRaSocketBatchClose).stubs().will(returnValue(HCCL_SUCCESS));
MOCKER(hrtRaSocketNonBlockSend).stubs().will(invoke(stub_SocketManagerTest_hrtRaSocketNonBlockSendHB));
MOCKER(hrtRaBlockGetSockets).stubs().will(invoke(stub_exchangerSocketTest_hrtRaBlockGetSockets));
MOCKER_CPP(&NetworkManager::GetRaResourceInfo).stubs().will(invoke(stub_GetRaResourceInfo_exchangerSocketTest));
hrtSetDevice(0);
ResetInitState();
DlRaFunction::GetInstance().DlRaFunctionInit();
ClearHalEvent();
struct RaInitConfig config;
std::cout << "A Test SetUP" << std::endl;
}
virtual void TearDown()
{
GlobalMockObject::verify();
std::cout << "A Test TearDown" << std::endl;
}
};
void get_ranks_1server_2dev(std::vector<RankInfo>& rank_vector)
{
RankInfo tmp_para_0;
tmp_para_0.userRank = 0;
tmp_para_0.devicePhyId = 0;
tmp_para_0.deviceType = DevType::DEV_TYPE_910;
tmp_para_0.serverIdx = 0;
tmp_para_0.serverId = "10.0.0.10";
tmp_para_0.nicIp.push_back(HcclIpAddress("192.168.0.11"));
tmp_para_0.nicDeploy = NICDeployment::NIC_DEPLOYMENT_DEVICE;
RankInfo tmp_para_1;
tmp_para_1.userRank = 1;
tmp_para_1.devicePhyId = 1;
tmp_para_1.deviceType = DevType::DEV_TYPE_910;
tmp_para_1.serverIdx = 0;
tmp_para_1.serverId = "10.0.0.10";
tmp_para_1.nicIp.push_back(HcclIpAddress("192.168.0.12"));
tmp_para_1.nicDeploy = NICDeployment::NIC_DEPLOYMENT_DEVICE;
rank_vector.push_back(tmp_para_0);
rank_vector.push_back(tmp_para_1);
return;
}
aclError aclrtReserveMemAddress_stub(void **virPtr, size_t size, size_t alignment, void *expectPtr, uint64_t flags)
{
CHK_PTR_NULL(virPtr);
vir_ptr_map[1];
*virPtr = reinterpret_cast<void *>(reinterpret_cast<u64>(vir_ptr_map[1].data()));
expectPtr = reinterpret_cast<void *>(reinterpret_cast<u64>(vir_ptr_map[1].data()));
return ACL_SUCCESS;
}
rtError_t aclrtMemImportFromShareableHandle_stub(uint64_t shareableHandle, int32_t deviceId, aclrtDrvMemHandle *handle)
{
*handle = reinterpret_cast<void *>(reinterpret_cast<u64>(vir_ptr_map[deviceId].data()));
return ACL_SUCCESS;
}
TEST_F(ZeroCopyMemoryAgentUt, ut_agent_test)
{
MOCKER(GetIsSupSockBatchCloseImmed).stubs().will(invoke(stub_SocketManagerTest_GetIsSupSockBatchCloseImmed));
u32 interfaceVersion = 1;
MOCKER(hrtRaGetInterfaceVersion)
.expects(atMost(2))
.with(mockcpp::any(), mockcpp::any(), outBoundP(&interfaceVersion))
.will(returnValue(HCCL_SUCCESS));
HcclResult ret;
u32 recvBufLen = 64;
u64 compSize = 64;
MOCKER_CPP(&HcclSocket::Send, HcclResult(HcclSocket::*)(const void *, u64))
.stubs()
.with(mockcpp::any())
.will(invoke(stub_ZeroCopyMemoryAgentSt_Send));
MOCKER_CPP(&HcclSocket::IRecv).stubs().will(invoke(stub_ZeroCopyMemoryAgent_IRecv));
MOCKER(aclrtMapMem).stubs().will(returnValue(ACL_SUCCESS));
MOCKER(aclrtUnmapMem).stubs().will(returnValue(ACL_SUCCESS));
std::unique_ptr<HcclSocketManager> socketManager;
socketManager.reset(new (std::nothrow) HcclSocketManager(NICDeployment::NIC_DEPLOYMENT_DEVICE, 0, 0, 0));
std::string commTag = "SocketManagerTest";
bool isInterLink = false;
u32 socketsPerLink = 1;
NicType socketType = NicType::VNIC_TYPE;
HcclSocketRole localRole = HcclSocketRole::SOCKET_ROLE_SERVER;
HcclIpAddress localIPs(0x01);
ret = HcclNetInit(NICDeployment::NIC_DEPLOYMENT_DEVICE, 0, 0, false);
EXPECT_EQ(ret, HCCL_SUCCESS);
std::vector<RankInfo> rank_vector;
get_ranks_1server_2dev(rank_vector);
ZeroCopyMemoryAgent ZeroCopyMemoryAgent(socketManager, 0, 0, localIPs, rank_vector, 0, true,"ZeroCopyMemoryAgentTest");
EXPECT_EQ(ZeroCopyMemoryAgent.Init(), HCCL_SUCCESS);
MOCKER(aclrtReserveMemAddress).stubs().will(invoke(aclrtReserveMemAddress_stub));
MOCKER(aclrtMemImportFromShareableHandle).stubs().will(invoke(aclrtMemImportFromShareableHandle_stub));
MOCKER(aclrtMapMem).stubs().will(returnValue(ACL_SUCCESS));
MOCKER(aclrtUnmapMem).stubs().will(returnValue(ACL_SUCCESS));
EXPECT_EQ(
ZeroCopyMemoryAgent.SetMemoryRange(
reinterpret_cast<void *>(reinterpret_cast<u64>(vir_ptr_map[0].data())), lenth, alignment, flags),
HCCL_SUCCESS);
EXPECT_EQ(ZeroCopyMemoryAgent.ActivateCommMemory(reinterpret_cast<void *>(reinterpret_cast<u64>(vir_ptr_map[0].data())),
lenth,
0,
reinterpret_cast<void *>(reinterpret_cast<u64>(vir_ptr_map[3].data())),
flags),
HCCL_SUCCESS);
EXPECT_EQ(ZeroCopyMemoryAgent.DeactivateCommMemory(
reinterpret_cast<void *>(reinterpret_cast<u64>(vir_ptr_map[0].data()))),
HCCL_SUCCESS);
EXPECT_EQ(ZeroCopyMemoryAgent.UnsetMemoryRange(reinterpret_cast<void *>(reinterpret_cast<u64>(vir_ptr_map[0].data()))),
HCCL_SUCCESS);
u64 baseSetAddr = 0x1000;
u64 baseSetLen = 2 * 1024 * 1024;
int dummyHandle = 1;
void *handle = &dummyHandle;
EXPECT_EQ(ZeroCopyMemoryAgent.BarrierClose(), HCCL_SUCCESS);
EXPECT_EQ(ZeroCopyMemoryAgent.DeInit(), HCCL_SUCCESS);
ZeroCopyMemoryAgent.mapDevPhyIdconnectedSockets_.clear();
HcclNetDeInit(NICDeployment::NIC_DEPLOYMENT_DEVICE, 0, 0);
}
TEST_F(ZeroCopyMemoryAgentUt, ut_agent_wait_timeout)
{
std::unique_ptr<HcclSocketManager> socketManager;
socketManager.reset(new (std::nothrow) HcclSocketManager(NICDeployment::NIC_DEPLOYMENT_DEVICE, 0, 0, 0));
HcclIpAddress localIPs(0x01);
std::vector<RankInfo> rankInfo;
ZeroCopyMemoryAgent agent(socketManager, 0, 0, localIPs, rankInfo, 0, false, "wait_timeout");
s32 timeout = 0;
MOCKER(GetExternalInputHcclLinkTimeOut).stubs().will(returnValue(timeout));
agent.mapDevPhyIdconnectedSockets_[0] = std::make_shared<HcclSocket>(nullptr, 0);
agent.mapDevPhyIdconnectedSockets_[1] = std::make_shared<HcclSocket>(nullptr, 0);
agent.reqMsgCounter_[static_cast<int>(RequestType::SET_MEMORY_RANGE)] = 100;
EXPECT_NE(agent.WaitForAllRemoteComplete(RequestType::SET_MEMORY_RANGE), HCCL_SUCCESS);
agent.reqMsgCounter_[static_cast<int>(RequestType::SET_MEMORY_RANGE)] = 1;
EXPECT_NE(agent.WaitForAllRemoteComplete(RequestType::SET_MEMORY_RANGE), HCCL_SUCCESS);
}
HcclResult stub_ZeroCopyMemoryAgent_SendAsync(hccl::HcclSocket *socket, const void *data, u64 size,
u64 *sentSize, void **reqHandle)
{
std::unique_lock<std::mutex> lock(stub_ZeroCopyMemoryAgentUt_mutex);
std::vector<u8> temp;
temp.resize(size);
memcpy_s(temp.data(), size, data, size);
exchangeDataForAck_.push(temp);
*sentSize = size;
*reqHandle = (void*)0x01;
return HCCL_SUCCESS;
}
HcclResult stub_ZeroCopyMemoryAgent_RecvAsync(hccl::HcclSocket *socket, void *recvBuf, u64 recvBufLen,
u64 *receivedSize, void **reqHandle)
{
*reqHandle = (void*)0x02;
std::unique_lock<std::mutex> lock(stub_ZeroCopyMemoryAgentUt_mutex);
while (exchangeDataForAck_.empty()) {
*receivedSize = 0;
return HCCL_SUCCESS;
}
u64 compSize = 0;
HcclResult ret = ZeroCopyMemoryAgentRecv(socket, recvBuf, recvBufLen, compSize);
*receivedSize = compSize;
return ret;
}
HcclResult stub_ZeroCopyMemoryAgent_GetAsyncReqResult(hccl::HcclSocket *socket, void *reqHandle, HcclResult &reqResult)
{
reqResult = HCCL_SUCCESS;
return HCCL_SUCCESS;
}
TEST_F(ZeroCopyMemoryAgentUt, Ut_AgentFunc_When_UseAsyncSocketApi_ExpectNorm)
{
MOCKER(GetIsSupSockBatchCloseImmed).stubs().will(invoke(stub_SocketManagerTest_GetIsSupSockBatchCloseImmed));
u32 interfaceVersion = 1;
MOCKER(hrtRaGetInterfaceVersion).expects(atMost(2))
.with(mockcpp::any(), mockcpp::any(), outBoundP(&interfaceVersion))
.will(returnValue(HCCL_SUCCESS));
HcclResult ret;
MOCKER(HcclSocket::IsSupportAsync).stubs().will(returnValue(true));
MOCKER_CPP(&HcclSocket::SendAsync).stubs().will(invoke(stub_ZeroCopyMemoryAgent_SendAsync));
MOCKER_CPP(&HcclSocket::RecvAsync).stubs().will(invoke(stub_ZeroCopyMemoryAgent_RecvAsync));
MOCKER_CPP(&HcclSocket::GetAsyncReqResult).stubs().will(invoke(stub_ZeroCopyMemoryAgent_GetAsyncReqResult));
MOCKER(aclrtMapMem).stubs().will(returnValue(ACL_SUCCESS));
MOCKER(aclrtUnmapMem).stubs().will(returnValue(ACL_SUCCESS));
std::unique_ptr<HcclSocketManager> socketManager;
socketManager.reset(new (std::nothrow) HcclSocketManager(NICDeployment::NIC_DEPLOYMENT_DEVICE, 0, 0, 0));
std::string commTag = "SocketManagerTest";
bool isInterLink = false;
u32 socketsPerLink = 1;
NicType socketType = NicType::VNIC_TYPE;
HcclSocketRole localRole = HcclSocketRole::SOCKET_ROLE_SERVER;
HcclIpAddress localIPs(0x01);
ret = HcclNetInit(NICDeployment::NIC_DEPLOYMENT_DEVICE, 0, 0, false);
EXPECT_EQ(ret, HCCL_SUCCESS);
std::vector<RankInfo> rank_vector;
get_ranks_1server_2dev(rank_vector);
ZeroCopyMemoryAgent agent(socketManager, 0, 0, localIPs, rank_vector, 0, true,"ZeroCopyMemoryAgentTest");
EXPECT_EQ(agent.Init(), HCCL_SUCCESS);
MOCKER(aclrtReserveMemAddress).stubs().will(invoke(aclrtReserveMemAddress_stub));
MOCKER(aclrtMemImportFromShareableHandle).stubs().will(invoke(aclrtMemImportFromShareableHandle_stub));
MOCKER(aclrtMapMem).stubs().will(returnValue(ACL_SUCCESS));
MOCKER(aclrtUnmapMem).stubs().will(returnValue(ACL_SUCCESS));
EXPECT_EQ(agent.SetMemoryRange(reinterpret_cast<void *>(reinterpret_cast<u64>(vir_ptr_map[0].data())), lenth, alignment, flags),
HCCL_SUCCESS);
EXPECT_EQ(agent.ActivateCommMemory(reinterpret_cast<void *>(reinterpret_cast<u64>(vir_ptr_map[0].data())),
lenth, 0,
reinterpret_cast<void *>(reinterpret_cast<u64>(vir_ptr_map[3].data())), flags),
HCCL_SUCCESS);
EXPECT_EQ(agent.DeactivateCommMemory(reinterpret_cast<void *>(reinterpret_cast<u64>(vir_ptr_map[0].data()))),
HCCL_SUCCESS);
EXPECT_EQ(agent.UnsetMemoryRange(reinterpret_cast<void *>(reinterpret_cast<u64>(vir_ptr_map[0].data()))),
HCCL_SUCCESS);
EXPECT_EQ(agent.BarrierClose(), HCCL_SUCCESS);
EXPECT_EQ(agent.DeInit(), HCCL_SUCCESS);
agent.mapDevPhyIdconnectedSockets_.clear();
HcclNetDeInit(NICDeployment::NIC_DEPLOYMENT_DEVICE, 0, 0);
}
TEST_F(ZeroCopyMemoryAgentUt, Ut_RequestBatchSendAsync_When_CombineAckAndReq_ExpectNorm)
{
MOCKER(GetIsSupSockBatchCloseImmed).stubs().will(invoke(stub_SocketManagerTest_GetIsSupSockBatchCloseImmed));
u32 interfaceVersion = 1;
MOCKER(hrtRaGetInterfaceVersion).expects(atMost(2))
.with(mockcpp::any(), mockcpp::any(), outBoundP(&interfaceVersion))
.will(returnValue(HCCL_SUCCESS));
HcclResult ret;
MOCKER(HcclSocket::IsSupportAsync).stubs().will(returnValue(true));
MOCKER(aclrtMapMem).stubs().will(returnValue(ACL_SUCCESS));
MOCKER(aclrtUnmapMem).stubs().will(returnValue(ACL_SUCCESS));
std::unique_ptr<HcclSocketManager> socketManager;
socketManager.reset(new (std::nothrow) HcclSocketManager(NICDeployment::NIC_DEPLOYMENT_DEVICE, 0, 0, 0));
std::string commTag = "SocketManagerTest";
bool isInterLink = false;
u32 socketsPerLink = 1;
NicType socketType = NicType::VNIC_TYPE;
HcclSocketRole localRole = HcclSocketRole::SOCKET_ROLE_SERVER;
HcclIpAddress localIPs(0x01);
ret = HcclNetInit(NICDeployment::NIC_DEPLOYMENT_DEVICE, 0, 0, false);
EXPECT_EQ(ret, HCCL_SUCCESS);
std::vector<RankInfo> rank_vector;
get_ranks_1server_2dev(rank_vector);
ZeroCopyMemoryAgent agent(socketManager, 0, 0, localIPs, rank_vector, 0, true,"ZeroCopyMemoryAgentTest");
MOCKER_CPP(&ZeroCopyMemoryAgent::InitInnerThread).stubs().will(returnValue(HCCL_SUCCESS));
EXPECT_EQ(agent.Init(), HCCL_SUCCESS);
MOCKER(aclrtReserveMemAddress).stubs().will(invoke(aclrtReserveMemAddress_stub));
MOCKER(aclrtMemImportFromShareableHandle).stubs().will(invoke(aclrtMemImportFromShareableHandle_stub));
MOCKER(aclrtMapMem).stubs().will(returnValue(ACL_SUCCESS));
MOCKER(aclrtUnmapMem).stubs().will(returnValue(ACL_SUCCESS));
MOCKER_CPP(&HcclSocket::SendAsync).stubs().will(returnValue(HCCL_SUCCESS));
agent.sendMgrs_[1].AddRequest(true, agent.exchangeDataForAck_[1]);
agent.sendMgrs_[1].AddRequest(false, agent.exchangeDataForSend_);
agent.RequestBatchSendAsync();
EXPECT_FALSE(agent.sendMgrs_[1].hasReq_[0].load());
EXPECT_TRUE(agent.sendMgrs_[1].hasReq_[1].load());
EXPECT_EQ(agent.sendMgrs_[1].reqDataSize_, 128U);
EXPECT_EQ(agent.DeInit(), HCCL_SUCCESS);
agent.mapDevPhyIdconnectedSockets_.clear();
HcclNetDeInit(NICDeployment::NIC_DEPLOYMENT_DEVICE, 0, 0);
}
TEST_F(ZeroCopyMemoryAgentUt, Ut_RequestBatchSendAsync_When_SocketSendFailed_ExpectTryMore)
{
MOCKER(GetIsSupSockBatchCloseImmed).stubs().will(invoke(stub_SocketManagerTest_GetIsSupSockBatchCloseImmed));
u32 interfaceVersion = 1;
MOCKER(hrtRaGetInterfaceVersion).expects(atMost(2))
.with(mockcpp::any(), mockcpp::any(), outBoundP(&interfaceVersion))
.will(returnValue(HCCL_SUCCESS));
HcclResult ret;
MOCKER(HcclSocket::IsSupportAsync).stubs().will(returnValue(true));
MOCKER(aclrtMapMem).stubs().will(returnValue(ACL_SUCCESS));
MOCKER(aclrtUnmapMem).stubs().will(returnValue(ACL_SUCCESS));
std::unique_ptr<HcclSocketManager> socketManager;
socketManager.reset(new (std::nothrow) HcclSocketManager(NICDeployment::NIC_DEPLOYMENT_DEVICE, 0, 0, 0));
std::string commTag = "SocketManagerTest";
bool isInterLink = false;
u32 socketsPerLink = 1;
NicType socketType = NicType::VNIC_TYPE;
HcclSocketRole localRole = HcclSocketRole::SOCKET_ROLE_SERVER;
HcclIpAddress localIPs(0x01);
ret = HcclNetInit(NICDeployment::NIC_DEPLOYMENT_DEVICE, 0, 0, false);
EXPECT_EQ(ret, HCCL_SUCCESS);
std::vector<RankInfo> rank_vector;
get_ranks_1server_2dev(rank_vector);
ZeroCopyMemoryAgent agent(socketManager, 0, 0, localIPs, rank_vector, 0, true,"ZeroCopyMemoryAgentTest");
MOCKER_CPP(&ZeroCopyMemoryAgent::InitInnerThread).stubs().will(returnValue(HCCL_SUCCESS));
EXPECT_EQ(agent.Init(), HCCL_SUCCESS);
MOCKER(aclrtReserveMemAddress).stubs().will(invoke(aclrtReserveMemAddress_stub));
MOCKER(aclrtMemImportFromShareableHandle).stubs().will(invoke(aclrtMemImportFromShareableHandle_stub));
MOCKER(aclrtMapMem).stubs().will(returnValue(ACL_SUCCESS));
MOCKER(aclrtUnmapMem).stubs().will(returnValue(ACL_SUCCESS));
MOCKER_CPP(&HcclSocket::SendAsync).stubs().will(returnValue(HCCL_E_NETWORK));
agent.sendMgrs_[1].AddRequest(false, agent.exchangeDataForSend_);
agent.sendMgrs_[1].reqDataSize_ = 0;
agent.sendMgrs_[1].sentSize_ = 0;
agent.RequestBatchSendAsync();
EXPECT_FALSE(agent.sendMgrs_[1].hasReq_[0].load());
EXPECT_TRUE(agent.sendMgrs_[1].hasReq_[1].load());
EXPECT_EQ(agent.sendMgrs_[1].reqDataSize_, 64U);
EXPECT_EQ(agent.sendMgrs_[1].sentSize_, 0);
HcclResult sendReqRet = HCCL_E_TCP_TRANSFER;
MOCKER_CPP(&HcclSocket::GetAsyncReqResult).stubs()
.with(mockcpp::any(), outBound(sendReqRet))
.will(returnValue(HCCL_E_AGAIN))
.then(returnValue(HCCL_SUCCESS));
agent.sendMgrs_[1].AddRequest(false, agent.exchangeDataForSend_);
void *handle = (void*)0x01;
agent.sendMgrs_[1].lastSendSize_ = 0;
agent.sendMgrs_[1].sentSize_ = 0;
agent.sendMgrs_[1].reqDataSize_ = 64;
agent.sendMgrs_[1].lastSendHandle_ = handle;
agent.CheckBatchSendAsyncResult();
EXPECT_EQ(agent.sendMgrs_[1].lastSendHandle_, handle);
agent.CheckBatchSendAsyncResult();
EXPECT_FALSE(agent.sendMgrs_[1].hasReq_[0].load());
EXPECT_TRUE(agent.sendMgrs_[1].hasReq_[1].load());
EXPECT_EQ(agent.sendMgrs_[1].reqDataSize_, 64U);
EXPECT_EQ(agent.sendMgrs_[1].sentSize_, 0);
EXPECT_EQ(agent.sendMgrs_[1].lastSendSize_, 0);
EXPECT_EQ(agent.sendMgrs_[1].lastSendHandle_, nullptr);
EXPECT_EQ(agent.DeInit(), HCCL_SUCCESS);
agent.mapDevPhyIdconnectedSockets_.clear();
HcclNetDeInit(NICDeployment::NIC_DEPLOYMENT_DEVICE, 0, 0);
}
TEST_F(ZeroCopyMemoryAgentUt, Ut_RequestBatchRecvAsync_When_SocketRecvFailed_ExpectTryMore)
{
MOCKER(GetIsSupSockBatchCloseImmed).stubs().will(invoke(stub_SocketManagerTest_GetIsSupSockBatchCloseImmed));
u32 interfaceVersion = 1;
MOCKER(hrtRaGetInterfaceVersion).expects(atMost(2))
.with(mockcpp::any(), mockcpp::any(), outBoundP(&interfaceVersion))
.will(returnValue(HCCL_SUCCESS));
HcclResult ret;
MOCKER(HcclSocket::IsSupportAsync).stubs().will(returnValue(true));
MOCKER(aclrtMapMem).stubs().will(returnValue(ACL_SUCCESS));
MOCKER(aclrtUnmapMem).stubs().will(returnValue(ACL_SUCCESS));
std::unique_ptr<HcclSocketManager> socketManager;
socketManager.reset(new (std::nothrow) HcclSocketManager(NICDeployment::NIC_DEPLOYMENT_DEVICE, 0, 0, 0));
std::string commTag = "SocketManagerTest";
bool isInterLink = false;
u32 socketsPerLink = 1;
NicType socketType = NicType::VNIC_TYPE;
HcclSocketRole localRole = HcclSocketRole::SOCKET_ROLE_SERVER;
HcclIpAddress localIPs(0x01);
ret = HcclNetInit(NICDeployment::NIC_DEPLOYMENT_DEVICE, 0, 0, false);
EXPECT_EQ(ret, HCCL_SUCCESS);
std::vector<RankInfo> rank_vector;
get_ranks_1server_2dev(rank_vector);
ZeroCopyMemoryAgent agent(socketManager, 0, 0, localIPs, rank_vector, 0, true,"ZeroCopyMemoryAgentTest");
MOCKER_CPP(&ZeroCopyMemoryAgent::InitInnerThread).stubs().will(returnValue(HCCL_SUCCESS));
EXPECT_EQ(agent.Init(), HCCL_SUCCESS);
MOCKER(aclrtReserveMemAddress).stubs().will(invoke(aclrtReserveMemAddress_stub));
MOCKER(aclrtMemImportFromShareableHandle).stubs().will(invoke(aclrtMemImportFromShareableHandle_stub));
MOCKER(aclrtMapMem).stubs().will(returnValue(ACL_SUCCESS));
MOCKER(aclrtUnmapMem).stubs().will(returnValue(ACL_SUCCESS));
void *handle = (void*)0x02;
MOCKER_CPP(&HcclSocket::RecvAsync).stubs()
.with(mockcpp::any(), mockcpp::any(), mockcpp::any(), outBoundP(&handle))
.will(returnValue(HCCL_SUCCESS));
agent.recvMgrs_[1].recvIndex_ = 0;
agent.recvMgrs_[1].lastRecvSize_ = 0;
agent.RequestBatchRecvAsync();
EXPECT_EQ(agent.recvMgrs_[1].lastRecvHandle_, handle);
EXPECT_EQ(agent.recvMgrs_[1].lastRecvSize_, 0);
HcclResult recvReqRet = HCCL_E_TCP_TRANSFER;
MOCKER_CPP(&HcclSocket::GetAsyncReqResult).stubs()
.with(mockcpp::any(), outBound(recvReqRet))
.will(returnValue(HCCL_E_AGAIN))
.then(returnValue(HCCL_SUCCESS));
agent.CheckBatchRecvAsyncResult();
EXPECT_EQ(agent.recvMgrs_[1].recvIndex_, 0);
EXPECT_EQ(agent.recvMgrs_[1].lastRecvHandle_, handle);
agent.CheckBatchRecvAsyncResult();
EXPECT_EQ(agent.recvMgrs_[1].recvIndex_, 0);
EXPECT_EQ(agent.recvMgrs_[1].lastRecvHandle_, nullptr);
EXPECT_EQ(agent.DeInit(), HCCL_SUCCESS);
agent.mapDevPhyIdconnectedSockets_.clear();
HcclNetDeInit(NICDeployment::NIC_DEPLOYMENT_DEVICE, 0, 0);
}