* Copyright (c) 2025-2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file distributed_context.h
* \brief
*/
#pragma once
#include <vector>
#include <utility>
#include <string>
#include "hccl_context.h"
#include "tilefwk/platform.h"
#include "interface/tileop/distributed/comm_context.h"
#include "securec.h"
namespace {
class TilingStructBase {
public:
TilingStructBase() {}
virtual ~TilingStructBase() {}
virtual int32_t MakeMc2TilingStruct(const std::string& groupName) = 0;
virtual void* GetMc2CommConfig() = 0;
private:
std::string groupName_{};
};
class TilingStruct : public TilingStructBase {
public:
TilingStruct() {}
~TilingStruct() {}
int32_t MakeMc2TilingStruct(const std::string& groupName) override
{
(void)memset_s(&Mc2CommConfig_, sizeof(Mc2CommConfig_), 0, sizeof(Mc2CommConfig_));
constexpr uint32_t version = 2;
constexpr uint32_t hcommCnt = 1;
constexpr uint32_t opTypeAllToAll = 6;
const char* algConfig = "AllGather=level0:ring";
Mc2CommConfig_.version = version;
Mc2CommConfig_.hcommCnt = hcommCnt;
Mc2CommConfig_.hcommCfg.skipLocalRankCopy = 0;
Mc2CommConfig_.hcommCfg.skipBufferWindowCopy = 0;
Mc2CommConfig_.hcommCfg.stepSize = 0;
Mc2CommConfig_.hcommCfg.opType = opTypeAllToAll;
if (strcpy_s(Mc2CommConfig_.hcommCfg.groupName, sizeof(Mc2CommConfig_.hcommCfg.groupName), groupName.c_str()) !=
EOK) {
return -1;
}
if (strcpy_s(Mc2CommConfig_.hcommCfg.algConfig, sizeof(Mc2CommConfig_.hcommCfg.algConfig), algConfig) != EOK) {
return -1;
}
return 0;
}
void* GetMc2CommConfig() override { return &Mc2CommConfig_; }
private:
npu::tile_fwk::Mc2CommConfig Mc2CommConfig_;
};
class TilingStructV2 : public TilingStructBase {
public:
TilingStructV2() {}
~TilingStructV2() {}
int32_t MakeMc2TilingStruct(const std::string& groupName) override
{
(void)memset_s(&Mc2CommConfig_, sizeof(Mc2CommConfig_), 0, sizeof(Mc2CommConfig_));
if (npu::tile_fwk::Platform::Instance().GetSoc().GetNPUArch() == npu::tile_fwk::NPUArch::DAV_3510) {
Mc2CommConfig_.inner.version = 100U;
Mc2CommConfig_.inner.commEngine = 3;
(void)memset_s(
Mc2CommConfig_.inner.reserved, sizeof(Mc2CommConfig_.inner.reserved), 0,
sizeof(Mc2CommConfig_.inner.reserved));
} else {
Mc2CommConfig_.inner.version = 1;
}
const char* algConfig = "BatchWrite=level0:fullmesh";
Mc2CommConfig_.init.version = 100U;
Mc2CommConfig_.init.mc2HcommCnt = 1;
Mc2CommConfig_.init.queueNum = 0;
Mc2CommConfig_.init.commBlockNum = 48U;
Mc2CommConfig_.init.devType = 4U;
Mc2CommConfig_.inner.skipLocalRankCopy = 0;
Mc2CommConfig_.inner.skipBufferWindowCopy = 0;
Mc2CommConfig_.inner.stepSize = 0;
Mc2CommConfig_.inner.opType = 18U;
Mc2CommConfig_.inner.version = 1;
Mc2CommConfig_.init.offset[0] = static_cast<uint32_t>(
reinterpret_cast<uint64_t>(&Mc2CommConfig_.inner) - reinterpret_cast<uint64_t>(&Mc2CommConfig_.init));
auto ret = strcpy_s(Mc2CommConfig_.inner.groupName, npu::tile_fwk::GROUP_NAME_SIZE, groupName.c_str());
if (ret != 0) {
return -1;
}
ret = strcpy_s(Mc2CommConfig_.inner.algConfig, npu::tile_fwk::ALG_CONFIG_SIZE, algConfig);
if (ret != 0) {
return -1;
}
return 0;
}
void* GetMc2CommConfig() override { return &Mc2CommConfig_; }
private:
npu::tile_fwk::Mc2CommConfigV2 Mc2CommConfig_;
};
}
namespace npu::tile_fwk::dynamic {
constexpr int WIN_TYPE_NUM = 3;
enum class ResType { RING_A2, MESH_A3, MESH_A5, UNKNOWN };
class DistributedContext {
public:
DistributedContext(){};
~DistributedContext(){};
static std::vector<uint64_t> GetCommContext(const std::vector<std::string>& groupNames);
static std::vector<uint64_t> GetCommContextToHost(const std::vector<std::string>& groupNames);
template <ResType T>
static uint64_t AllocCommContext(const uint64_t ctxAddr, const std::string& groupName);
private:
template <typename T>
static void FillCommCtxAttr(TileOp::CommContext* ctxHost, T* hcclParamhost);
template <typename T>
static void FillCommCtxWinArr(uint32_t i, TileOp::CommContext* ctxHost, T* hcclParamhost);
};
}