* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "gtest/gtest.h"
#include <mockcpp/mockcpp.hpp>
#include <stdio.h>
#include <stdlib.h>
#include <assert.h>
#include <securec.h>
#include <ifaddrs.h>
#include <sys/socket.h>
#include <netdb.h>
#include <sys/types.h>
#include <stddef.h>
#include <sys/mman.h>
#include <fcntl.h>
#include <sys/mman.h>
#include <hccl/hccl_comm.h>
#include <hccl/hccl_inner.h>
#define private public
#define protected public
#include "hccl_impl.h"
#include "hccl_comm_pub.h"
#undef protected
#undef private
#include "llt_hccl_stub_pub.h"
#include <iostream>
#include <fstream>
#include <nlohmann/json.hpp>
#include "hccl/base.h"
#include "hccl/hccl_ex.h"
#include <hccl/hccl_types.h>
#include "topoinfo_ranktableParser_pub.h"
#include "tsd/tsd_client.h"
#include "dltdt_function.h"
#include <unistd.h>
#include "externalinput_pub.h"
#include "v80_rank_table.h"
#include "externalinput.h"
#include "op_base.h"
#include <functional>
#include <map>
using namespace std;
using namespace hccl;
class OpbaseMultiThreadTest : public testing::TestWithParam<bool>
{
protected:
virtual void SetUp()
{
ra_set_test_type(0, "ST_TEST");
static s32 call_cnt = 0;
DlTdtFunction::GetInstance().DlTdtFunctionInit();
TsdOpen(1,2);
string name =std::to_string(call_cnt++) +"_" + __PRETTY_FUNCTION__;
ra_set_shm_name(name .c_str());
ResetInitState();
s32 portNum = -1;
MOCKER(hrtGetHccsPortNum)
.stubs()
.with(mockcpp::any(), outBound(portNum))
.will(returnValue(HCCL_SUCCESS));
std::cout << "A Test SetUP" << std::endl;
}
virtual void TearDown()
{
TsdClose(1);
std::cout << "A Test TearDown" << std::endl;
}
};
struct ThreadContext {
HcclComm comm;
int32_t deviceLogicID;
uint32_t ndev;
};
#define HCCL_COM_DATA_SIZE 1024
void ExecAllReduce(int ndev, HcclComm comm, uint32_t deviceLogicID, uint32_t rankId)
{
int ret = HCCL_SUCCESS;
ret = hrtSetDevice(deviceLogicID);
EXPECT_EQ(ret, HCCL_SUCCESS);
rtStream_t stream;
rtError_t rt_ret = RT_ERROR_NONE;
rt_ret = aclrtCreateStream(&stream);
EXPECT_EQ(rt_ret, RT_ERROR_NONE);
s32 count = HCCL_COM_DATA_SIZE;
s8* sendbuf;
sendbuf= (s8*)sal_malloc(count * sizeof(s8));
sal_memset(sendbuf, count * sizeof(s8), 0, count * sizeof(s8));
s8* recvbuf;
recvbuf= (s8*)sal_malloc(count * sizeof(s8));
sal_memset(recvbuf, count * sizeof(s8), 0, count * sizeof(s8));
for (int j = 0; j < count; j++)
{
sendbuf[j] = 2;
}
uint32_t rankSize = 0;
ret = HcclGetRankSize(comm, &rankSize);
EXPECT_EQ(ret, HCCL_SUCCESS);
EXPECT_EQ(rankSize, ndev);
uint32_t rankID = 0;
ret = HcclGetRankId(comm, &rankID);
EXPECT_EQ(ret, HCCL_SUCCESS);
EXPECT_EQ(rankID, rankId);
ret = HcclAllReduceInner(sendbuf, recvbuf, count, HCCL_DATA_TYPE_INT8, HCCL_REDUCE_SUM, comm, stream);
EXPECT_EQ(ret, HCCL_SUCCESS);
rt_ret = aclrtSynchronizeStream(stream);
EXPECT_EQ(rt_ret, RT_ERROR_NONE);
s32 errors = 0;
for (int j = 0; j < count; j++)
{
if (recvbuf[j] != 2)
{
errors ++;
break;
}
}
EXPECT_EQ(errors, 0);
sal_free(sendbuf);
sal_free(recvbuf);
rt_ret = aclrtDestroyStream(stream);
EXPECT_EQ(rt_ret, RT_ERROR_NONE);
ret = hrtResetDevice(deviceLogicID);
EXPECT_EQ(ret, 0);
return;
}
#if 0
TEST_F(OpbaseMultiThreadTest, ut_HcclAllReduce)
{
int ret = HCCL_SUCCESS;
uint32_t ndev = 1;
int32_t devices[ndev] = {0};
HcclComm comms[ndev];
for (int i = 0; i < ndev; i++) {
ret = hrtSetDevice(devices[i]);
EXPECT_EQ(ret, 0);
}
ret = HcclCommInitAll(ndev, devices, comms);
EXPECT_EQ(ret, 0);
std::vector<std::thread> threads;
threads.resize(ndev);
for (uint32_t i = 0; i < ndev; i++) {
threads[i] = std::thread(ExecAllReduce, ndev, comms[i], devices[i], i);
}
for (uint32_t i = 0; i < ndev; ++i) {
threads[i].join();
}
for (uint32_t i = 0; i < ndev; i++) {
ret = hrtResetDevice(devices[i]);
EXPECT_EQ(ret, 0);
ret = HcclCommDestroy(comms[i]);
EXPECT_EQ(ret, 0);
}
}
#endif
void ExecAllGather(int ndev, HcclComm comm, uint32_t deviceLogicID, uint32_t rankId)
{
int ret = HCCL_SUCCESS;
ret = hrtSetDevice(deviceLogicID);
EXPECT_EQ(ret, HCCL_SUCCESS);
rtStream_t stream;
rtError_t rt_ret = RT_ERROR_NONE;
rt_ret = aclrtCreateStream(&stream);
EXPECT_EQ(rt_ret, RT_ERROR_NONE);
s32 count = 8;
s8* sendbuf;
sendbuf= (s8*)sal_malloc(count * sizeof(s8));
sal_memset(sendbuf, count * sizeof(s8), 0, count * sizeof(s8));
s8* recvbuf;
recvbuf= (s8*)sal_malloc(count * sizeof(s8));
sal_memset(recvbuf, count * sizeof(s8), 0, count * sizeof(s8));
for (int j = 0; j < count; j++)
{
sendbuf[j] = 2;
}
uint32_t rankSize = 0;
ret = HcclGetRankSize(comm, &rankSize);
EXPECT_EQ(ret, HCCL_SUCCESS);
EXPECT_EQ(rankSize, ndev);
uint32_t rankID = 0;
ret = HcclGetRankId(comm, &rankID);
EXPECT_EQ(ret, HCCL_SUCCESS);
EXPECT_EQ(rankID, rankId);
ret = HcclAllGatherInner(sendbuf, recvbuf, count, HCCL_DATA_TYPE_INT8, comm, stream);
EXPECT_EQ(ret, HCCL_SUCCESS);
rt_ret = aclrtSynchronizeStream(stream);
EXPECT_EQ(rt_ret, RT_ERROR_NONE);
s32 errors = 0;
for (int j = 0; j < count; j++)
{
if (recvbuf[j] != 2)
{
errors ++;
break;
}
}
EXPECT_EQ(errors, 0);
sal_free(sendbuf);
sal_free(recvbuf);
rt_ret = aclrtDestroyStream(stream);
EXPECT_EQ(rt_ret, RT_ERROR_NONE);
ret = hrtResetDevice(deviceLogicID);
EXPECT_EQ(ret, 0);
return;
}
#if 0
TEST_F(OpbaseMultiThreadTest, ut_HcclAllGather)
{
int ret = HCCL_SUCCESS;
uint32_t ndev = 1;
int32_t devices[ndev] = {0};
HcclComm comms[ndev];
for (int i = 0; i < ndev; i++) {
ret = hrtSetDevice(devices[i]);
EXPECT_EQ(ret, 0);
}
ret = HcclCommInitAll(ndev, devices, comms);
EXPECT_EQ(ret, 0);
std::vector<std::thread> threads;
threads.resize(ndev);
for (uint32_t i = 0; i < ndev; i++) {
threads[i] = std::thread(ExecAllGather, ndev, comms[i], devices[i], i);
}
for (uint32_t i = 0; i < ndev; ++i) {
threads[i].join();
}
for (uint32_t i = 0; i < ndev; i++) {
ret = hrtResetDevice(devices[i]);
EXPECT_EQ(ret, 0);
ret = HcclCommDestroy(comms[i]);
EXPECT_EQ(ret, 0);
}
}
#endif
void ExecBroadCast(int ndev, HcclComm comm, uint32_t deviceLogicID, uint32_t rankId)
{
int ret = HCCL_SUCCESS;
ret = hrtSetDevice(deviceLogicID);
EXPECT_EQ(ret, HCCL_SUCCESS);
rtStream_t stream;
rtError_t rt_ret = RT_ERROR_NONE;
rt_ret = aclrtCreateStream(&stream);
EXPECT_EQ(rt_ret, RT_ERROR_NONE);
s32 count = HCCL_COM_DATA_SIZE;
s8* sendbuf;
sendbuf= (s8*)sal_malloc(count * sizeof(s8));
sal_memset(sendbuf, count * sizeof(s8), 0, count * sizeof(s8));
for (int j = 0; j < count; j++)
{
sendbuf[j] = 2;
}
uint32_t rankSize = 0;
ret = HcclGetRankSize(comm, &rankSize);
EXPECT_EQ(ret, HCCL_SUCCESS);
EXPECT_EQ(rankSize, ndev);
uint32_t rankID = 0;
ret = HcclGetRankId(comm, &rankID);
EXPECT_EQ(ret, HCCL_SUCCESS);
EXPECT_EQ(rankID, rankId);
ret = HcclBroadcastInner(sendbuf, count, HCCL_DATA_TYPE_INT8, 0, comm, stream);
EXPECT_EQ(ret, HCCL_SUCCESS);
rt_ret = aclrtSynchronizeStream(stream);
EXPECT_EQ(rt_ret, RT_ERROR_NONE);
s32 errors = 0;
for (int j = 0; j < count; j++)
{
if (sendbuf[j] != 2)
{
errors ++;
break;
}
}
EXPECT_EQ(errors, 0);
sal_free(sendbuf);
rt_ret = aclrtDestroyStream(stream);
EXPECT_EQ(rt_ret, RT_ERROR_NONE);
ret = hrtResetDevice(deviceLogicID);
EXPECT_EQ(ret, 0);
return;
}
#if 0
TEST_F(OpbaseMultiThreadTest, ut_HcclBroadCast)
{
int ret = HCCL_SUCCESS;
uint32_t ndev = 1;
int32_t devices[ndev] = {0};
HcclComm comms[ndev];
for (int i = 0; i < ndev; i++) {
ret = hrtSetDevice(devices[i]);
EXPECT_EQ(ret, 0);
}
ret = HcclCommInitAll(ndev, devices, comms);
EXPECT_EQ(ret, 0);
std::vector<std::thread> threads;
threads.resize(ndev);
for (uint32_t i = 0; i < ndev; i++) {
threads[i] = std::thread(ExecBroadCast, ndev, comms[i], devices[i], i);
}
for (uint32_t i = 0; i < ndev; ++i) {
threads[i].join();
}
for (uint32_t i = 0; i < ndev; i++) {
ret = hrtResetDevice(devices[i]);
EXPECT_EQ(ret, 0);
ret = HcclCommDestroy(comms[i]);
EXPECT_EQ(ret, 0);
}
}
#endif
void ExecReduceScatter(int ndev, HcclComm comm, uint32_t deviceLogicID, uint32_t rankId)
{
int ret = HCCL_SUCCESS;
ret = hrtSetDevice(deviceLogicID);
EXPECT_EQ(ret, HCCL_SUCCESS);
rtStream_t stream;
rtError_t rt_ret = RT_ERROR_NONE;
rt_ret = aclrtCreateStream(&stream);
EXPECT_EQ(rt_ret, RT_ERROR_NONE);
s32 count = HCCL_COM_DATA_SIZE;
s8* sendbuf;
sendbuf= (s8*)sal_malloc(count * sizeof(s8));
sal_memset(sendbuf, count * sizeof(s8), 0, count * sizeof(s8));
s8* recvbuf;
recvbuf= (s8*)sal_malloc(count * sizeof(s8));
sal_memset(recvbuf, count * sizeof(s8), 0, count * sizeof(s8));
for (int j = 0; j < count; j++)
{
sendbuf[j] = 2;
}
uint32_t rankSize = 0;
ret = HcclGetRankSize(comm, &rankSize);
EXPECT_EQ(ret, HCCL_SUCCESS);
EXPECT_EQ(rankSize, ndev);
uint32_t rankID = 0;
ret = HcclGetRankId(comm, &rankID);
EXPECT_EQ(ret, HCCL_SUCCESS);
EXPECT_EQ(rankID, rankId);
ret = HcclReduceScatterInner(sendbuf, recvbuf, count, HCCL_DATA_TYPE_INT8, HCCL_REDUCE_SUM, comm, stream);
EXPECT_EQ(ret, HCCL_SUCCESS);
rt_ret = aclrtSynchronizeStream(stream);
EXPECT_EQ(rt_ret, RT_ERROR_NONE);
s32 errors = 0;
for (int j = 0; j < count; j++)
{
if (recvbuf[j] != 2)
{
errors ++;
break;
}
}
EXPECT_EQ(errors, 0);
sal_free(sendbuf);
sal_free(recvbuf);
rt_ret = aclrtDestroyStream(stream);
EXPECT_EQ(rt_ret, RT_ERROR_NONE);
ret = hrtResetDevice(deviceLogicID);
EXPECT_EQ(ret, 0);
return;
}
#if 0
TEST_F(OpbaseMultiThreadTest, ut_HcclReduceScatter)
{
int ret = HCCL_SUCCESS;
uint32_t ndev = 1;
int32_t devices[ndev] = {0};
HcclComm comms[ndev];
for (int i = 0; i < ndev; i++) {
ret = hrtSetDevice(devices[i]);
EXPECT_EQ(ret, 0);
}
ret = HcclCommInitAll(ndev, devices, comms);
EXPECT_EQ(ret, 0);
std::vector<std::thread> threads;
threads.resize(ndev);
for (uint32_t i = 0; i < ndev; i++) {
threads[i] = std::thread(ExecReduceScatter, ndev, comms[i], devices[i], i);
}
for (uint32_t i = 0; i < ndev; ++i) {
threads[i].join();
}
for (uint32_t i = 0; i < ndev; i++) {
ret = hrtResetDevice(devices[i]);
EXPECT_EQ(ret, 0);
ret = HcclCommDestroy(comms[i]);
EXPECT_EQ(ret, 0);
}
}
#endif
void ExecReduce(int ndev, HcclComm comm, uint32_t deviceLogicID, uint32_t rankId)
{
int ret = HCCL_SUCCESS;
ret = hrtSetDevice(deviceLogicID);
EXPECT_EQ(ret, HCCL_SUCCESS);
rtStream_t stream;
rtError_t rt_ret = RT_ERROR_NONE;
rt_ret = aclrtCreateStream(&stream);
EXPECT_EQ(rt_ret, RT_ERROR_NONE);
s32 count = HCCL_COM_DATA_SIZE;
s8* sendbuf;
sendbuf= (s8*)sal_malloc(count * sizeof(s8));
sal_memset(sendbuf, count * sizeof(s8), 0, count * sizeof(s8));
s8* recvbuf;
recvbuf= (s8*)sal_malloc(count * sizeof(s8));
sal_memset(recvbuf, count * sizeof(s8), 0, count * sizeof(s8));
for (int j = 0; j < count; j++)
{
sendbuf[j] = 2;
}
uint32_t rankSize = 0;
ret = HcclGetRankSize(comm, &rankSize);
EXPECT_EQ(ret, HCCL_SUCCESS);
EXPECT_EQ(rankSize, ndev);
uint32_t rankID = 0;
ret = HcclGetRankId(comm, &rankID);
EXPECT_EQ(ret, HCCL_SUCCESS);
EXPECT_EQ(rankID, rankId);
ret = HcclReduceInner(sendbuf, recvbuf, count, HCCL_DATA_TYPE_INT8, HCCL_REDUCE_SUM, 0, comm, stream);
EXPECT_EQ(ret, HCCL_SUCCESS);
rt_ret = aclrtSynchronizeStream(stream);
EXPECT_EQ(rt_ret, RT_ERROR_NONE);
s32 errors = 0;
if (rankId == 0) {
for (int j = 0; j < count; j++)
{
if (recvbuf[j] != 2)
{
printf("rankId : %d, deviceLogicID: %d, j : %d, val : %d \n", rankId, deviceLogicID, j, recvbuf[j]);
errors ++;
break;
}
}
}
EXPECT_EQ(errors, 0);
sal_free(sendbuf);
sal_free(recvbuf);
rt_ret = aclrtDestroyStream(stream);
EXPECT_EQ(rt_ret, RT_ERROR_NONE);
ret = hrtResetDevice(deviceLogicID);
EXPECT_EQ(ret, 0);
return;
}
#if 0
TEST_F(OpbaseMultiThreadTest, ut_HcclReduce)
{
int ret = HCCL_SUCCESS;
uint32_t ndev = 1;
int32_t devices[ndev] = {0};
HcclComm comms[ndev];
for (int i = 0; i < ndev; i++) {
ret = hrtSetDevice(devices[i]);
EXPECT_EQ(ret, 0);
}
ret = HcclCommInitAll(ndev, devices, comms);
EXPECT_EQ(ret, 0);
std::vector<std::thread> threads;
threads.resize(ndev);
for (uint32_t i = 0; i < ndev; i++) {
threads[i] = std::thread(ExecReduce, ndev, comms[i], devices[i], i);
}
for (uint32_t i = 0; i < ndev; ++i) {
threads[i].join();
}
for (uint32_t i = 0; i < ndev; i++) {
ret = hrtResetDevice(devices[i]);
EXPECT_EQ(ret, 0);
ret = HcclCommDestroy(comms[i]);
EXPECT_EQ(ret, 0);
}
}
#endif