* Copyright (c) 2025-2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file distributed_context.cpp
* \brief
*/
#include "distributed_context.h"
#include <memory>
#include "tilefwk/pypto_fwk_log.h"
#include "adapter/api/hcomm_api.h"
#include "adapter/api/acl_api.h"
#include "interface/utils/common.h"
#include "tilefwk/error_code.h"
#include "machine/runtime/memory_utils/memory_pool.h"
#include "machine/runtime/context/stream_context.h"
#include "machine/device/dynamic/device_utils.h"
constexpr uint32_t COMM_IS_NOT_SET_DEVICE = 0;
constexpr uint32_t COMM_MESH = 0b1u;
constexpr uint32_t WINDATA_INDEX = 0;
constexpr uint32_t WINSTATUS_INDEX = 1;
constexpr uint32_t WINDEBUG_INDEX = 2;
constexpr uint64_t WIN_EXP_SIZE = 1024UL * 1024UL;
std::unordered_map<std::string, std::pair<uint64_t, uint64_t>>
g_context;
using namespace npu::tile_fwk;
namespace npu::tile_fwk::dynamic {
uint8_t* AllocHostAddr(size_t size)
{
if (size == 0 || size > 0x500000000) {
DISTRIBUTED_LOGE(DistributedErrorCode::CONTEXT_CONFIGURE_FAILED, "AllocHostAddr failed size %zu", size);
return nullptr;
}
auto hostPtr = (uint8_t*)malloc(size);
memset_s(hostPtr, size, 0, size);
return hostPtr;
}
template <typename T>
void DistributedContext::FillCommCtxAttr(TileOp::CommContext* ctxHost, T* hcclParamhost)
{
(void)ctxHost;
(void)hcclParamhost;
}
template <>
void DistributedContext::FillCommCtxAttr<npu::tile_fwk::HcclCombinOpParamA5>(
TileOp::CommContext* ctxHost, npu::tile_fwk::HcclCombinOpParamA5* hcclParamhost)
{
ctxHost->rankId = hcclParamhost->rankId;
ctxHost->rankNum = hcclParamhost->rankNum;
ctxHost->statusIndex = hcclParamhost->rankNum;
ctxHost->debugIndex = hcclParamhost->rankNum * 2;
ASSERT(DistributedErrorCode::CONTEXT_CONFIGURE_FAILED, hcclParamhost->winSize > WIN_EXP_SIZE)
<< "winSize: " << hcclParamhost->winSize << " need greater than WIN_EXP_SIZE: " << WIN_EXP_SIZE;
ctxHost->winDataSize = hcclParamhost->winSize - WIN_EXP_SIZE;
ctxHost->winStatusSize = WIN_EXP_SIZE;
ctxHost->winDebugSize = hcclParamhost->winSize;
ctxHost->totalWinNum = hcclParamhost->rankNum * WIN_TYPE_NUM;
}
template <>
void DistributedContext::FillCommCtxAttr<npu::tile_fwk::HcclCombinOpParam>(
TileOp::CommContext* ctxHost, npu::tile_fwk::HcclCombinOpParam* hcclParamhost)
{
ctxHost->rankId = hcclParamhost->rankId;
ctxHost->rankNum = hcclParamhost->rankNum;
ctxHost->statusIndex = hcclParamhost->rankNum;
ctxHost->debugIndex = hcclParamhost->rankNum * 2;
ctxHost->winDataSize = hcclParamhost->winSize;
ctxHost->winStatusSize = hcclParamhost->winExpSize;
ctxHost->winDebugSize = hcclParamhost->winSize;
ctxHost->totalWinNum = hcclParamhost->rankNum * WIN_TYPE_NUM;
}
template <>
void DistributedContext::FillCommCtxAttr<npu::tile_fwk::HcclOpResParamHead>(
TileOp::CommContext* ctxHost, npu::tile_fwk::HcclOpResParamHead* hcclParamhost)
{
ctxHost->rankId = hcclParamhost->localUsrRankId;
ctxHost->rankNum = hcclParamhost->rankSize;
ctxHost->statusIndex = hcclParamhost->rankSize;
ctxHost->debugIndex = hcclParamhost->rankSize * 2;
ctxHost->winDataSize = hcclParamhost->winSize;
ctxHost->winStatusSize = hcclParamhost->winSize;
ctxHost->winDebugSize = hcclParamhost->winExpSize;
ctxHost->totalWinNum = hcclParamhost->rankSize * WIN_TYPE_NUM;
}
template <typename T>
void DistributedContext::FillCommCtxWinArr(uint32_t i, TileOp::CommContext* ctxHost, T* hcclParamhost)
{
(void)i;
(void)ctxHost;
(void)hcclParamhost;
}
template <>
void DistributedContext::FillCommCtxWinArr<npu::tile_fwk::HcclCombinOpParamA5>(
uint32_t i, TileOp::CommContext* ctxHost, npu::tile_fwk::HcclCombinOpParamA5* hcclParamhost)
{
ctxHost->winAddr[i + (WINDATA_INDEX * ctxHost->rankNum)] = hcclParamhost->windowsIn[i] + WIN_EXP_SIZE;
ctxHost->winAddr[i + (WINSTATUS_INDEX * ctxHost->rankNum)] = hcclParamhost->windowsIn[i];
ctxHost->winAddr[i + (WINDEBUG_INDEX * ctxHost->rankNum)] = hcclParamhost->windowsOut[i];
}
template <>
void DistributedContext::FillCommCtxWinArr<npu::tile_fwk::HcclCombinOpParam>(
uint32_t i, TileOp::CommContext* ctxHost, npu::tile_fwk::HcclCombinOpParam* hcclParamhost)
{
ctxHost->winAddr[i + (WINDATA_INDEX * ctxHost->rankNum)] = hcclParamhost->windowsIn[i];
ctxHost->winAddr[i + (WINSTATUS_INDEX * ctxHost->rankNum)] = hcclParamhost->windowsExp[i];
ctxHost->winAddr[i + (WINDEBUG_INDEX * ctxHost->rankNum)] = hcclParamhost->windowsOut[i];
}
template <>
void DistributedContext::FillCommCtxWinArr<npu::tile_fwk::HcclOpResParamHead>(
uint32_t i, TileOp::CommContext* ctxHost, npu::tile_fwk::HcclOpResParamHead* hcclParamhost)
{
ctxHost->winAddr[i + (WINDATA_INDEX * ctxHost->rankNum)] = hcclParamhost->localWindowsIn;
ctxHost->winAddr[i + (WINSTATUS_INDEX * ctxHost->rankNum)] = hcclParamhost->localWindowsExp;
ctxHost->winAddr[i + (WINDEBUG_INDEX * ctxHost->rankNum)] = hcclParamhost->localWindowsOut;
}
template <>
void DistributedContext::FillCommCtxWinArr<npu::tile_fwk::HcclRankRelationResV2>(
uint32_t i, TileOp::CommContext* ctxHost, npu::tile_fwk::HcclRankRelationResV2* hcclParamhost)
{
ctxHost->winAddr[i + (WINDATA_INDEX * ctxHost->rankNum)] = hcclParamhost->windowsIn;
ctxHost->winAddr[i + (WINSTATUS_INDEX * ctxHost->rankNum)] = hcclParamhost->windowsExp;
ctxHost->winAddr[i + (WINDEBUG_INDEX * ctxHost->rankNum)] = hcclParamhost->windowsOut;
}
template <ResType T>
uint64_t DistributedContext::AllocCommContext(
[[maybe_unused]] const uint64_t ctxAddr, [[maybe_unused]] const std::string& groupName)
{
return 0;
}
uint64_t AllocateAndSetupCommContext(
void* paramHost, uint32_t rankNum, const std::string& groupName,
std::function<void(TileOp::CommContext*, void*)> fillAttrFunc,
std::function<void(uint32_t, TileOp::CommContext*, void*)> fillWinArrFunc)
{
size_t commCtxSize = sizeof(TileOp::CommContext) + sizeof(uint64_t) * rankNum * WIN_TYPE_NUM;
TileOp::CommContext* ctxHost = static_cast<TileOp::CommContext*>(static_cast<void*>(AllocHostAddr(commCtxSize)));
ASSERT(DistributedErrorCode::CONTEXT_CONFIGURE_FAILED, ctxHost != nullptr) << "ctxHost malloc failed";
fillAttrFunc(ctxHost, paramHost);
for (uint32_t i = 0; i < rankNum; i++) {
fillWinArrFunc(i, ctxHost, paramHost);
}
TileOp::CommContext* ctxDevice = nullptr;
DevMemoryPool::Instance().AllocDevAddr(reinterpret_cast<uint8_t**>(&ctxDevice), commCtxSize);
ASSERT(DistributedErrorCode::CONTEXT_CONFIGURE_FAILED, ctxDevice != nullptr) << "ctxDevice malloc failed";
auto ret = AclRtMemcpy(ctxDevice, commCtxSize, ctxHost, commCtxSize, AclRtMemcpyKind::HOST_TO_DEVICE);
ASSERT(DistributedErrorCode::CONTEXT_CONFIGURE_FAILED, ret == 0) << "AclRtMemcpy failed, error: " << ret;
g_context[groupName] = std::make_pair(reinterpret_cast<uint64_t>(ctxDevice), reinterpret_cast<uint64_t>(ctxHost));
return reinterpret_cast<uint64_t>(ctxDevice);
}
template <>
uint64_t DistributedContext::AllocCommContext<ResType::MESH_A5>(
[[maybe_unused]] const uint64_t ctxAddr, [[maybe_unused]] const std::string& groupName)
{
npu::tile_fwk::HcclCombinOpParamA5* hcclParamDevice = (npu::tile_fwk::HcclCombinOpParamA5*)ctxAddr;
npu::tile_fwk::HcclCombinOpParamA5* hcclParamhost =
(npu::tile_fwk::HcclCombinOpParamA5*)AllocHostAddr(sizeof(npu::tile_fwk::HcclCombinOpParamA5));
ASSERT(DistributedErrorCode::CONTEXT_CONFIGURE_FAILED, hcclParamhost != nullptr) << "hcclParamhost malloc failed";
size_t offsetRankId = offsetof(npu::tile_fwk::HcclCombinOpParamA5, rankId);
size_t offsetXnAddr = offsetof(npu::tile_fwk::HcclCombinOpParamA5, xnAddr);
auto ret = AclRtMemcpy(
&(hcclParamhost->rankId), offsetXnAddr - offsetRankId, &(hcclParamDevice->rankId), offsetXnAddr - offsetRankId,
AclRtMemcpyKind::DEVICE_TO_HOST);
ASSERT(DistributedErrorCode::CONTEXT_CONFIGURE_FAILED, ret == 0) << "AclRtMemcpy failed, error: " << ret;
return AllocateAndSetupCommContext(
hcclParamhost, hcclParamhost->rankNum, groupName,
[](TileOp::CommContext* ctx, void* param) {
auto* p = static_cast<npu::tile_fwk::HcclCombinOpParamA5*>(param);
FillCommCtxAttr(ctx, p);
},
[](uint32_t i, TileOp::CommContext* ctx, void* param) {
auto* p = static_cast<npu::tile_fwk::HcclCombinOpParamA5*>(param);
FillCommCtxWinArr(i, ctx, p);
});
}
template <>
uint64_t DistributedContext::AllocCommContext<ResType::MESH_A3>(
[[maybe_unused]] const uint64_t ctxAddr, [[maybe_unused]] const std::string& groupName)
{
npu::tile_fwk::HcclCombinOpParam* hcclParamDevice = (npu::tile_fwk::HcclCombinOpParam*)ctxAddr;
npu::tile_fwk::HcclCombinOpParam* hcclParamhost =
(npu::tile_fwk::HcclCombinOpParam*)AllocHostAddr(sizeof(npu::tile_fwk::HcclCombinOpParam));
ASSERT(DistributedErrorCode::CONTEXT_CONFIGURE_FAILED, hcclParamhost != nullptr) << "hcclParamhost malloc failed";
size_t offsetRankId = offsetof(npu::tile_fwk::HcclCombinOpParam, rankId);
size_t offsetHcomId = offsetof(npu::tile_fwk::HcclCombinOpParam, hcomId);
size_t offsetWinExpSize = offsetof(npu::tile_fwk::HcclCombinOpParam, winExpSize);
size_t offsetMultiServerFlag = offsetof(npu::tile_fwk::HcclCombinOpParam, multiServerFlag);
auto ret = AclRtMemcpy(
&(hcclParamhost->rankId), offsetHcomId - offsetRankId, &(hcclParamDevice->rankId), offsetHcomId - offsetRankId,
AclRtMemcpyKind::DEVICE_TO_HOST);
ASSERT(DistributedErrorCode::CONTEXT_CONFIGURE_FAILED, ret == 0) << "AclRtMemcpy failed, error: " << ret;
ret = AclRtMemcpy(
&(hcclParamhost->winExpSize), offsetMultiServerFlag - offsetWinExpSize, &(hcclParamDevice->winExpSize),
offsetMultiServerFlag - offsetWinExpSize, AclRtMemcpyKind::DEVICE_TO_HOST);
ASSERT(DistributedErrorCode::CONTEXT_CONFIGURE_FAILED, ret == 0) << "AclRtMemcpy failed, error: " << ret;
return AllocateAndSetupCommContext(
hcclParamhost, hcclParamhost->rankNum, groupName,
[](TileOp::CommContext* ctx, void* param) {
auto* p = static_cast<npu::tile_fwk::HcclCombinOpParam*>(param);
FillCommCtxAttr(ctx, p);
},
[](uint32_t i, TileOp::CommContext* ctx, void* param) {
auto* p = static_cast<npu::tile_fwk::HcclCombinOpParam*>(param);
FillCommCtxWinArr(i, ctx, p);
});
}
template <>
uint64_t DistributedContext::AllocCommContext<ResType::RING_A2>(
[[maybe_unused]] const uint64_t ctxAddr, [[maybe_unused]] const std::string& groupName)
{
npu::tile_fwk::HcclOpResParam* hcclParam = (npu::tile_fwk::HcclOpResParam*)ctxAddr;
HcclOpResParamHead* hcclParamhost = (HcclOpResParamHead*)AllocHostAddr(sizeof(HcclOpResParamHead));
ASSERT(DistributedErrorCode::CONTEXT_CONFIGURE_FAILED, hcclParamhost != nullptr) << "hcclParamhost malloc failed";
size_t offsetLocalUsrRankId = offsetof(npu::tile_fwk::HcclOpResParam, localUsrRankId);
size_t offsetRWinStart = offsetof(npu::tile_fwk::HcclOpResParam, rWinStart);
auto ret = AclRtMemcpy(&(hcclParamhost->localUsrRankId), offsetRWinStart - offsetLocalUsrRankId,
&(hcclParam->localUsrRankId), offsetRWinStart - offsetLocalUsrRankId, AclRtMemcpyKind::DEVICE_TO_HOST);
ASSERT(DistributedErrorCode::CONTEXT_CONFIGURE_FAILED, ret == 0) << "AclRtMemcpy failed, error: " << ret;
size_t remoteResSize = hcclParamhost->rankSize * sizeof(npu::tile_fwk::RemoteResPtr);
npu::tile_fwk::RemoteResPtr* remoteResPtr = (npu::tile_fwk::RemoteResPtr*)AllocHostAddr(remoteResSize);
ASSERT(DistributedErrorCode::CONTEXT_CONFIGURE_FAILED, remoteResPtr != nullptr) << "remoteResPtr malloc failed";
ret = AclRtMemcpy(remoteResPtr, remoteResSize, &(hcclParam->remoteRes), remoteResSize,
AclRtMemcpyKind::DEVICE_TO_HOST);
ASSERT(DistributedErrorCode::CONTEXT_CONFIGURE_FAILED, ret == 0) << "AclRtMemcpy failed, error: " << ret;
size_t commCtxSize = sizeof(TileOp::CommContext) + sizeof(uint64_t) * hcclParamhost->rankSize * WIN_TYPE_NUM;
TileOp::CommContext* ctxHost = (TileOp::CommContext*)AllocHostAddr(commCtxSize);
ASSERT(DistributedErrorCode::CONTEXT_CONFIGURE_FAILED, ctxHost != nullptr) << "ctxHost malloc failed";
FillCommCtxAttr<npu::tile_fwk::HcclOpResParamHead>(ctxHost, hcclParamhost);
for (uint64_t i = 0; i < hcclParamhost->rankSize; i++) {
if (i == hcclParamhost->localUsrRankId) {
FillCommCtxWinArr<npu::tile_fwk::HcclOpResParamHead>(i, ctxHost, hcclParamhost);
continue;
}
uint64_t remoteResDevicePtr;
ret = AclRtMemcpy(&remoteResDevicePtr, sizeof(uint64_t), &(remoteResPtr[i].nextDevicePtr), sizeof(uint64_t),
AclRtMemcpyKind::DEVICE_TO_HOST);
ASSERT(DistributedErrorCode::CONTEXT_CONFIGURE_FAILED, ret == 0) << "AclRtMemcpy failed, error: " << ret;
npu::tile_fwk::HcclRankRelationResV2 remoteParam;
ret = AclRtMemcpy(&remoteParam, sizeof(HcclRankRelationResV2), (void*)remoteResDevicePtr,
sizeof(HcclRankRelationResV2), AclRtMemcpyKind::DEVICE_TO_HOST);
ASSERT(DistributedErrorCode::CONTEXT_CONFIGURE_FAILED, ret == 0) << "AclRtMemcpy failed, error: " << ret;
FillCommCtxWinArr<npu::tile_fwk::HcclRankRelationResV2>(i, ctxHost, &remoteParam);
}
TileOp::CommContext* ctxDevice = nullptr;
DevMemoryPool::Instance().AllocDevAddr((uint8_t**)&ctxDevice, commCtxSize);
ASSERT(DistributedErrorCode::CONTEXT_CONFIGURE_FAILED, ctxDevice != nullptr) << "ctxDevice malloc failed";
ret = AclRtMemcpy(
ctxDevice, sizeof(TileOp::CommContext) + sizeof(uint64_t) * hcclParamhost->rankSize * WIN_TYPE_NUM, ctxHost,
sizeof(TileOp::CommContext) + sizeof(uint64_t) * hcclParamhost->rankSize * WIN_TYPE_NUM,
AclRtMemcpyKind::HOST_TO_DEVICE);
ASSERT(DistributedErrorCode::CONTEXT_CONFIGURE_FAILED, ret == 0) << "AclRtMemcpy failed, error: " << ret;
g_context[groupName].first = (uint64_t)ctxDevice;
g_context[groupName].second = (uint64_t)ctxHost;
return (uint64_t)ctxDevice;
}
std::vector<uint64_t> DistributedContext::GetCommContext([[maybe_unused]] const std::vector<std::string>& groupNames)
{
if (groupNames.size() == 0) {
return {};
}
HCommTopo topoRet;
const char* group = groupNames[0].c_str();
ASSERT(DistributedErrorCode::INVALID_HCCL_TOPO,
HcommGetL0TopoTypeEx(group, &topoRet, COMM_IS_NOT_SET_DEVICE) == HCOMM_SUCCESS) << "Get hcom topo failed";
uint32_t topoType = static_cast<uint32_t>(topoRet);
std::shared_ptr<TilingStructBase> tilingStruct;
if (topoType == COMM_MESH) {
tilingStruct = std::make_shared<TilingStruct>();
} else {
tilingStruct = std::make_shared<TilingStructV2>();
}
std::vector<uint64_t> commContext(groupNames.size(), 0);
ASSERT(DistributedErrorCode::INVALID_GROUP_COUNT, groupNames.size() <= DIST_COMM_GROUP_NUM)
<< "The number of communication groups (" << groupNames.size() << ") exceeds the maximum limit ("
<< DIST_COMM_GROUP_NUM << ").";
for (size_t groupIndex = 0; groupIndex < groupNames.size(); ++groupIndex) {
auto groupName = groupNames[groupIndex];
if (g_context.find(groupName) != g_context.end()) {
commContext[groupIndex] = g_context[groupName].first;
continue;
}
HcommHandle commHandle = nullptr;
ASSERT(DistributedErrorCode::INVALID_GROUP_NAME,
HcommGetCommHandleByGroup(groupName.c_str(), &commHandle) == 0) << "Get hcom handle failed";
tilingStruct->MakeMc2TilingStruct(groupName);
auto ret = HcommAllocComResourceByTiling(
commHandle, GetStreamContext().GetAiCoreStream(), tilingStruct->GetMc2CommConfig(),
reinterpret_cast<void**>(&commContext[groupIndex]));
ASSERT(DistributedErrorCode::HCCL_ALLOC_RESOURCE_FAILED, (ret == 0) && (commContext[groupIndex] != 0UL))
<< "Hccl alloc resource failed";
DISTRIBUTED_LOGI(
"groupIndex=%zu, groupName=%s, commContext=%lu", groupIndex, groupName.c_str(), commContext[groupIndex]);
if (Platform::Instance().GetSoc().GetNPUArch() == NPUArch::DAV_3510) {
commContext[groupIndex] = AllocCommContext<ResType::MESH_A5>(commContext[groupIndex], groupName);
} else if (topoType == COMM_MESH) {
commContext[groupIndex] = AllocCommContext<ResType::MESH_A3>(commContext[groupIndex], groupName);
} else {
commContext[groupIndex] = AllocCommContext<ResType::RING_A2>(commContext[groupIndex], groupName);
}
}
return commContext;
}
std::vector<uint64_t> DistributedContext::GetCommContextToHost(
[[maybe_unused]] const std::vector<std::string>& groupNames)
{
std::vector<uint64_t> devAddrs = GetCommContext(groupNames);
std::vector<uint64_t> hostContext;
ASSERT(DistributedErrorCode::INVALID_GROUP_COUNT, groupNames.size() <= DIST_COMM_GROUP_NUM)
<< "The number of communication groups (" << groupNames.size() << ") exceeds the maximum limit ("
<< DIST_COMM_GROUP_NUM << ").";
for (size_t groupIndex = 0; groupIndex < groupNames.size(); groupIndex++) {
std::string groupName = groupNames[groupIndex];
hostContext.push_back((uint64_t)g_context[groupName].second);
}
return hostContext;
}
}