aclnnMoeDistributeDispatchSetup
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | × |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | × |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | × |
功能说明
- 接口说明:
-
对Token数据进行量化(可选),根据token选择的topK专家在EP(Expert Parallelism)域的AllToAllV通信,只进行数据发送和通信状态发送,通信指令发出后算子即刻退出,无需等待通信完成。数据的接收和后处理由aclnnMoeDistributeDispatchTeardown接口完成。
-
注意该接口必须与aclnnMoeDistributeDispatchTeardown,aclnnMoeDistributeCombineSetup,aclnnMoeDistributeCombineTeardown配套使用。
-
函数原型
每个算子分为两段式接口,必须先调用 “aclnnMoeDistributeDispatchSetupGetWorkspaceSize”接口获取计算所需workspace大小以及包含了算子计算流程的执行器,再调用“aclnnMoeDistributeDispatchSetup”接口执行计算。
aclnnStatus aclnnMoeDistributeDispatchSetupGetWorkspaceSize(
const aclTensor* x,
const aclTensor* expertIds,
const aclTensor* scalesOptional,
const aclTensor* xActiveMaskOptional,
const char* groupEp,
int64_t epWorldSize,
int64_t epRankId,
int64_t moeExpertNum,
int64_t expertShardType,
int64_t sharedExpertNum,
int64_t sharedExpertRankNum,
int64_t quantMode,
int64_t globalBs,
int64_t commType,
const char* commAlg,
aclTensor* yOut,
aclTensor* expandIdxOut,
aclTensor* commCmdInfoOut,
uint64_t* workspaceSize,
aclOpExecutor** executor)
aclnnStatus aclnnMoeDistributeDispatchSetup(
void *workspace,
uint64_t workspaceSize,
aclOpExecutor *executor,
aclrtStream stream)
aclnnMoeDistributeDispatchSetupGetWorkspaceSize
-
参数说明
参数名 输入/输出 描述 使用说明 数据类型 数据格式 维度(shape) 非连续Tensor x 输入 表示本卡发送的token数据。 要求为2D Tensor。 FLOAT16、BFLOAT16 ND (Bs, H) √ expertIds 输入 每个token的topK个专家索引。 要求为2D Tensor。 INT32 ND (Bs, K) √ scalesOptional 输入 每个专家的量化平滑参数。 要求为2D Tensor。非量化场景传空指针,动态量化可选择传入有效数据或传入空指针。 FLOAT32 ND (sharedExpertNum + moeExpertNum, H) √ xActiveMaskOptional 输入 表示token是否参与通信。 要求为1D Tensor。可选择传入有效数据或传入空指针,传入空指针时是表示所有token都会参与通信。 BOOL ND (Bs, ) √ groupEp 输入 EP通信域名称(专家并行通信域)。 字符串长度范围为[1, 128)。 STRING - - - epWorldSize 输入 EP通信域大小。 - INT64 - - - epRankId 输入 EP域本卡Id。 取值范围[0, epWorldSize),同一个EP通信域中各卡的epRankId不重复。 INT64 - - - moeExpertNum 输入 MoE专家数量。 取值范围(0, 512]。满足moeExpertNum % (epWorldSize - sharedExpertRankNum) = 0。 INT64 - - - expertShardType 输入 表示共享专家卡分布类型。 当前仅支持传入0,表示共享专家卡排在MoE专家卡前面。 INT64 - - - sharedExpertNum 输入 表示共享专家数量。 取值范围[0, 4]。 INT64 - - - sharedExpertRankNum 输入 表示共享专家卡数量。 取值范围[0, epWorldSize / 2]。 INT64 - - - quantMode 输入 表示量化模式。 取值范围[0, 4]。0表示非量化,1表示静态量化,2表示Pertoken动态量化,3表示Pergroup动态量化,4表示MX量化,当前仅支持0和4。 INT64 - - - globalBs 输入 EP域全局的batch size大小。 - 各rank Bs一致时,globalBs = Bs * epWorldSize 或 0。
- 各rank Bs不一致时,globalBs = maxBs * epWorldSize(maxBs为单卡Bs最大值)。
INT64 - - - commType 输入 表示通信方案选择。 取值范围[0, 2],0表示AICPU-SDMA方案,1表示CCU方案,2表示URMA方案,当前版本仅支持2。 INT64 - - - commAlg 输入 表示通信亲和内存布局算法。 预留字段,当前版本不支持,传空指针或空字符串即可。 STRING - - - yOut 输出 表示本卡待发送的通信数据,通信数据对输入token数据做了算法重排;如需量化,先将输入token做量化处理,再对数据做重排。 要求为2D Tensor。 FLOAT16、BFLOAT16、INT8、HiFP8、FP8E5M2、FP8E4M3 ND (BS * (K + sharedExpertNum), tokenMsgSize) √ expandIdxOut 输出 表示给同一专家发送的token个数,对应Combine系列算子中的expandIdx。 要求为1D Tensor。 INT32 ND (BS * K, ) √ commCmdInfoOut 输出 通信的cmd信息 要求为1D Tensor。 INT32 ND (BS * (K + sharedExpertNum) + epWorldSize * localExpertNum) * 16 √ workspaceSize 输出 返回需要在Device侧申请的workspace大小。 - - - - - executor 输出 返回op执行器,包含了算子的计算流程。 - - - - - -
Ascend 950PR/Ascend 950DT:
- scalesOptional 非量化场景传空指针,动态量化可选择传入有效数据或传入空指针。
- xActiveMaskOptional 可选择传入有效数据或传入空指针,传入空指针时表示所有token都会参与通信。
- groupEp 字符串长度范围为[1, 128)。
- epWorldSize 取值范围[2, 384]。当前仅支持2、8。
- epRankId 取值范围[0, epWorldSize)。同一个EP通信域中各卡的epRankId不能重复。
- moeExpertNum 取值范围(0, 512]。
- expertShardType 当前仅支持传0,表示共享专家卡排在MoE专家卡前面。
- sharedExpertNum 当前取值范围[0, 4]。
- sharedExpertRankNum 取值范围[0, epWorldSize / 2]。
- globalBs 当每个rank的Bs数一致场景下,globalBs = Bs * epWorldSize 或 globalBs = 0;当每个rank的Bs数不一致场景下,globalBs = maxBs * epWorldSize,其中maxBs表示单卡Bs最大值。
- commType 当前仅支持2。
- commAlg 当前版本不支持,传空指针即可。
-
Atlas A3 训练系列产品/Atlas A3 推理系列产品:
- scalesOptional 非量化场景传空指针,动态量化可选择传入有效数据或传入空指针。
- xActiveMaskOptional 可选择传入有效数据或传入空指针,传入空指针时表示所有token都会参与通信。
- groupEp 字符串长度范围为[1, 128)。
- epWorldSize 取值范围[2, 384]。当前仅支持2、8。
- epRankId 取值范围[0, epWorldSize)。同一个EP通信域中各卡的epRankId不能重复。
- moeExpertNum 取值范围(0, 512]。
- expertShardType 当前仅支持传0,表示共享专家卡排在MoE专家卡前面。
- sharedExpertNum 当前取值范围[0, 4]。
- sharedExpertRankNum 取值范围[0, epWorldSize / 2]。
- globalBs 当每个rank的Bs数一致场景下,globalBs = Bs * epWorldSize 或 globalBs = 0;当每个rank的Bs数不一致场景下,globalBs = maxBs * epWorldSize,其中maxBs表示单卡Bs最大值。
- commType 当前仅支持0。
- commAlg 当前版本不支持,传空指针即可。
-
返回值:
返回aclnnStatus状态码,具体参见aclnn返回码。
第一段接口完成入参校验,出现以下场景时报错:
返回值 错误码 描述 ACLNN_ERR_PARAM_NULLPTR 161001 输入和输出的必选参数Tensor是空指针。 ACLNN_ERR_PARAM_INVALID 161002 输入和输出的数据类型不在支持的范围内。 ACLNN_ERR_INNER_TILING_ERROR 561002 输入和输出的shape不在支持的范围内。 参数的取值不在支持的范围内。
aclnnMoeDistributeDispatchSetup
-
参数说明
参数名 输入/输出 描述 workspace 输入 在Device侧申请的workspace内存地址。 workspaceSize 输入 在Device侧申请的workspace大小,由第一段接口aclnnMoeDistributeDispatchSetupGetWorkspaceSize获取。 executor 输入 op执行器,包含了算子计算流程。 stream 输入 指定执行任务的Stream。 -
返回值
返回aclnnStatus状态码,具体参见aclnn返回码。
约束说明
-
确定性计算:
- aclnnMoeDistributeDispatchSetup默认确定性实现。
-
aclnnMoeDistributeDispatchSetup接口与aclnnMoeDistributeDispatchTeardown,aclnnMoeDistributeCombineSetup,aclnnMoeDistributeCombineTeardown接口必须配套使用。
-
调用接口过程中使用的
groupEp、epWorldSize、moeExpertNum、expertShardType、sharedExpertNum、sharedExpertRankNum、globalBs、commQuantMode、commType、commAlg参数取值所有卡需保持一致,groupEp、epWorldSize、expertShardType、sharedExpertNum、sharedExpertRankNum、globalBs、commQuantMode、commType、commAlg参数取值在网络中不同层中也需保持一致,且和aclnnMoeDistributeDispatchTeardown,aclnnMoeDistributeCombineSetup,aclnnMoeDistributeCombineTeardown对应参数也保持一致。 -
Atlas A3 训练系列产品/Atlas A3 推理系列产品:该场景下单卡包含双DIE(简称为“晶粒”或“裸片”),因此参数说明里的“本卡”均表示单DIE。
-
参数说明里shape格式说明:
-
A:表示本卡可能接收的最大token数量,取值范围如下:
- 对于MoE专家,当
globalBs为0时,要满足A >=BS*epWorldSize* min(localExpertNum,K);当globalBs非0时,要满足A >=globalBs* min(localExpertNum,K)。 - 对于共享专家,当
globalBs为0时,要满足A =BS*epWorldSize*sharedExpertNum/sharedExpertRankNum;当globalBs非0时,要满足A =globalBs*sharedExpertNum/sharedExpertRankNum。
- 对于MoE专家,当
-
H:表示hidden size隐藏层大小,取值范围[1024, 8192]。当前仅支持4096、7168。
-
BS:表示batch sequence size,即本卡最终输出的token数量,取值范围为0 < BS ≤ 512。当前仅支持8、16、256。
-
K:表示选取topK个专家,取值范围为0 <
K≤ 16同时满足0 <K≤moeExpertNum。当前仅支持6、8。 -
localExpertNum:表示本卡专家数量。
- 对于共享专家卡,localExpertNum = 1
- 对于MoE专家卡,localExpertNum =
moeExpertNum/ (epWorldSize-sharedExpertRankNum)。moeExpertNum当前仅支持32。
-
tokenMsgSize:表示每个token在数据通信时的维度信息。
- 非量化场景下,tokenMsgSize = Align256(H)。
- 量化场景下,tokenMsgSize = Align512(Align32(H) + 4 ),其中AlignN(x) = ((x + N - 1) / N) * N。
-
当前版本暂不支持共享专家。sharedExpertNum和sharedExpertRankNum当前仅支持0。
-
-
HCCL_BUFFSIZE:
调用本接口前需检查
HCCL_BUFFSIZE环境变量取值是否合理,该环境变量表示单个通信域占用内存大小,单位MB,不配置时默认为200MB。要求 >= 2且满足>= 4 * (localExpertNum*maxBs*epWorldSize* Align512(Align32(2 * H) + 44) + (K+sharedExpertNum) *maxBs* Align512(2 *H)),localExpertNum代表使用MoE专家卡的本卡专家数,其中Align512(x) = ((x + 512 - 1) / 512) * 512,Align32(x) = ((x + 32 - 1) / 32) * 32。 -
通信域使用约束:
- 一个模型中的aclnnMoeDistributeDispatchSetup接口,aclnnMoeDistributeDispatchTeardown接口,aclnnMoeDistributeCombineSetup接口,aclnnMoeDistributeCombineTeardown接口仅支持相同EP通信域,且该通信域中不允许有其他算子。
调用示例
-
文件准备:
-
按照下方指导创建rank_table_m2.json文件并修改。
-
将项目拷贝到两台服务器中,并根据机器的device ip配置rank_table_m2.json文件内容。注意两机rank_table_m2.json文件保持一致。
-
安装cann包,并根据算子调用编译运行。
-
-
关于rankTable:
-
开发者可以通过ranktable文件配置参与集合通信的NPU资源信息,详细配置请参考《集合通信用户指南》中“通信功能开发>集群信息配置>ranktable文件配置资源信息”。
-
使用
cat /etc/hccn.conf或者for i in seq 0 7; do echo "===================> dev$i, NPU$((i+1))"; hccn_tool -i $i -ip -g; done查询机器的device ip。然后参考集合通信文档填写json文件。
注意:两机16卡场景中,两机器的device_id都是0~7,其中一台机器的rank_id为0~7,另一台机器的rank_id为8~15。单机16卡场景中,device_id和rank_id都是0~15。
-
-
环境变量配置:
# 运行前需设置三个环境变量 ## FIRST_RANK_ID说明:以两机16卡为例,其中一机器设置为0,另一机器设置为8 ## 如export FIRST_RANK_ID=0 export RANK_TABLE_FILE=/home/path/to/rank_table_m2.json export FIRST_RANK_ID=<设备的起始rank_id> ## ENV_DEV_NUM说明:根据当前机器的卡数设置该变量,以两机16卡为例,将两台机器设置为16 export ENV_DEV_NUM=16 -
机器数量设置:
两机16卡场景中,需将参数MACHINE_NUM设置为2,即
const uint32_t MACHINE_NUM = 2;单机16卡场景则无需修改。
-
Ascend 950PR/Ascend 950DT:
-
环境变量配置:
# 运行前需设置RANK_TABLE_FILE环境变量 export RANK_TABLE_FILE=/home/path/to/rank_table_m2.json
示例代码如下,仅供参考,具体编译和执行过程请参考编译与运行样例。
#include <thread> #include <iostream> #include <string> #include <vector> #include <random> #include <fstream> #include <getopt.h> #include <cstring> #include <sys/wait.h> #include <bits/stdc++.h> #include <unistd.h> #include "acl/acl.h" #include "hccl/hccl.h" #include "aclnnop/aclnn_moe_distribute_dispatch_setup.h" #include "aclnnop/aclnn_moe_distribute_dispatch_teardown.h" int64_t GetShapeSize(const std::vector<int64_t> &shape) { int64_t shapesize = 1; for (auto i : shape) { shapesize *= i; } return shapesize; } static int64_t GetCeil(int64_t a, int64_t b) { return (a + b - 1) / b; } #define CHECK_RET(cond, return_expr) \ do { \ if (!(cond)) { \ return_expr; \ } \ } while (0) #define LOG_PRINT(message, ...) \ do { \ printf(message, ##__VA_ARGS__); \ } while (0) struct Args { uint32_t rankId; uint32_t epRankId; uint32_t tpRankId; HcclComm hcclEpComm; HcclComm hcclTpComm; aclrtStream dispatchsetupstream; aclrtStream dispatchteardownstream; aclrtContext context; }; // 全局参数 int64_t g_bs = 8; int64_t g_h = 7168; int64_t g_k = 8; int64_t g_expertSharedType = 0; int64_t g_sharedExpertNum = 0; int64_t g_sharedExpertRankNum = 0; int64_t g_moeExpertNum = 32; int64_t g_quantMode = 0; int64_t g_expertTokenNumsType = 1; int64_t g_outDtype = 0; int64_t g_commQuantMode = 0; // 量化场景下使用参数 // 静态量化场景下,是不是所有专家共享scales // 0:所有专家不共享scale // 1:所有专家共享scales,且维度为(h,) // 2:所有专家共享scales,且维度为(1,) int64_t g_isStaticSharedScales = 0; int64_t g_staticSharedScalesOne = 1; int64_t g_staticSharedScalesTwo = 2; int64_t g_groupListType = 1; int64_t g_rankId = 0; int64_t g_epWorldSize = 2; int64_t g_tpWorldSize = 1; int64_t g_hasSmoothScale = 0; int64_t FIRST_RANK_ID = 0; int64_t commType = 2; int64_t globalBS = g_bs * g_epWorldSize; int64_t localExpertNum = g_moeExpertNum / (g_epWorldSize - g_sharedExpertRankNum); int64_t localToken = globalBS * (localExpertNum < g_k ? localExpertNum : g_k); uint64_t tokenMsgSize = 0; uint64_t expandIdxOutSize = 16; uint64_t assistInfoForCombineOutSize = localToken * 128; bool hasSmoothScale = g_hasSmoothScale != 0; int64_t g_timeOut = 100000000; int64_t g_hcclBufferSize = 12000; std::string g_xDtype = "fp16"; std::string g_caseName = "./"; std::string g_ranktablePath = "./"; std::string g_scaleType = "fp32"; std::string g_expandxDtype = "fp16"; std::string g_yOutDtype = "fp16"; /* 根据当前场景,构造device侧输入输出变量 */ // 声明device侧输入输出变量 void *xDeviceAddr = nullptr; void *expertIdsDeviceAddr = nullptr; void *scalesDeviceAddr = nullptr; void *expertScalesDeviceAddr = nullptr; void *expandXDeviceAddr = nullptr; void *dynamicScalesDeviceAddr = nullptr; void *expandIdxDeviceAddr = nullptr; void *expertTokenNumsDeviceAddr = nullptr; void *epRecvCountsDeviceAddr = nullptr; void *tpRecvCountsDeviceAddr = nullptr; void *expandScalesDeviceAddr = nullptr; void *yOutDeviceAddr = nullptr; void *assistInfoForCombineOutDeviceAddr = nullptr; void *quantExpandXOutDeviceAddr = nullptr; void *xOutDeviceAddr = nullptr; void *commCmdInfoOutDeviceAddr = nullptr; void *commCmdInfoOutShapeforcombineDeviceAddr = nullptr; aclTensor *x = nullptr; aclTensor *expertIds = nullptr; aclTensor *scalesOptional = nullptr; aclTensor *xActiveMaskOptional = nullptr; aclTensor *yOut = nullptr; aclTensor *xOut = nullptr; aclTensor *sharedExpertXOptional = nullptr; aclTensor *scales = nullptr; aclTensor *expertScales = nullptr; aclTensor *expandX = nullptr; aclTensor *dynamicScales = nullptr; aclTensor *expandIdx = nullptr; aclTensor *expertTokenNums = nullptr; aclTensor *epRecvCounts = nullptr; aclTensor *tpRecvCounts = nullptr; aclTensor *expandScales = nullptr; aclTensor *assistInfoForCombineOut = nullptr; aclTensor *quantExpandXOut = nullptr; aclTensor *commCmdInfoOut = nullptr; aclTensor *commCmdInfoOutforCombine = nullptr; // 定义当前场景下各变量维度 std::vector<int64_t> xShape{g_bs, g_h}; std::vector<int64_t> expertIdsShape{g_bs, g_k}; std::vector<int64_t> scalesShape{(g_sharedExpertRankNum > 0) ? 1 + g_moeExpertNum : g_moeExpertNum, g_h}; std::vector<int64_t> scaleshif8Shape{1}; std::vector<int64_t> scalesCommonShape_h{g_h}; std::vector<int64_t> scalesCommonShape_1{1}; std::vector<int64_t> expertScalesShape{g_bs, g_k}; std::vector<int64_t> expandXShape{g_tpWorldSize * localToken, g_h}; std::vector<int64_t> scalesmodel2Shape{g_bs, g_h}; std::vector<int64_t> yOutShape{g_bs * (g_k + g_sharedExpertNum), 7168}; std::vector<int64_t> dynamicScalesShape_H{g_tpWorldSize * localToken, g_h}; std::vector<int64_t> dynamicScalesShape_ceil32{g_tpWorldSize * localToken, GetCeil(g_h, 32)}; std::vector<int64_t> dynamicScalesShape_ceil128{g_tpWorldSize * localToken, GetCeil(g_h, 128)}; std::vector<int64_t> dynamicScalesShape{g_tpWorldSize * localToken}; std::vector<int64_t> dynamicModel2ScalesShape{localToken * g_h}; std::vector<int64_t> expandIdxShape{g_bs * g_k}; std::vector<int64_t> expertTokenNumsShape{localExpertNum}; std::vector<int64_t> epRecvCountsShape{g_tpWorldSize * localExpertNum * g_epWorldSize}; std::vector<int64_t> tpRecvCountsShape{g_tpWorldSize * localExpertNum}; std::vector<int64_t> expandScalesShape{localToken}; std::vector<int64_t> assistInfoForCombineOutShape{assistInfoForCombineOutSize}; std::vector<int64_t> quantExpandXOutShape{g_tpWorldSize * localToken, 10752}; std::vector<int64_t> commCmdInfoOutShape{(g_bs * (g_k + g_sharedExpertNum) + g_epWorldSize * localExpertNum) * 16}; std::vector<int64_t> commCmdInfoOutShapeforcombine{1056}; std::vector<int64_t> xOutShape{g_bs, g_h}; long long xShapeSize = GetShapeSize(xShape); int64_t expertIdsShapeSize = GetShapeSize(expertIdsShape); int64_t scalesShapeSize = GetShapeSize(scalesShape); int64_t scalesmodel2ShapeSize = GetShapeSize(scalesmodel2Shape); int64_t expertScalesShapeSize = GetShapeSize(expertScalesShape); int64_t expandXShapeSize = GetShapeSize(expandXShape); int64_t yOutShapeSize = GetShapeSize(yOutShape); int64_t dynamicScalesShapeSize = GetShapeSize(dynamicScalesShape); int64_t dynamicModel2ScalesShapeSize = GetShapeSize(dynamicModel2ScalesShape); int64_t dynamicScalesShapeSizeH = GetShapeSize(dynamicScalesShape_H); int64_t dynamicScalesShapeSizeCeil32 = GetShapeSize(dynamicScalesShape_ceil32); int64_t dynamicScalesShapeSizeCeil128 = GetShapeSize(dynamicScalesShape_ceil128); int64_t expandIdxShapeSize = GetShapeSize(expandIdxShape); int64_t expertTokenNumsShapeSize = GetShapeSize(expertTokenNumsShape); int64_t epRecvCountsShapeSize = GetShapeSize(epRecvCountsShape); int64_t tpRecvCountsShapeSize = GetShapeSize(tpRecvCountsShape); int64_t expandScalesShapeSize = GetShapeSize(expandScalesShape); int64_t assistInfoForCombineOutShapeSize = GetShapeSize(assistInfoForCombineOutShape); int64_t commCmdInfoOutShapeSize = GetShapeSize(commCmdInfoOutShape); int64_t quantExpandXOutShapeSize = GetShapeSize(quantExpandXOutShape); int64_t xOutShapeSize = GetShapeSize(xOutShape); int64_t commCmdInfoOutShapeSizeforcombine = GetShapeSize(commCmdInfoOutShapeforcombine); // 构造host侧变量 std::vector<int16_t> xHostMode1Data(xShapeSize, 1); std::vector<int8_t> xHostMode2Data(xShapeSize, 1); std::vector<int32_t> expertIdsHostData(expertIdsShapeSize, 0); std::vector<float> scalesModel2FloatHostData(scalesmodel2ShapeSize, 0); std::vector<int8_t> scalesMode2Int8HostData(scalesmodel2ShapeSize, 0); std::vector<float> scalesHostData(scalesShapeSize, 0); std::vector<int8_t> scalesMxfp8HostData(scalesShapeSize, 0); std::vector<float> expertScalesHostData(expertScalesShapeSize, 0); std::vector<float> dynamicScalesHostData(dynamicScalesShapeSize, 0); std::vector<float> dynamicModel2ScalesHostDataFloat(dynamicModel2ScalesShapeSize, 0); std::vector<int8_t> dynamicModel2ScalesHostDataInt8(dynamicModel2ScalesShapeSize, 0); std::vector<float> dynamicScalesHostDataH(dynamicScalesShapeSizeH, 0); std::vector<int8_t> dynamicScalesHostDataCeil32(dynamicScalesShapeSizeCeil32, 0); std::vector<int8_t> dynamicScalesHostDataCeil128(dynamicScalesShapeSizeCeil128, 0); std::vector<int8_t> dynamicScalesMxfp8HostData(dynamicScalesShapeSize, 0); std::vector<float> expandIdxHostData(expandIdxShapeSize, 0); std::vector<int64_t> expertTokenNumsHostData(expertTokenNumsShapeSize, 0); std::vector<int32_t> epRecvCountsHostData(epRecvCountsShapeSize, 0); std::vector<int32_t> tpRecvCountsHostData(tpRecvCountsShapeSize, 0); std::vector<float> expandScalesHostData(expandScalesShapeSize, 0); std::vector<int16_t> expandXHostData(expandXShapeSize, 0); std::vector<int8_t> expandXQuantHostData(expandXShapeSize, 0); std::vector<float> yOutHostData(yOutShapeSize, 0); std::vector<int32_t> AssistInfoForCombineOutHostData(assistInfoForCombineOutShapeSize, 0); std::vector<int32_t> commCmdInfoOutHostData(commCmdInfoOutShapeSize, 0); std::vector<int8_t> quantExpandXOutHostData(quantExpandXOutShapeSize, 0); std::vector<float> xOutHostData(xOutShapeSize, 0); std::vector<int32_t> commCmdInfoOutforCombineHostData(commCmdInfoOutShapeSizeforcombine, 0); // 0: dispatch算子非量化场景 // 1: 静态量化 // 2: pertoken动态量化 // 3: pergroup动态量化 // 4: MX量化 int64_t g_noQuant = 0; int64_t g_staticQuant = 1; int64_t g_pertokenQuant = 2; int64_t g_pergroupQuant = 3; int64_t g_mxQuant = 4; /* 声明算子执行必需变量 */ uint64_t workspaceSize = 0; aclOpExecutor *executor = nullptr; void *workspaceAddr = nullptr; std::map<std::string, aclDataType> x_dtypeIn = { {"bf16", ACL_BF16}, {"fp16", ACL_FLOAT16}, {"fp8_e5m2", ACL_FLOAT8_E5M2}, {"fp8_e4m3fn", ACL_FLOAT8_E4M3FN}, {"hif8", ACL_HIFLOAT8}, }; std::map<std::string, aclDataType> expandx_dtypeOut = { {"bf16", ACL_BF16}, {"fp16", ACL_FLOAT16}, {"int8", ACL_INT8}, {"fp8_e5m2", ACL_FLOAT8_E5M2}, {"fp8_e4m3fn", ACL_FLOAT8_E4M3FN}, {"hif8", ACL_HIFLOAT8}, }; std::map<std::string, aclDataType> scaleIn = { {"fp32", ACL_FLOAT}, {"fp8_e8m0", ACL_FLOAT8_E8M0}, }; std::vector<int32_t> generate_random_vector(int64_t g_bs, int64_t g_k, int seed) { // 为expert_id生成一个值随机为0或1的vec // 总大小 int total_size = g_bs * g_k; // 使用随机数引擎和分布 std::mt19937 gen(seed); // Mersenne Twister 引擎 std::uniform_int_distribution<> dis(0, 1); // 生成 0 或 1 // 创建并填充向量 std::vector<int32_t> result; result.reserve(total_size); // 预分配内存,提升性能 for (int i = 0; i < total_size; ++i) { result.push_back(dis(gen)); } return result; } template <typename T> int CreateAclTensor(const std::vector<T> &hostData, const std::vector<int64_t> &shape, void **deviceAddr, aclDataType dataType, aclTensor **tensor, uint32_t &rankId, const std::string &name) { auto size = GetShapeSize(shape) * sizeof(T); auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtMalloc failed. ret: %d\n", ret); return ret); ret = aclrtMemcpy(*deviceAddr, size, hostData.data(), size, ACL_MEMCPY_HOST_TO_DEVICE); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtMemcpy failed. ret: %d\n", ret); return ret); std::vector<int64_t> strides(shape.size(), 1); for (int64_t i = shape.size() - 2; i >= 0; i--) { strides[i] = shape[i + 1] * strides[i + 1]; } *tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND, shape.data(), shape.size(), *deviceAddr); return 0; } int CreateAclTensorX(Args &args) { int ret = 0; if (g_quantMode == g_noQuant && (g_xDtype == "fp8_e5m2" || g_xDtype == "fp8_e4m3fn" || g_xDtype == "hif8")) { ret = CreateAclTensor(xHostMode2Data, xShape, &xDeviceAddr, x_dtypeIn[g_xDtype], &x, args.rankId, "x"); } else { ret = CreateAclTensor(xHostMode1Data, xShape, &xDeviceAddr, x_dtypeIn[g_xDtype], &x, args.rankId, "x"); } CHECK_RET(ret == ACL_SUCCESS, return ret); return ret; } int CreateAclTensorExpertIds(Args &args) { int ret = 0; ret = CreateAclTensor(expertIdsHostData, expertIdsShape, &expertIdsDeviceAddr, aclDataType::ACL_INT32, &expertIds, args.rankId, "expertIds"); CHECK_RET(ret == ACL_SUCCESS, return ret); return ret; } int CreateAclTensorScales(Args &args) { int ret = 0; if (g_quantMode == g_noQuant && (x_dtypeIn[g_xDtype] == ACL_BF16 || x_dtypeIn[g_xDtype] == ACL_FLOAT16)) { scales = nullptr; } else if ((g_quantMode == g_noQuant) && (x_dtypeIn[g_xDtype] == ACL_HIFLOAT8)) { ret = CreateAclTensor(scalesHostData, scalesShape, &scalesDeviceAddr, aclDataType::ACL_FLOAT, &scales, args.rankId, "scales"); } else if ((g_quantMode == g_noQuant) && (x_dtypeIn[g_xDtype] == ACL_FLOAT8_E4M3FN || x_dtypeIn[g_xDtype] == ACL_FLOAT8_E5M2)) { if (scaleIn[g_scaleType] == ACL_FLOAT) { ret = CreateAclTensor(scalesModel2FloatHostData, scalesmodel2Shape, &scalesDeviceAddr, aclDataType::ACL_FLOAT, &scales, args.rankId, "scales"); } else if (scaleIn[g_scaleType] == ACL_FLOAT8_E8M0) { ret = CreateAclTensor(scalesMode2Int8HostData, scalesmodel2Shape, &scalesDeviceAddr, aclDataType::ACL_FLOAT8_E8M0, &scales, args.rankId, "scales"); } } else if ((g_quantMode == g_staticQuant) && (x_dtypeIn[g_xDtype] == ACL_BF16 || x_dtypeIn[g_xDtype] == ACL_FLOAT16)) { if (g_isStaticSharedScales == g_staticSharedScalesOne) { ret = CreateAclTensor(scalesHostData, scalesCommonShape_h, &scalesDeviceAddr, aclDataType::ACL_FLOAT, &scales, args.rankId, "scales"); } else if (g_isStaticSharedScales == g_staticSharedScalesTwo) { ret = CreateAclTensor(scalesHostData, scalesCommonShape_1, &scalesDeviceAddr, aclDataType::ACL_FLOAT, &scales, args.rankId, "scales"); } else { ret = CreateAclTensor(scalesHostData, scalesShape, &scalesDeviceAddr, aclDataType::ACL_FLOAT, &scales, args.rankId, "scales"); } } else if ((g_quantMode == g_pertokenQuant) && (x_dtypeIn[g_xDtype] == ACL_BF16 || x_dtypeIn[g_xDtype] == ACL_FLOAT16)) { ret = CreateAclTensor(scalesHostData, scalesShape, &scalesDeviceAddr, aclDataType::ACL_FLOAT, &scales, args.rankId, "scales"); } else if ((g_quantMode == g_pergroupQuant) && (x_dtypeIn[g_xDtype] == ACL_BF16 || x_dtypeIn[g_xDtype] == ACL_FLOAT16)) { ret = CreateAclTensor(scalesHostData, scalesShape, &scalesDeviceAddr, aclDataType::ACL_FLOAT, &scales, args.rankId, "scales"); } else if ((g_quantMode == g_mxQuant) && (x_dtypeIn[g_xDtype] == ACL_BF16 || x_dtypeIn[g_xDtype] == ACL_FLOAT16)) { scales = nullptr; } CHECK_RET(ret == ACL_SUCCESS, return ret); return ret; } int CreateAclTensorExpertScales(Args &args) { int ret = 0; ret = CreateAclTensor(expertScalesHostData, expertScalesShape, &expertScalesDeviceAddr, aclDataType::ACL_FLOAT, &expertScales, args.rankId, "expertScales"); CHECK_RET(ret == ACL_SUCCESS, return ret); return ret; } int CreateAclTensorExpandX(Args &args) { int ret = 0; if (g_quantMode > 0) { ret = CreateAclTensor(expandXQuantHostData, expandXShape, &expandXDeviceAddr, expandx_dtypeOut[g_expandxDtype], &expandX, args.rankId, "expandX"); } else if ((g_quantMode == g_noQuant) && (x_dtypeIn[g_xDtype] == ACL_HIFLOAT8 || x_dtypeIn[g_xDtype] == ACL_FLOAT8_E4M3FN || x_dtypeIn[g_xDtype] == ACL_FLOAT8_E5M2)) { ret = CreateAclTensor(expandXQuantHostData, expandXShape, &expandXDeviceAddr, expandx_dtypeOut[g_expandxDtype], &expandX, args.rankId, "expandX"); } else { ret = CreateAclTensor(expandXHostData, expandXShape, &expandXDeviceAddr, expandx_dtypeOut[g_expandxDtype], &expandX, args.rankId, "expandX"); } CHECK_RET(ret == ACL_SUCCESS, return ret); return ret; } int CreateAclTensorDynamicScales(Args &args) { int ret = 0; if (g_quantMode == g_noQuant && (x_dtypeIn[g_xDtype] == ACL_BF16 || x_dtypeIn[g_xDtype] == ACL_FLOAT16)) { ret = CreateAclTensor(dynamicScalesHostData, dynamicScalesShape, &dynamicScalesDeviceAddr, aclDataType::ACL_FLOAT, &dynamicScales, args.rankId, "dynamicScales"); } else if ((g_quantMode == g_noQuant) && (x_dtypeIn[g_xDtype] == ACL_HIFLOAT8)) { ret = CreateAclTensor(dynamicModel2ScalesHostDataFloat, dynamicModel2ScalesShape, &dynamicScalesDeviceAddr, aclDataType::ACL_FLOAT, &dynamicScales, args.rankId, "dynamicScales"); } else if ((g_quantMode == g_noQuant) && (x_dtypeIn[g_xDtype] == ACL_FLOAT8_E4M3FN || x_dtypeIn[g_xDtype] == ACL_FLOAT8_E5M2)) { if (scaleIn[g_scaleType] == ACL_FLOAT) { ret = CreateAclTensor(dynamicModel2ScalesHostDataFloat, dynamicModel2ScalesShape, &dynamicScalesDeviceAddr, aclDataType::ACL_FLOAT, &dynamicScales, args.rankId, "dynamicScales"); } else if (scaleIn[g_scaleType] == ACL_FLOAT8_E8M0) { ret = CreateAclTensor(dynamicModel2ScalesHostDataInt8, dynamicModel2ScalesShape, &dynamicScalesDeviceAddr, aclDataType::ACL_FLOAT8_E8M0, &dynamicScales, args.rankId, "dynamicScales"); } else { ret = CreateAclTensor(dynamicScalesHostData, dynamicScalesShape, &dynamicScalesDeviceAddr, aclDataType::ACL_FLOAT, &dynamicScales, args.rankId, "dynamicScales"); } } else if ((g_quantMode == g_staticQuant) && (x_dtypeIn[g_xDtype] == ACL_BF16 || x_dtypeIn[g_xDtype] == ACL_FLOAT16)) { if (expandx_dtypeOut[g_expandxDtype] == ACL_INT8) { ret = CreateAclTensor(dynamicScalesHostDataH, dynamicScalesShape_H, &dynamicScalesDeviceAddr, aclDataType::ACL_FLOAT, &dynamicScales, args.rankId, "dynamicScales"); } else if (expandx_dtypeOut[g_expandxDtype] == ACL_HIFLOAT8) { ret = CreateAclTensor(dynamicScalesHostData, dynamicScalesShape, &dynamicScalesDeviceAddr, aclDataType::ACL_FLOAT, &dynamicScales, args.rankId, "dynamicScales"); } } else if ((g_quantMode == g_pertokenQuant) && (x_dtypeIn[g_xDtype] == ACL_BF16 || x_dtypeIn[g_xDtype] == ACL_FLOAT16)) { ret = CreateAclTensor(dynamicScalesHostData, dynamicScalesShape, &dynamicScalesDeviceAddr, aclDataType::ACL_FLOAT, &dynamicScales, args.rankId, "dynamicScales"); } else if ((g_quantMode == g_pergroupQuant) && (x_dtypeIn[g_xDtype] == ACL_BF16 || x_dtypeIn[g_xDtype] == ACL_FLOAT16)) { ret = CreateAclTensor(dynamicScalesHostDataCeil128, dynamicScalesShape_ceil128, &dynamicScalesDeviceAddr, aclDataType::ACL_FLOAT, &dynamicScales, args.rankId, "dynamicScales"); } else if ((g_quantMode == g_mxQuant) && (x_dtypeIn[g_xDtype] == ACL_BF16 || x_dtypeIn[g_xDtype] == ACL_FLOAT16)) { ret = CreateAclTensor(dynamicScalesHostDataCeil32, dynamicScalesShape_ceil32, &dynamicScalesDeviceAddr, aclDataType::ACL_FLOAT8_E8M0, &dynamicScales, args.rankId, "dynamicScales"); } CHECK_RET(ret == ACL_SUCCESS, return ret); return ret; } int CreateAclTensorExpandIdx(Args &args) { int ret = 0; ret = CreateAclTensor(expandIdxHostData, expandIdxShape, &expandIdxDeviceAddr, aclDataType::ACL_INT32, &expandIdx, args.rankId, "expandIdx"); CHECK_RET(ret == ACL_SUCCESS, return ret); return ret; } int CreateAclTensorExpertTokenNums(Args &args) { int ret = 0; ret = CreateAclTensor(expertTokenNumsHostData, expertTokenNumsShape, &expertTokenNumsDeviceAddr, aclDataType::ACL_INT64, &expertTokenNums, args.rankId, "expertTokenNums"); CHECK_RET(ret == ACL_SUCCESS, return ret); return ret; } int CreateAclTensorEpRecvCounts(Args &args) { int ret = 0; ret = CreateAclTensor(epRecvCountsHostData, epRecvCountsShape, &epRecvCountsDeviceAddr, aclDataType::ACL_INT32, &epRecvCounts, args.rankId, "epRecvCounts"); CHECK_RET(ret == ACL_SUCCESS, return ret); return ret; } int CreateAclTensorTpRecvCounts(Args &args) { int ret = 0; ret = CreateAclTensor(tpRecvCountsHostData, tpRecvCountsShape, &tpRecvCountsDeviceAddr, aclDataType::ACL_INT32, &tpRecvCounts, args.rankId, "tpRecvCounts"); CHECK_RET(ret == ACL_SUCCESS, return ret); return ret; } int CreateAclTensorExpandScales(Args &args) { int ret = 0; ret = CreateAclTensor(expandScalesHostData, expandScalesShape, &expandScalesDeviceAddr, aclDataType::ACL_FLOAT, &expandScales, args.rankId, "expandScales"); CHECK_RET(ret == ACL_SUCCESS, return ret); return ret; } int CreateAclTensorYOut(Args &args) { int ret = 0; ret = CreateAclTensor(yOutHostData, yOutShape, &yOutDeviceAddr, aclDataType::ACL_FLOAT16, &yOut, args.rankId, "yOut"); CHECK_RET(ret == ACL_SUCCESS, return ret); return ret; } int CreateAclTensorAssistInfoForCombineOut(Args &args) { int ret = 0; ret = CreateAclTensor(AssistInfoForCombineOutHostData, assistInfoForCombineOutShape, &assistInfoForCombineOutDeviceAddr, aclDataType::ACL_INT32, &assistInfoForCombineOut, args.rankId, "assistInfoForCombineOut"); CHECK_RET(ret == ACL_SUCCESS, return ret); return ret; } int CreateAclTensorcommCmdInfoOut(Args &args) { int ret = 0; ret = CreateAclTensor(commCmdInfoOutHostData, commCmdInfoOutShape, &commCmdInfoOutDeviceAddr, aclDataType::ACL_INT32, &commCmdInfoOut, args.rankId, "commCmdInfoOut"); CHECK_RET(ret == ACL_SUCCESS, return ret); return ret; } int CreateAclTensorquantExpandXOut(Args &args) { int ret = 0; ret = CreateAclTensor(quantExpandXOutHostData, quantExpandXOutShape, &quantExpandXOutDeviceAddr, aclDataType::ACL_INT8, &quantExpandXOut, args.rankId, "quantExpandXOut"); CHECK_RET(ret == ACL_SUCCESS, return ret); return ret; } int CreateAclTensorxOut(Args &args) { int ret = 0; ret = CreateAclTensor(xOutHostData, xOutShape, &xOutDeviceAddr, aclDataType::ACL_FLOAT16, &xOut, args.rankId, "xOut"); CHECK_RET(ret == ACL_SUCCESS, return ret); return ret; } int CreateAclcommCmdInfoOutforCombine(Args &args) { int ret = 0; ret = CreateAclTensor(commCmdInfoOutforCombineHostData, commCmdInfoOutShapeforcombine, &commCmdInfoOutShapeforcombineDeviceAddr, aclDataType::ACL_INT32, &commCmdInfoOutforCombine, args.rankId, "commCmdInfoOutforCombine"); CHECK_RET(ret == ACL_SUCCESS, return ret); return ret; } int ProcessDispatchSetup(char *hcomEpName, char *hcomTpName, Args &args) { int ret = 0; /******************************先调用dispatch_setup********************************************/ // 调用dispatch_setup算子第一阶段接口 ret = aclnnMoeDistributeDispatchSetupGetWorkspaceSize( x, expertIds, (hasSmoothScale ? scales : nullptr), nullptr, hcomEpName, g_epWorldSize, args.epRankId, g_moeExpertNum, g_expertSharedType, g_sharedExpertNum, g_sharedExpertRankNum, g_quantMode, globalBS, commType, nullptr, yOut, expandIdx, commCmdInfoOut, &workspaceSize, &executor); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclnnMoeDistributeDispatchSetupGetWorkspaceSize failed. ret = %d \n", ret); return ret); // 根据dispatch算子第一阶段接口计算出的workspaceSize申请device内存 if (workspaceSize > 0) { ret = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] DispatchSetup aclrtMalloc failed. ret = %d\n", ret); return ret); } // 调用dispatch算子第二阶段接口 ret = aclnnMoeDistributeDispatchSetup(workspaceAddr, workspaceSize, executor, args.dispatchsetupstream); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclnnMoeDistributeDispatchSetup failed. ret = %d \n", ret); return ret); ret = aclrtSynchronizeStreamWithTimeout(args.dispatchsetupstream, g_timeOut); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtSynchronizeStreamWithTimeout failed. ret = %d \n", ret); return ret); LOG_PRINT("[INFO] device_%d aclnnMoeDistributeDispatchSetup execute successfully.\n", args.rankId); return ret; } int ProcessDispatchTeardown(char *hcomEpName, char *hcomTpName, Args &args) { int ret = 0; /******************************调用dispatch_teardown********************************************/ // 调用dispatch_teardown算子第一阶段接口 ret = aclnnMoeDistributeDispatchTeardownGetWorkspaceSize( x, yOut, expertIds, commCmdInfoOut, hcomEpName, g_epWorldSize, args.epRankId, g_moeExpertNum, g_expertSharedType, g_sharedExpertNum, g_sharedExpertRankNum, g_quantMode, globalBS, g_expertTokenNumsType, commType, nullptr, expandX, dynamicScales, assistInfoForCombineOut, expertTokenNums, &workspaceSize, &executor); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclnnMoeDistributeDispatchTeardownGetWorkspaceSize failed. ret = %d \n", ret); return ret); // 根据dispatch算子第一阶段接口计算出的workspaceSize申请device内存 if (workspaceSize > 0) { ret = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] DispatchTeardown aclrtMalloc failed. ret = %d\n", ret); return ret); } // 调用dispatch_teardown算子第二阶段接口 ret = aclnnMoeDistributeDispatchTeardown(workspaceAddr, workspaceSize, executor, args.dispatchteardownstream); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclnnMoeDistributeDispatchTeardown failed. ret = %d \n", ret); return ret); // 等待任务执行结束 ret = aclrtSynchronizeStreamWithTimeout(args.dispatchteardownstream, g_timeOut); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtSynchronizeStreamWithTimeout failed. ret = %d \n", ret); return ret); LOG_PRINT("[INFO] device_%d aclnnMoeDistributeDispatchTeardown execute successfully.\n", args.rankId); return ret; } void ReleaseTensor() { if (x != nullptr) { aclDestroyTensor(x); } if (expertIds != nullptr) { aclDestroyTensor(expertIds); } if (scalesOptional != nullptr) { aclDestroyTensor(scalesOptional); } if (xActiveMaskOptional != nullptr) { aclDestroyTensor(xActiveMaskOptional); } if (yOut != nullptr) { aclDestroyTensor(yOut); } if (xOut != nullptr) { aclDestroyTensor(xOut); } if (sharedExpertXOptional != nullptr) { aclDestroyTensor(sharedExpertXOptional); } if (scales != nullptr) { aclDestroyTensor(scales); } if (expertScales != nullptr) { aclDestroyTensor(expertScales); } if (expandX != nullptr) { aclDestroyTensor(expandX); } if (dynamicScales != nullptr) { aclDestroyTensor(dynamicScales); } if (expandIdx != nullptr) { aclDestroyTensor(expandIdx); } if (expertTokenNums != nullptr) { aclDestroyTensor(expertTokenNums); } if (epRecvCounts != nullptr) { aclDestroyTensor(epRecvCounts); } if (tpRecvCounts != nullptr) { aclDestroyTensor(tpRecvCounts); } if (expandScales != nullptr) { aclDestroyTensor(expandScales); } if (assistInfoForCombineOut != nullptr) { aclDestroyTensor(assistInfoForCombineOut); } if (quantExpandXOut != nullptr) { aclDestroyTensor(quantExpandXOut); } if (commCmdInfoOut != nullptr) { aclDestroyTensor(commCmdInfoOut); } if (commCmdInfoOutforCombine != nullptr) { aclDestroyTensor(commCmdInfoOutforCombine); } } void ReleaseAddr() { // 释放device资源 if (workspaceSize > 0) { aclrtFree(workspaceAddr); } if (xDeviceAddr != nullptr) { aclrtFree(xDeviceAddr); } if (expertIdsDeviceAddr != nullptr) { aclrtFree(expertIdsDeviceAddr); } if (scalesDeviceAddr != nullptr) { aclrtFree(scalesDeviceAddr); } if (expertScalesDeviceAddr != nullptr) { aclrtFree(expertScalesDeviceAddr); } if (expandXDeviceAddr != nullptr) { aclrtFree(expandXDeviceAddr); } if (dynamicScalesDeviceAddr != nullptr) { aclrtFree(dynamicScalesDeviceAddr); } if (expandIdxDeviceAddr != nullptr) { aclrtFree(expandIdxDeviceAddr); } if (expertTokenNumsDeviceAddr != nullptr) { aclrtFree(expertTokenNumsDeviceAddr); } if (epRecvCountsDeviceAddr != nullptr) { aclrtFree(epRecvCountsDeviceAddr); } if (tpRecvCountsDeviceAddr != nullptr) { aclrtFree(tpRecvCountsDeviceAddr); } if (expandScalesDeviceAddr != nullptr) { aclrtFree(expandScalesDeviceAddr); } if (yOutDeviceAddr != nullptr) { aclrtFree(yOutDeviceAddr); } if (assistInfoForCombineOutDeviceAddr != nullptr) { aclrtFree(assistInfoForCombineOutDeviceAddr); } if (quantExpandXOutDeviceAddr != nullptr) { aclrtFree(quantExpandXOutDeviceAddr); } if (xOutDeviceAddr != nullptr) { aclrtFree(xOutDeviceAddr); } if (commCmdInfoOutDeviceAddr != nullptr) { aclrtFree(commCmdInfoOutDeviceAddr); } if (commCmdInfoOutShapeforcombineDeviceAddr != nullptr) { aclrtFree(commCmdInfoOutShapeforcombineDeviceAddr); } } void ReleaseResources() { ReleaseTensor(); ReleaseAddr(); } void ProcessLocalExpertNumAndLocalToken(Args &args) { if (args.epRankId < g_sharedExpertRankNum) { // 共享专家卡 localExpertNum = 1; localToken = globalBS / g_sharedExpertRankNum; } else { // Moe专家卡 localExpertNum = g_moeExpertNum / (g_epWorldSize - g_sharedExpertRankNum); localToken = globalBS * (localExpertNum < g_k ? localExpertNum : g_k); } expandXShape = {g_tpWorldSize * localToken, g_h}; dynamicScalesShape_H = {g_tpWorldSize * localToken, g_h}; dynamicScalesShape_ceil32 = {g_tpWorldSize * localToken, GetCeil(g_h, 32)}; dynamicScalesShape_ceil128 = {g_tpWorldSize * localToken, GetCeil(g_h, 128)}; dynamicScalesShape = {g_tpWorldSize * localToken}; dynamicModel2ScalesShape = {localToken * g_h}; expertTokenNumsShape = {localExpertNum}; epRecvCountsShape = {g_tpWorldSize * localExpertNum * g_epWorldSize}; tpRecvCountsShape = {g_tpWorldSize * localExpertNum}; expandScalesShape = {localToken}; expandXShapeSize = GetShapeSize(expandXShape); dynamicScalesShapeSizeH = GetShapeSize(dynamicScalesShape_H); dynamicScalesShapeSizeCeil32 = GetShapeSize(dynamicScalesShape_ceil32); dynamicScalesShapeSizeCeil128 = GetShapeSize(dynamicScalesShape_ceil128); dynamicScalesShapeSize = GetShapeSize(dynamicScalesShape); dynamicModel2ScalesShapeSize = GetShapeSize(dynamicModel2ScalesShape); expandScalesShapeSize = GetShapeSize(expandScalesShape); expertTokenNumsShapeSize = GetShapeSize(expertTokenNumsShape); epRecvCountsShapeSize = GetShapeSize(epRecvCountsShape); tpRecvCountsShapeSize = GetShapeSize(tpRecvCountsShape); expandXHostData.resize(expandXShapeSize, 0); expandXQuantHostData.resize(expandXShapeSize, 0); dynamicScalesHostDataH.resize(dynamicScalesShapeSizeH, 0); dynamicScalesHostDataCeil32.resize(dynamicScalesShapeSizeCeil32, 0); dynamicScalesHostDataCeil128.resize(dynamicScalesShapeSizeCeil128, 0); dynamicScalesHostData.resize(dynamicScalesShapeSize, 0); dynamicScalesMxfp8HostData.resize(dynamicScalesShapeSize, 0); dynamicModel2ScalesHostDataFloat.resize(dynamicModel2ScalesShapeSize, 0); dynamicModel2ScalesHostDataInt8.resize(dynamicModel2ScalesShapeSize, 0); expandScalesHostData.resize(expandScalesShapeSize, 0); expertTokenNumsHostData.resize(expertTokenNumsShapeSize, 0); epRecvCountsHostData.resize(epRecvCountsShapeSize, 0); tpRecvCountsHostData.resize(tpRecvCountsShapeSize, 0); } int LaunchOneProcess(Args &args) { ProcessLocalExpertNumAndLocalToken(args); int ret = aclrtSetCurrentContext(args.context); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtSetCurrentContext failed. ret: %d\n", ret); return ret); char hcomEpName[128] = {0}; // ep通信域名称 ret = HcclGetCommName(args.hcclEpComm, hcomEpName); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] HcclGetEpCommName failed. ret: %d\n", ret); return ret); char hcomTpName[128] = {0}; auto destroyFunc = [&args]() { std::cout << "== begin to destroy " << std::endl; HcclCommDestroy(args.hcclEpComm); HcclCommDestroy(args.hcclTpComm); aclrtDestroyStream(args.dispatchsetupstream); aclrtDestroyStream(args.dispatchteardownstream); aclrtDestroyContext(args.context); aclrtResetDevice(args.rankId); }; // 构造device侧变量 ret = CreateAclTensorX(args); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] CreateAclTensor failed. ret: %d\n", ret); return ret); ret = CreateAclTensorExpertIds(args); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] CreateAclTensor failed. ret: %d\n", ret); return ret); ret = CreateAclTensorScales(args); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] CreateAclTensor failed. ret: %d\n", ret); return ret); ret = CreateAclTensorExpertScales(args); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] CreateAclTensor failed. ret: %d\n", ret); return ret); ret = CreateAclTensorExpandX(args); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] CreateAclTensor failed. ret: %d\n", ret); return ret); ret = CreateAclTensorDynamicScales(args); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] CreateAclTensor failed. ret: %d\n", ret); return ret); ret = CreateAclTensorExpandIdx(args); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] CreateAclTensor failed. ret: %d\n", ret); return ret); ret = CreateAclTensorExpertTokenNums(args); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] CreateAclTensor failed. ret: %d\n", ret); return ret); ret = CreateAclTensorEpRecvCounts(args); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] CreateAclTensor failed. ret: %d\n", ret); return ret); ret = CreateAclTensorTpRecvCounts(args); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] CreateAclTensor failed. ret: %d\n", ret); return ret); ret = CreateAclTensorExpandScales(args); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] CreateAclTensor failed. ret: %d\n", ret); return ret); ret = CreateAclTensorYOut(args); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] CreateAclTensor failed. ret: %d\n", ret); return ret); ret = CreateAclTensorAssistInfoForCombineOut(args); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] CreateAclTensor failed. ret: %d\n", ret); return ret); ret = CreateAclTensorcommCmdInfoOut(args); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] CreateAclTensor failed. ret: %d\n", ret); return ret); ret = CreateAclTensorquantExpandXOut(args); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] CreateAclTensor failed. ret: %d\n", ret); return ret); ret = CreateAclTensorxOut(args); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] CreateAclTensor failed. ret: %d\n", ret); return ret); ret = CreateAclcommCmdInfoOutforCombine(args); LOG_PRINT("----Dispatch Setup aclnn start----"); ProcessDispatchSetup(hcomEpName, hcomTpName, args); LOG_PRINT("----Dispatch Teardown aclnn start----"); ProcessDispatchTeardown(hcomEpName, hcomTpName, args); ReleaseResources(); return 0; } int RunInProcess(int rank, int rankSize) { int ret = aclInit(nullptr); LOG_PRINT("aclInit: %d \n", ret); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclInit failed. ret = %d\n", ret); return ret); aclrtStream dispatchsetupstream; aclrtStream dispatchteardownstream; aclrtContext context; ret = aclrtSetDevice(rank); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtSetDevice failed. ret = %d\n", ret); return ret); ret = aclrtCreateContext(&context, rank); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtCreateContext failed. ret = %d\n", ret); return ret); ret = aclrtCreateStream(&dispatchsetupstream); ret = aclrtCreateStream(&dispatchteardownstream); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtCreateStream failed. ret = %d\n", ret); return ret); HcclComm commsEp; HcclCommConfig config; HcclCommConfigInit(&config); config.hcclDeterministic = 1; config.hcclBufferSize = g_hcclBufferSize; strncpy(config.hcclCommName, "hccl_comm_test", COMM_NAME_MAX_LENGTH - 1); std::string rankTableFile = std::getenv("RANK_TABLE_FILE"); std::cout << "rankTableFilePath is :" << rankTableFile << std::endl; int rank_id = rank + FIRST_RANK_ID; ret = HcclCommInitClusterInfoConfig(rankTableFile.c_str(), rank_id, &config, &commsEp); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] HcclCommInitClusterInfoConfig ep world %d failed. ret = %d\n", g_rankId, ret); return ret); Args args; uint32_t epRankId = rank_id / g_tpWorldSize; uint32_t tpRankId = rank_id % g_tpWorldSize; args.rankId = rank; args.epRankId = epRankId; args.tpRankId = tpRankId; args.hcclEpComm = commsEp; args.dispatchsetupstream = dispatchsetupstream; args.dispatchteardownstream = dispatchteardownstream; args.context = context; int res = LaunchOneProcess(args); if (res != ACL_SUCCESS) { std::cout << "run LaunchOneProcess failed, ret = " << res << std::endl; return res; } LOG_PRINT("[INFO] aclFinalize success\n"); return res; } int main(int argc, char *argv[]) { char *env_rankID = std::getenv("FIRST_RANK_ID"); if (!env_rankID) { std::cerr << "FIRST_RANK_ID环境变量未设置!\n"; return 1; } FIRST_RANK_ID = std::stoi(std::string(env_rankID)); std::cout << "FIRST_RANK_ID is: " << FIRST_RANK_ID << std::endl; const int processCount = 2; pid_t pids[processCount]; for (int i = 0; i < processCount; ++i) { pids[i] = fork(); if (pids[i] < 0) { std::cout << "fork failed ! " << pids[i] << std::endl; } else if (pids[i] == 0) { // 子进程,完成任务后退出 int ret = RunInProcess(i, processCount); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] RunInProcess failed. ret = %d\n", ret); return ret); exit(0); } } // 父进程等待所有子进程完成 for (int i = 0; i < processCount; ++i) { waitpid(pids[i], nullptr, 0); } return 0; } -