* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <chrono>
#include <thread>
#include <string>
#include <vector>
#include "acl/acl.h"
#include "hixl/hixl.h"
using namespace hixl;
namespace {
constexpr size_t kBufferSize = 8 * 1024 * 1024;
constexpr size_t kChunkSize = 16 * 1024;
constexpr size_t kChunkCount = kBufferSize / kChunkSize;
constexpr int32_t kDefaultDevA = 0;
constexpr int32_t kDefaultDevB = 2;
constexpr int32_t kVersionLegacy = 0;
constexpr const char *kEngineA = "127.0.0.1:16000";
constexpr const char *kEngineB = "127.0.0.1:16001";
constexpr uint8_t kFillA = 0xAA;
constexpr uint8_t kFillB = 0xBB;
constexpr int32_t kConnTimeout = 5000;
constexpr int32_t kMaxPollCount = 100000;
static const std::vector<std::string> kSupportedProtocols = {"roce:device", "roce:host", "uboe:device",
"ubg:device", "ub_ctp:device", "ub_tp:device",
"ub_ctp:host", "ub_tp:host"};
#define CHECK_ACL(x) \
do { \
aclError __err_code = x; \
if (__err_code != ACL_ERROR_NONE) { \
std::cerr << __FILE__ << ":" << __LINE__ << " aclError:" << __err_code << std::endl; \
} \
} while (0)
struct EngineCtx {
Hixl engine;
int32_t device_id = 0;
const char *name = nullptr;
bool initialized = false;
bool connected = false;
void *dev_buf = nullptr;
void *host_buf = nullptr;
MemHandle dev_handle = nullptr;
MemHandle host_handle = nullptr;
};
const char *GetRecentErrMsg() {
const char *errmsg = aclGetRecentErrMsg();
return (errmsg == nullptr) ? "no error" : errmsg;
}
void TokenizeProtocols(const std::string &val, std::vector<std::string> &protocols) {
size_t start = 0;
size_t pos = val.find(',');
while (pos != std::string::npos) {
protocols.push_back(val.substr(start, pos - start));
start = pos + 1;
pos = val.find(',', start);
}
protocols.push_back(val.substr(start));
}
int32_t VerifyProtocolSupport(const std::string &proto) {
for (const auto &item : kSupportedProtocols) {
if (item == proto) {
return 0;
}
}
printf("[ERROR] Invalid protocol: %s\n", proto.c_str());
printf("Supported:");
for (const auto &item : kSupportedProtocols) {
printf(" %s", item.c_str());
}
printf("\n");
return -1;
}
int32_t ParseArgs(int32_t argc, char **argv, int32_t &device_a, int32_t &device_b, std::vector<std::string> &protocols,
int32_t &version) {
for (int32_t i = 1; i < argc; ++i) {
std::string arg = argv[i];
if (arg.find("--device=") == 0) {
std::string device_str = arg.substr(9);
auto comma_pos = device_str.find(',');
if (comma_pos == std::string::npos) {
printf("[ERROR] Invalid --device format, expected id1,id2\n");
return -1;
}
device_a = std::stoi(device_str.substr(0, comma_pos));
device_b = std::stoi(device_str.substr(comma_pos + 1));
} else if (arg.find("--protocol=") == 0) {
TokenizeProtocols(arg.substr(11), protocols);
} else if (arg.find("--version=") == 0) {
version = std::stoi(arg.substr(10));
} else {
printf("[ERROR] Unknown argument: %s\n", arg.c_str());
printf("Usage: %s --protocol=<type>[,...] [--device=id1,id2] [--version=0|1]\n", argv[0]);
return -1;
}
}
if (protocols.empty()) {
printf("[ERROR] --protocol is required\n");
return -1;
}
for (const auto &proto : protocols) {
if (VerifyProtocolSupport(proto) != 0) {
return -1;
}
}
bool is_legacy = (version == kVersionLegacy);
bool single_roce = (protocols.size() == 1 && protocols[0] == "roce:device");
if (is_legacy && !single_roce) {
printf("[ERROR] version 0 only supports roce:device\n");
return -1;
}
printf("[INFO] ParseArgs success: device_a=%d, device_b=%d, version=%d\n", device_a, device_b, version);
for (const auto &proto : protocols) {
printf("[INFO] protocol: %s\n", proto.c_str());
}
return 0;
}
int32_t BuildLegacyConfig(EngineCtx &ctx, const std::vector<std::string> &protocols,
std::map<AscendString, AscendString> &options) {
printf("[INFO] %s using legacy flow (version=0)\n", ctx.name);
std::string eng_name(ctx.name);
auto sep = eng_name.find(':');
uint32_t listen_port = std::stoi(eng_name.substr(sep + 1));
std::string lcomm = "{\"version\": \"1.2\"}";
options[OPTION_LOCAL_COMM_RES] = lcomm.c_str();
std::string res_cfg = "{\"comm_resource_config.listen_port\": " + std::to_string(listen_port) + "}";
options[OPTION_GLOBAL_RESOURCE_CONFIG] = res_cfg.c_str();
if (protocols[0] == "roce:device") {
options[OPTION_BUFFER_POOL] = "0:0";
setenv("HCCL_INTRA_ROCE_ENABLE", "1", 1);
}
return 0;
}
int32_t BuildV2Config(const std::vector<std::string> &protocols, std::map<AscendString, AscendString> &options) {
std::string proto_list;
bool first = true;
for (const auto &proto : protocols) {
if (!first) {
proto_list += ",";
}
proto_list += "\"" + proto + "\"";
first = false;
}
std::string res_config = "{\"comm_resource_config.protocol_desc\": [" + proto_list + "]}";
options[OPTION_GLOBAL_RESOURCE_CONFIG] = res_config.c_str();
return 0;
}
int32_t InitEngine(EngineCtx &ctx, const std::vector<std::string> &protocols, int32_t version, uint8_t fill_value) {
CHECK_ACL(aclrtSetDevice(ctx.device_id));
std::map<AscendString, AscendString> options;
if (version == kVersionLegacy) {
BuildLegacyConfig(ctx, protocols, options);
} else {
BuildV2Config(protocols, options);
}
auto ret = ctx.engine.Initialize(ctx.name, options);
if (ret != SUCCESS) {
printf("[ERROR] Initialize %s failed, ret=%u, errmsg:%s\n", ctx.name, ret, GetRecentErrMsg());
return -1;
}
ctx.initialized = true;
printf("[INFO] InitEngine %s success\n", ctx.name);
uint8_t *dev_ptr = nullptr;
CHECK_ACL(aclrtMalloc(reinterpret_cast<void **>(&dev_ptr), kBufferSize, ACL_MEM_MALLOC_HUGE_ONLY));
ctx.dev_buf = dev_ptr;
CHECK_ACL(aclrtMallocHost(&ctx.host_buf, kBufferSize));
std::fill(static_cast<uint8_t *>(ctx.host_buf), static_cast<uint8_t *>(ctx.host_buf) + kBufferSize, 0);
MemDesc desc{};
desc.addr = reinterpret_cast<uintptr_t>(ctx.dev_buf);
desc.len = kBufferSize;
ret = ctx.engine.RegisterMem(desc, MEM_DEVICE, ctx.dev_handle);
if (ret != SUCCESS) {
printf("[ERROR] %s RegisterMem device failed, ret=%u, errmsg:%s\n", ctx.name, ret, GetRecentErrMsg());
return -1;
}
desc.addr = reinterpret_cast<uintptr_t>(ctx.host_buf);
ret = ctx.engine.RegisterMem(desc, MEM_HOST, ctx.host_handle);
if (ret != SUCCESS) {
printf("[ERROR] %s RegisterMem host failed, ret=%u, errmsg:%s\n", ctx.name, ret, GetRecentErrMsg());
return -1;
}
std::fill(static_cast<uint8_t *>(ctx.host_buf), static_cast<uint8_t *>(ctx.host_buf) + kBufferSize, fill_value);
CHECK_ACL(aclrtMemcpy(ctx.dev_buf, kBufferSize, ctx.host_buf, kBufferSize, ACL_MEMCPY_HOST_TO_DEVICE));
printf("[INFO] %s InitEngine success, dev:%p, host:%p\n", ctx.name, ctx.dev_buf, ctx.host_buf);
return 0;
}
int32_t Connect(EngineCtx &ctx_a, EngineCtx &ctx_b) {
auto ret = ctx_a.engine.ConnectAsync(ctx_b.name, kConnTimeout);
if (ret != SUCCESS) {
printf("[ERROR] ConnectAsync %s->%s failed, ret=%u, errmsg:%s\n", ctx_a.name, ctx_b.name, ret, GetRecentErrMsg());
return -1;
}
ret = ctx_b.engine.ConnectAsync(ctx_a.name, kConnTimeout);
if (ret != SUCCESS) {
printf("[ERROR] ConnectAsync %s->%s failed, ret=%u, errmsg:%s\n", ctx_b.name, ctx_a.name, ret, GetRecentErrMsg());
return -1;
}
AsyncConnectStatus status_a = AsyncConnectStatus::NOT_CONNECT;
AsyncConnectStatus status_b = AsyncConnectStatus::NOT_CONNECT;
while (status_a != AsyncConnectStatus::CONNECTED || status_b != AsyncConnectStatus::CONNECTED) {
ctx_a.engine.GetAsyncConnectStatus(ctx_b.name, status_a);
ctx_b.engine.GetAsyncConnectStatus(ctx_a.name, status_b);
if (status_a == AsyncConnectStatus::CONNECT_FAILED || status_b == AsyncConnectStatus::CONNECT_FAILED) {
printf("[ERROR] Connect failed, status_a=%d, status_b=%d\n", static_cast<int>(status_a),
static_cast<int>(status_b));
return -1;
}
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
ctx_a.connected = true;
ctx_b.connected = true;
printf("[INFO] Connect success\n");
return 0;
}
void BuildTransferDescs(std::vector<TransferOpDesc> &descs, EngineCtx &ctx_a, EngineCtx &ctx_b) {
descs.reserve(kChunkCount);
for (size_t i = 0; i < kChunkCount; ++i) {
TransferOpDesc desc{};
desc.local_addr = reinterpret_cast<uintptr_t>(ctx_a.dev_buf) + i * kChunkSize;
desc.remote_addr = reinterpret_cast<uintptr_t>(ctx_b.host_buf) + i * kChunkSize;
desc.len = kChunkSize;
descs.push_back(desc);
}
}
int32_t StartTransfers(EngineCtx &ctx_a, EngineCtx &ctx_b, std::vector<TransferOpDesc> &descs, TransferReq &req_a,
TransferReq &req_b) {
TransferArgs args{};
req_a = nullptr;
auto ret = ctx_a.engine.TransferAsync(ctx_b.name, WRITE, descs, args, req_a);
if (ret != SUCCESS) {
printf("[ERROR] TransferAsync %s->%s failed, ret=%u, errmsg:%s\n", ctx_a.name, ctx_b.name, ret, GetRecentErrMsg());
return -1;
}
for (size_t i = 0; i < kChunkCount; ++i) {
descs[i].local_addr = reinterpret_cast<uintptr_t>(ctx_b.dev_buf) + i * kChunkSize;
descs[i].remote_addr = reinterpret_cast<uintptr_t>(ctx_a.host_buf) + i * kChunkSize;
}
req_b = nullptr;
ret = ctx_b.engine.TransferAsync(ctx_a.name, WRITE, descs, args, req_b);
if (ret != SUCCESS) {
printf("[ERROR] TransferAsync %s->%s failed, ret=%u, errmsg:%s\n", ctx_b.name, ctx_a.name, ret, GetRecentErrMsg());
return -1;
}
return 0;
}
int32_t WaitTransfers(EngineCtx &ctx_a, EngineCtx &ctx_b, TransferReq req_a, TransferReq req_b) {
TransferStatus st_a = TransferStatus::WAITING;
TransferStatus st_b = TransferStatus::WAITING;
int32_t poll_count = 0;
while (st_a == TransferStatus::WAITING || st_b == TransferStatus::WAITING) {
if (st_a == TransferStatus::WAITING) {
ctx_a.engine.GetTransferStatus(req_a, st_a);
}
if (st_b == TransferStatus::WAITING) {
ctx_b.engine.GetTransferStatus(req_b, st_b);
}
if (++poll_count > kMaxPollCount) {
printf("[ERROR] Transfer poll timeout\n");
return -1;
}
std::this_thread::sleep_for(std::chrono::microseconds(100));
}
if (st_a != TransferStatus::COMPLETED) {
printf("[ERROR] Transfer %s->%s failed, status=%d\n", ctx_a.name, ctx_b.name, static_cast<int>(st_a));
return -1;
}
if (st_b != TransferStatus::COMPLETED) {
printf("[ERROR] Transfer %s->%s failed, status=%d\n", ctx_b.name, ctx_a.name, static_cast<int>(st_b));
return -1;
}
printf("[INFO] Transfer completed\n");
return 0;
}
int32_t Transfer(EngineCtx &ctx_a, EngineCtx &ctx_b) {
std::vector<TransferOpDesc> descs;
BuildTransferDescs(descs, ctx_a, ctx_b);
TransferReq req_a = nullptr;
TransferReq req_b = nullptr;
if (StartTransfers(ctx_a, ctx_b, descs, req_a, req_b) != 0) {
return -1;
}
return WaitTransfers(ctx_a, ctx_b, req_a, req_b);
}
int32_t Verify(EngineCtx &ctx_a, EngineCtx &ctx_b) {
std::vector<uint8_t> expected_a(kBufferSize, kFillA);
std::vector<uint8_t> expected_b(kBufferSize, kFillB);
if (std::memcmp(ctx_b.host_buf, expected_a.data(), kBufferSize) != 0) {
printf("[ERROR] Verify %s host failed, expected 0xAA\n", ctx_b.name);
return -1;
}
if (std::memcmp(ctx_a.host_buf, expected_b.data(), kBufferSize) != 0) {
printf("[ERROR] Verify %s host failed, expected 0xBB\n", ctx_a.name);
return -1;
}
printf("[INFO] Verify success\n");
return 0;
}
void DisconnectBoth(EngineCtx &ctx_a, EngineCtx &ctx_b) {
if (ctx_a.connected) {
auto ret = ctx_a.engine.DisconnectAsync(ctx_b.name, kConnTimeout);
if (ret != SUCCESS) {
printf("[ERROR] DisconnectAsync %s->%s failed, ret=%u, errmsg:%s\n", ctx_a.name, ctx_b.name, ret,
GetRecentErrMsg());
}
}
if (ctx_b.connected) {
auto ret = ctx_b.engine.DisconnectAsync(ctx_a.name, kConnTimeout);
if (ret != SUCCESS) {
printf("[ERROR] DisconnectAsync %s->%s failed, ret=%u, errmsg:%s\n", ctx_b.name, ctx_a.name, ret,
GetRecentErrMsg());
}
}
if (ctx_a.connected || ctx_b.connected) {
AsyncConnectStatus st_a = AsyncConnectStatus::DISCONNECT_PENDING;
AsyncConnectStatus st_b = AsyncConnectStatus::DISCONNECT_PENDING;
while (st_a != AsyncConnectStatus::NOT_CONNECT || st_b != AsyncConnectStatus::NOT_CONNECT) {
ctx_a.engine.GetAsyncConnectStatus(ctx_b.name, st_a);
ctx_b.engine.GetAsyncConnectStatus(ctx_a.name, st_b);
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
printf("[INFO] Disconnect success\n");
}
}
void CleanupResources(EngineCtx &ctx) {
if (ctx.dev_handle != nullptr) {
ctx.engine.DeregisterMem(ctx.dev_handle);
}
if (ctx.host_handle != nullptr) {
ctx.engine.DeregisterMem(ctx.host_handle);
}
if (ctx.dev_buf != nullptr) {
CHECK_ACL(aclrtFree(ctx.dev_buf));
}
if (ctx.host_buf != nullptr) {
CHECK_ACL(aclrtFreeHost(ctx.host_buf));
}
}
void ShutdownEngines(EngineCtx &ctx_a, EngineCtx &ctx_b) {
if (ctx_a.initialized) {
ctx_a.engine.Finalize();
}
if (ctx_b.initialized) {
ctx_b.engine.Finalize();
}
}
void RestoreDevices(EngineCtx &ctx_a, EngineCtx &ctx_b) {
CHECK_ACL(aclrtResetDevice(ctx_a.device_id));
CHECK_ACL(aclrtResetDevice(ctx_b.device_id));
}
void Finalize(EngineCtx &ctx_a, EngineCtx &ctx_b) {
DisconnectBoth(ctx_a, ctx_b);
CleanupResources(ctx_a);
CleanupResources(ctx_b);
ShutdownEngines(ctx_a, ctx_b);
RestoreDevices(ctx_a, ctx_b);
}
int32_t Run(EngineCtx &ctx_a, EngineCtx &ctx_b, const std::vector<std::string> &protocols, int32_t version) {
if (InitEngine(ctx_a, protocols, version, kFillA) != 0) {
return -1;
}
if (InitEngine(ctx_b, protocols, version, kFillB) != 0) {
return -1;
}
if (Connect(ctx_a, ctx_b) != 0) {
return -1;
}
if (Transfer(ctx_a, ctx_b) != 0) {
return -1;
}
return Verify(ctx_a, ctx_b);
}
}
int main(int32_t argc, char **argv) {
std::vector<std::string> protocols;
int32_t version = 1;
int32_t device_a = kDefaultDevA;
int32_t device_b = kDefaultDevB;
if (ParseArgs(argc, argv, device_a, device_b, protocols, version) != 0) {
return -1;
}
EngineCtx ctx_a;
EngineCtx ctx_b;
ctx_a.name = kEngineA;
ctx_a.device_id = device_a;
ctx_b.name = kEngineB;
ctx_b.device_id = device_b;
int32_t ret = Run(ctx_a, ctx_b, protocols, version);
Finalize(ctx_a, ctx_b);
return ret;
}