torch_npu.npu_moe_distribute_dispatch_v2
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
功能说明
-
API功能: 需与torch_npu.npu_moe_distribute_combine_v2或torch_npu.npu_moe_distribute_combine_add_rms_norm配套使用,完成MoE的并行部署下的token dispatch_v2与combine_v2。
- 支持动态量化场景,对token数据先进行量化(可选),再进行EP(Expert Parallelism)域的alltoallv通信,再进行TP(Tensor Parallelism)域的allgatherv通信(可选);
- 支持特殊专家场景。
-
相较于npu_moe_distribute_dispatch接口,该接口变更如下:
- npu_moe_distribute_dispatch中shape为(Bs * K,)的返回值
expand_idx替换为shape为(A * 128,)的assist_info_for_combine,以包含更详细的token信息辅助torch_npu.npu_moe_distribute_combine_v2高效地进行全卡同步; - 新增输入参数
comm_alg,可用于代替HCCL_INTRA_PCIE_ENABLE和HCCL_INTRA_ROCE_ENABLE环境变量。
- npu_moe_distribute_dispatch中shape为(Bs * K,)的返回值
-
计算公式:
-
动态量化场景:
若
quant_mode不为2,即非动态量化场景:quant_out={ x,if quant_mode=0 CastToInt8( CastToFp32(x)× scales),if quant_mode≠0\ quant\_out= \begin{cases} \ x, & \quad \text{if}\ quant\_mode = 0 \\ \ CastToInt8(\ CastToFp32(x) \times \ scales ), & \quad \text{if } quant\_mode ≠ 0 \\ \end{cases}
alltoall_x_out= alltoallv( quant_out)\ alltoall\_x\_out= \ alltoallv(\ quant\_out)
expand_x={ allgatherv(alltoall_x_out), 有TP通信域 alltoall_x_out 无TP通信域\ expand\_x= \begin{cases} \ allgatherv(alltoall\_x\_out), & \quad \ 有TP通信域 \\ \ alltoall\_x\_out & \quad \ 无TP通信域 \\ \end{cases}
若
quant_mode为2,即动态量化场景:x_fp32= CastToFp32(x)× scales\ x\_fp32= \ CastToFp32(x) \times \ scales
dynamic_scales_value=127.0/Max(Abs(x_fp32))\ dynamic\_scales\_value = 127.0/Max(Abs(x\_fp32))
quant_out=CastToInt8( x_fp32× dynamic_scales_value)\ quant\_out=CastToInt8(\ x\_fp32 \times \ dynamic\_scales\_value )
alltoall_x_out= alltoallv( quant_out)\ alltoall\_x\_out= \ alltoallv(\ quant\_out)
alltoall_dynamic_scales_out=alltoall(1.0/dynamic_scales)\ alltoall\_dynamic\_scales\_out = alltoall(1.0/dynamic\_scales)
expand_x={ allgatherv(alltoall_x_out), 有TP通信域 alltoall_x_out 无TP通信域\ expand\_x= \begin{cases} \ allgatherv(alltoall\_x\_out), & \quad \ 有TP通信域 \\ \ alltoall\_x\_out & \quad \ 无TP通信域 \\ \end{cases}
dynamic_scales={ allgatherv(alltoall_dynamic_scales_out), 有TP通信域 alltoall_dynamic_scales_out 无TP通信域\ dynamic\_scales= \begin{cases} \ allgatherv(alltoall\_dynamic\_scales\_out), & \quad \ 有TP通信域 \\ \ \ alltoall\_dynamic\_scales\_out & \quad \ 无TP通信域 \\ \end{cases}
-
特殊专家场景:
零专家场景,即
zero_expert_num不为0:Moe(ori_x)=0Moe(ori\_x)=0
拷贝专家场景,即
copy_expert_num不为0:Moe(ori_x)=ori_xMoe(ori\_x)=ori\_x
常量专家场景,即
const_expert_num不为0:Moe(ori_x)=const_expert_alpha_1∗ori_x+const_expert_alpha_2∗const_expert_vMoe(ori\_x)=const\_expert\_alpha\_1*ori\_x+const\_expert\_alpha\_2*const\_expert\_v
参数ori_x、const_expert_alpha_1、const_expert_alpha_2、const_expert_v见torch_npu.npu_moe_distribute_combine_v2文档。
-
函数原型
torch_npu.npu_moe_distribute_dispatch_v2(x, expert_ids, group_ep, ep_world_size, ep_rank_id, moe_expert_num, *, scales=None, x_active_mask=None, expert_scales=None, elastic_info=None, performance_info=None, group_tp="", tp_world_size=0, tp_rank_id=0, expert_shard_type=0, shared_expert_num=1, shared_expert_rank_num=0, quant_mode=0, global_bs=0, expert_token_nums_type=1, comm_alg="", zero_expert_num=0, copy_expert_num=0, const_expert_num=0) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)
参数说明
-
x (
Tensor):必选参数,表示计算使用的token数据,需根据expert_ids来发送给其他卡。要求为2维张量,shape为(BS, H),表示有BS个token,数据类型支持bfloat16、float16,数据格式为NDND,支持非连续的Tensor。 -
expert_ids (
Tensor):必选参数,表示每个token的topK个专家索引,决定每个token要发给哪些专家。要求为2维张量,shape为(BS, K),数据类型支持int32,数据格式为NDND,支持非连续的Tensor。对应torch_npu.npu_moe_distribute_combine_v2的expert_ids输入,张量里value取值范围为[0, moe_expert_num),且同一行中的K个value不能重复。 -
group_ep (
str):必选参数,EP通信域名称,专家并行的通信域。字符串长度范围为[1,128)。Atlas A3 训练系列产品/Atlas A3 推理系列产品时,group_ep不能与group_tp相同。 -
ep_world_size(
int):必选参数,EP通信域size。- Atlas A2 训练系列产品/Atlas A2 推理系列产品:
ep_world_size的取值范围如下所示。comm_alg设置为"fullmesh"时,ep_world_size取值范围为16、32、64、128、256。comm_alg设置为"hierarchy"时,ep_world_size取值范围为16、32、64。
- Atlas A3 训练系列产品/Atlas A3 推理系列产品:取值支持[2, 768]。
- Atlas A2 训练系列产品/Atlas A2 推理系列产品:
-
ep_rank_id (
int):必选参数,EP通信域本卡ID,取值范围[0, ep_world_size),同一个EP通信域中各卡的ep_rank_id不重复。 -
moe_expert_num (
int):必选参数,MoE专家数量,取值范围[1, 1024],并且满足以下条件:moe_expert_num%(ep_world_size - shared_expert_rank_num)=0。- Atlas A2 训练系列产品/Atlas A2 推理系列产品:还需满足moe_expert_num/(ep_world_size - shared_expert_rank_num) <= 24。
-
*:必选参数,代表其之前的变量是位置相关的,必须按照顺序输入;之后的变量是可选参数,位置无关,需要使用键值对赋值,不赋值会使用默认值。
-
scales (
Tensor):可选参数,表示每个专家的权重,非量化场景不传,动态量化场景可传可不传。若传值要求为2维张量,如果有共享专家,shape为(shared_expert_num+moe_expert_num, H),如果没有共享专家,shape为(moe_expert_num, H),数据类型支持float,数据格式为NDND,不支持非连续的Tensor。 -
x_active_mask (
Tensor):可选参数,表示token是否参与通信。- Atlas A2 训练系列产品/Atlas A2 推理系列产品:
comm_alg设置为"fullmesh"时,要求是一个1维或者2维张量。当输入为1维时,shape为(BS, ); 当输入为2维时,shape为(BS, K)。数据类型支持bool,数据格式要求为NDND,支持非连续的Tensor。当输入为1维时,参数为true表示对应的token参与通信,true必须排到false之前,例:{true, false, true} 为非法输入;当输入为2D时,参数为true表示当前token对应的expert_ids参与通信,若当前token对应的K个bool值全为false,表示当前token不会参与通信。默认所有token都会参与通信。当每张卡的BS数量不一致时,所有token必须全部有效。支持2维张量属于零计算专家特性,此特性尚在实验阶段,请谨慎使用。comm_alg设置为"hierarchy"时,当前版本不支持,使用默认值None即可。
- Atlas A3 训练系列产品/Atlas A3 推理系列产品:要求是一个1维或者2维张量。当输入为1维时,shape为(BS, ); 当输入为2维时,shape为(BS, K)。数据类型支持
bool,数据格式要求为NDND,支持非连续的Tensor。当输入为1维时,参数为true表示对应的token参与通信,true必须排到false之前,例:{true, false, true} 为非法输入;当输入为2D时,参数为true表示当前token对应的expert_ids参与通信,若当前token对应的K个bool值全为false,表示当前token不会参与通信。默认所有token都会参与通信。当每张卡的BS数量不一致时,所有token必须全部有效。
- Atlas A2 训练系列产品/Atlas A2 推理系列产品:
-
expert_scales (
Tensor):可选参数,表示每个token的topK个专家权重。- Atlas A2 训练系列产品/Atlas A2 推理系列产品:要求为2维张量,shape为(BS, K),数据类型支持
float,数据格式为NDND,支持非连续的Tensor。 - Atlas A3 训练系列产品/Atlas A3 推理系列产品:暂不支持该参数,使用默认值即可。
- Atlas A2 训练系列产品/Atlas A2 推理系列产品:要求为2维张量,shape为(BS, K),数据类型支持
-
elastic_info (
Tensor):预留参数,当前版本不支持,传默认值None即可。 -
performance_info (
Tensor):可选参数,表示本卡等待各卡数据的通信时间,单位为us(微秒)。单次算子调用各卡通信耗时会累加到该Tensor上,算子内部不进行自动清零,因此用户每次启用此Tensor开始记录耗时前需对Tensor清零。- Atlas A2 训练系列产品/Atlas A2 推理系列产品:可选择传入有效数据或填None,传入None时表示不使能记录通信耗时功能;当传入有效数据时,要求是一个1D的Tensor,shape为(ep_world_size,),数据类型支持int64;数据格式要求为ND,支持非连续的Tensor。
- Atlas A3 训练系列产品/Atlas A3 推理系列产品:预留参数,当前版本不支持,传默认值None即可。
-
group_tp (
string):可选参数,TP通信域名称,数据并行的通信域。若有TP域通信需要传参,若无TP域通信,使用默认值""即可。- Atlas A2 训练系列产品/Atlas A2 推理系列产品:eager模式使用默认值即可,图模式传入与
group_ep相同。 - Atlas A3 训练系列产品/Atlas A3 推理系列产品:字符串长度范围为[1, 128),不能和
group_ep相同。
- Atlas A2 训练系列产品/Atlas A2 推理系列产品:eager模式使用默认值即可,图模式传入与
-
tp_world_size (
int):可选参数,TP通信域size。有TP域通信才需要传参。- Atlas A2 训练系列产品/Atlas A2 推理系列产品:不支持TP域通信,使用默认值0即可。
- Atlas A3 训练系列产品/Atlas A3 推理系列产品:当有TP域通信时,取值范围[0, 2],0和1表示无TP域通信,2表示有TP域通信。
-
tp_rank_id (
int):可选参数,TP通信域本卡ID。有TP域通信才需要传参。- Atlas A2 训练系列产品/Atlas A2 推理系列产品:不支持TP域通信,使用默认值即可。
- Atlas A3 训练系列产品/Atlas A3 推理系列产品:当有TP域通信时,取值范围[0, 1],默认为0,同一个TP通信域中各卡的
tp_rank_id不重复。无TP域通信时,传0即可。
-
expert_shard_type (
int):可选参数,表示共享专家卡排布类型。- Atlas A2 训练系列产品/Atlas A2 推理系列产品:暂不支持该参数,使用默认值即可。
- Atlas A3 训练系列产品/Atlas A3 推理系列产品:当前仅支持0,表示共享专家卡排在MoE专家卡前面。
-
shared_expert_num (
int):可选参数,表示共享专家数量,一个共享专家可以复制部署到多个卡上。- Atlas A2 训练系列产品/Atlas A2 推理系列产品:暂不支持该参数,使用默认值即可。
- Atlas A3 训练系列产品/Atlas A3 推理系列产品:取值范围[0, 4],0表示无共享专家,默认值为1。
-
shared_expert_rank_num (
int):可选参数,表示共享专家卡数量。- Atlas A2 训练系列产品/Atlas A2 推理系列产品:不支持共享专家,使用默认值即可。
- Atlas A3 训练系列产品/Atlas A3 推理系列产品:取值范围[0, ep_world_size)。取0表示无共享专家,不取0需满足shared_expert_rank_num%shared_expert_num=0。
-
quant_mode (
int):可选参数,表示量化模式。支持取值:0表示非量化(默认),2表示动态量化。当quant_mode为2,dynamic_scales不为None;当quant_mode为0,dynamic_scales为None。 -
global_bs (
int):可选参数,表示EP域全局的batch size大小。当每个rank的BS不同时,支持传入max_bs*ep_world_size,其中max_bs表示单rank BS最大值;当每个rank的BS相同时,支持取值0或BS*ep_world_size。 -
expert_token_nums_type (
int):可选参数,表示输出expert_token_nums的值类型,取值范围[0, 1],0表示每个专家收到token数量的前缀和,1表示每个专家收到的token数量(默认)。 -
comm_alg (
string):可选参数,表示通信亲和内存布局算法。- Atlas A2 训练系列产品/Atlas A2 推理系列产品:当前版本支持"","fullmesh","hierarchy"三种输入方式。推荐配置"hierarchy"并搭配25.0.RC1.1及以上版本驱动使用。
- "": 配置HCCL_INTRA_PCIE_ENABLE=1和HCCL_INTRA_ROCE_ENABLE=0时,调用"hierarchy"算法,否则调用"fullmesh"算法。不推荐使用该方式。
- "fullmesh": token数据直接通过RDMA方式发往topk个目标专家所在的卡。
- "hierarchy": token数据经过跨机、机内两次发送,仅不同server同号卡之间使用RDMA通信,server内使用HCCS通信。
- Atlas A3 训练系列产品/Atlas A3 推理系列产品:当前版本仅支持""、"fullmesh_v1"、"fullmesh_v2"三种输入方式。
- "":默认值,使能fullmesh_v1模板;
- "fullmesh_v1":使能fullmesh_v1模板;
- "fullmesh_v2":使能fullmesh_v2模板,其中fullmesh_v2模板仅在
tp_world_size取值为1时生效,且不支持各卡Bs不一致、使能x_active_mask和使能特殊专家等场景。
- Atlas A2 训练系列产品/Atlas A2 推理系列产品:当前版本支持"","fullmesh","hierarchy"三种输入方式。推荐配置"hierarchy"并搭配25.0.RC1.1及以上版本驱动使用。
-
zero_expert_num (
int):可选参数,表示零专家的数量。- Atlas A2 训练系列产品/Atlas A2 推理系列产品:
comm_alg设置为"fullmesh"时,取值范围[0, MAX_INT32),MAX_INT32 = 2^31 - 1,合法的零专家的ID值是[moe_expert_num, moe_expert_num+zero_expert_num)。参数为非0时属于零计算专家特性,此特性尚在实验阶段,请谨慎使用。comm_alg设置为"hierarchy"时,当前版本不支持,传0即可。
- Atlas A3 训练系列产品/Atlas A3 推理系列产品:取值范围[0, MAX_INT32),MAX_INT32 = 2^31 - 1,合法的零专家的ID值是[moe_expert_num, moe_expert_num+zero_expert_num),当
comm_alg设置为"fullmesh_v2"时,当前版本不支持,传0即可。
- Atlas A2 训练系列产品/Atlas A2 推理系列产品:
-
copy_expert_num (
int):可选参数,表示拷贝专家的数量。- Atlas A2 训练系列产品/Atlas A2 推理系列产品:
comm_alg设置为"fullmesh"时,取值范围[0, MAX_INT32),MAX_INT32 = 2^31 - 1,合法的拷贝专家的ID值是[moe_expert_num+zero_expert_num, moe_expert_num+zero_expert_num+copy_expert_num)。参数为非0时属于零计算专家特性,此特性尚在实验阶段,请谨慎使用。comm_alg设置为"hierarchy"时,当前版本不支持,传0即可。
- Atlas A3 训练系列产品/Atlas A3 推理系列产品:取值范围[0, MAX_INT32),MAX_INT32 = 2^31 - 1,合法的拷贝专家的ID值是[moe_expert_num+zero_expert_num, moe_expert_num+zero_expert_num+copy_expert_num),当
comm_alg设置为"fullmesh_v2"时,当前版本不支持,传0即可。
- Atlas A2 训练系列产品/Atlas A2 推理系列产品:
-
const_expert_num (
int):可选参数,表示常量专家的数量。- Atlas A2 训练系列产品/Atlas A2 推理系列产品:当前版本不支持,传0即可。
- Atlas A3 训练系列产品/Atlas A3 推理系列产品:取值范围[0, MAX_INT32),MAX_INT32 = 2^31 - 1,合法的常量专家的ID值是[moe_expert_num+zero_expert_num+copy_expert_num, moe_expert_num+zero_expert_num+copy_expert_num+const_expert_num),当
comm_alg设置为"fullmesh_v2"时,当前版本不支持,传0即可。
输出说明
-
expand_x (
Tensor):表示本卡收到的token数据,要求为2维张量,shape为(max(tp_world_size, 1) *A, H),A表示在EP通信域可能收到的最大token数,数据类型支持bfloat16、float16、int8。量化时类型为int8,非量化时与x数据类型保持一致。数据格式为NDND,支持非连续的Tensor。 -
dynamic_scales (
Tensor):表示计算得到的动态量化参数。当quant_mode不为0时才有该输出,要求为1维张量,shape为(A,),数据类型支持float,数据格式支持NDND,支持非连续的Tensor。 -
assist_info_for_combine (
Tensor):表示给同一专家发送的token个数,要求是一个1维张量,shape为(A * 128, )。数据类型支持int32,数据格式为NDND,支持非连续的Tensor。对应torch_npu.npu_moe_distribute_combine_v2的assist_info_for_combine输入。 -
expert_token_nums (
Tensor):本卡每个专家实际收到的token数量,要求为1维张量,shape为(local_expert_num,),数据类型int64,数据格式支持NDND,支持非连续的Tensor。 -
ep_recv_counts (
Tensor):表示EP通信域各卡收到的token数(token数以前缀和的形式表示),要求为1维张量,数据类型int32,数据格式支持NDND,支持非连续的Tensor。对应torch_npu.npu_moe_distribute_combine_v2的ep_send_counts输入。- Atlas A2 训练系列产品/Atlas A2 推理系列产品:要求shape为(moe_expert_num+2*global_bs*K*server_num, ),前
moe_expert_num个数表示在EP通信域内,该卡上每个专家收到来自其他各卡的token数(以前缀和的形式表示),2*global_bs*K*server_num用于存储机间和机内通信前,combine可提前做reduce操作的token个数和通信区偏移量,global_bs传入0时此处按照bs*ep_world_size计算。 - Atlas A3 训练系列产品/Atlas A3 推理系列产品:要求shape为(ep_world_size*max(tp_world_size, 1)*local_expert_num, )。
- Atlas A2 训练系列产品/Atlas A2 推理系列产品:要求shape为(moe_expert_num+2*global_bs*K*server_num, ),前
-
tp_recv_counts (
Tensor):表示TP通信域各卡收到的token数量。对应torch_npu.npu_moe_distribute_combine_v2的tp_send_counts输入。- Atlas A2 训练系列产品/Atlas A2 推理系列产品:不支持TP通信域,暂无该输出,
- Atlas A3 训练系列产品/Atlas A3 推理系列产品:支持TP通信域,要求是一个1D Tensor,shape为(tp_world_size, ),数据类型支持
int32,数据格式为NDND,支持非连续的Tensor。
-
expand_scales (
Tensor):表示expert_scales与x一起进行alltoallv之后的输出。- Atlas A2 训练系列产品/Atlas A2 推理系列产品:要求是一个1维张量,shape为(A, ),数据类型支持
float,数据格式要求为NDND,支持非连续的Tensor。 - Atlas A3 训练系列产品/Atlas A3 推理系列产品:暂不支持该输出,返回None。
- Atlas A2 训练系列产品/Atlas A2 推理系列产品:要求是一个1维张量,shape为(A, ),数据类型支持
约束说明
-
该接口支持推理场景下使用。
-
该接口支持静态图模式,
npu_moe_distribute_dispatch_v2和npu_moe_distribute_combine_v2必须配套使用。 -
在不同产品型号、不同通信算法或不同版本中,
npu_moe_distribute_dispatch_v2的Tensor输出assist_info_for_combine、ep_recv_counts、tp_recv_counts、expand_scales中的元素值可能不同,使用时直接将上述Tensor传给npu_moe_distribute_combine_v2对应参数即可,模型其他业务逻辑不应对其存在依赖。 -
调用接口过程中使用的
group_ep、ep_world_size、moe_expert_num、group_tp、tp_world_size、expert_shard_type、shared_expert_num、shared_expert_rank_num、global_bs参数取值所有卡需保持一致,group_ep、ep_world_size、group_tp、tp_world_size、expert_shard_type、global_bs网络不同层中也需保持一致,且和torch_npu.npu_moe_distribute_combine_v2对应参数也保持一致。 -
Atlas A3 训练系列产品/Atlas A3 推理系列产品:该场景下单卡包含双DIE(简称为“晶粒”或“裸片”),因此参数说明里的“本卡”均表示单DIE。
-
moe_expert_num + zero_expert_num + copy_expert_num + const_expert_num < MAX_INT32。
-
参数里Shape使用的变量如下:
-
A:表示本卡接收的最大token数量,取值范围如下
- 对于共享专家,要满足A=BS*shared_expert_num/shared_expert_rank_num。
- 对于MoE专家,当
global_bs为0时,要满足A>=BS*ep_world_size*min(local_expert_num, K);当global_bs不为0时,要满足A>=global_bs* min(local_expert_num, K)。
-
H:表示hidden size隐藏层大小。
- Atlas A2 训练系列产品/Atlas A2 推理系列产品:
H的取值范围如下所示。comm_alg设置为"fullmesh"时,H的取值范围(0, 7168],且保证是32的整数倍。comm_alg设置为"hierarchy"且驱动版本不低于25.0.RC1.1时,H的取值范围(0, 10 * 1024],且保证是32的整数倍。
- Atlas A3 训练系列产品/Atlas A3 推理系列产品:取值为[1024, 8192]。
- Atlas A2 训练系列产品/Atlas A2 推理系列产品:
-
BS:表示batch sequence size,即本卡最终输出的token数量。
- Atlas A2 训练系列产品/Atlas A2 推理系列产品:取值范围为0<BS≤256。
- Atlas A3 训练系列产品/Atlas A3 推理系列产品:取值范围为0<BS≤512,当
comm_alg为"fullmesh_v2"时,需满足0<BS≤256。
-
K:表示选取topK个专家,取值范围为0<K≤16,同时满足0 < K ≤ moe_expert_num + zero_expert_num + copy_expert_num + const_expert_num,当
comm_alg为"fullmesh_v2"时,取值范围为0<K≤12。 -
server_num:表示服务器的节点数,取值只支持2、4、8。
- Atlas A2 训练系列产品/Atlas A2 推理系列产品:仅该场景的shape使用了该变量。
-
local_expert_num:表示本卡专家数量。
- 对于共享专家卡,local_expert_num为1。
- 对于MoE专家卡,local_expert_num=moe_expert_num/(ep_world_size-shared_expert_rank_num),当local_expert_num大于1时,不支持TP域通信。
- Atlas A3 训练系列产品/Atlas A3 推理系列产品:应满足0 < local_expert_num * ep_world_size ≤ 2048。
-
-
HCCL通信域缓存区大小:
调用本接口前需检查通信域缓存区大小取值是否合理,单位MB,不配置时默认为200MB。
-
Atlas A2 训练系列产品/Atlas A2 推理系列产品: 该场景支持通过环境变量HCCL_BUFFSIZE配置。
- comm_alg配置为"": 依照HCCL_INTRA_PCIE_ENABLE和HCCL_INTRA_ROCE_ENABLE配置选择"fullmesh"或"hierarchy"公式。
- comm_alg配置为"fullmesh": 设置大小要求>=2*(BS*ep_world_size*min(local_expert_num, K)*H*sizeof(uint16)+2MB)。
- comm_alg配置为"hierarchy": 设置大小要求=moe_expert_num*BS*(H*sizeof(dtype_x)+4*((K+7)/8*8)*sizeof(uint32))+4MB+100MB,不要求moe_expert_num/(ep_world_size - shared_expert_rank_num) <= 24。
-
Atlas A3 训练系列产品/Atlas A3 推理系列产品: 该场景不仅支持通过环境变量HCCL_BUFFSIZE配置,还支持通过hccl_buffer_size配置(参考《PyTorch训练模型迁移调优》中“性能调优>性能调优方法>通信优化>优化方法>hccl_buffer_size”章节)。
- ep通信域内,comm_alg配置为"fullmesh_v1"或"": 设置大小要求 >= 2 * (local_expert_num * max_bs * ep_world_size * Align512(Align32(2 * H) + 64) + (K + shared_expert_num) * max_bs * Align512(2 * H))。
- ep通信域内,comm_alg配置为"fullmesh_v2": 设置大小要求 >= 2 * (local_expert_num * max_bs * ep_world_size * 480Align512(Align32(2 * H) + 64) + (K + shared_expert_num) * max_bs * Align512(2 * H))。
- tp通信域内:设置大小要求 >= (A * Align512(Align32(h * 2) + 44) + A * Align512(h * 2)) * 2。
- 其中 480Align512(x) = ((x+480-1)/480)*512,Align512(x) = ((x+512-1)/512)*512,Align32(x) = ((x+32-1)/32)*32。
-
-
HCCL_INTRA_PCIE_ENABLE和HCCL_INTRA_ROCE_ENABLE:
- Atlas A2 训练系列产品/Atlas A2 推理系列产品:该环境变量不再推荐使用,建议comm_alg配置"hierarchy"。
-
本文公式中的“/”表示整除。
-
通信域使用约束:
-
一个模型中的
npu_moe_distribute_dispatch_v2和npu_moe_distribute_combine_v2算子仅支持相同EP通信域,且该通信域中不允许有其他算子。 -
一个模型中的
npu_moe_distribute_dispatch_v2和npu_moe_distribute_combine_v2算子仅支持相同TP通信域或都不支持TP通信域,有TP通信域时该通信域中不允许有其他算子。 -
Atlas A3 训练系列产品/Atlas A3 推理系列产品:一个通信域内的节点需在一个超节点内,不支持跨超节点。
-
-
组网约束:
- Atlas A2 训练系列产品/Atlas A2 推理系列产品:多机场景仅支持交换机组网,不支持双机直连组网。
-
版本配套约束:
静态图模式下,从Ascend Extension for PyTorch 8.0.0版本开始,Ascend Extension for PyTorch框架会对静态图中最后一个节点输出结果做Meta推导与inferShape推导的结果强校验。当图中只有一个Dispatch_v2算子,若CANN版本落后于Ascend Extension for PyTorch版本,会出现Shape不匹配报错,建议用户升级CANN版本,详细的版本配套关系参见《Ascend Extension for PyTorch 版本说明》中“相关产品版本配套说明”。
调用示例
-
单算子模式调用
import os import torch import random import torch_npu import numpy as np from torch.multiprocessing import Process import torch.distributed as dist from torch.distributed import ReduceOp # 控制模式 quant_mode = 2 # 2为动态量化 is_dispatch_scales = True # 动态量化可选择是否传scales input_dtype = torch.bfloat16 # 输出dtype server_num = 1 server_index = 0 port = 50001 master_ip = '127.0.0.1' dev_num = 16 world_size = server_num * dev_num rank_per_dev = int(world_size / server_num) # 每个host有几个die shared_expert_rank_num = 0 # 共享专家数 moe_expert_num = 32 # moe专家数 bs = 8 # token数量 h = 7168 # 每个token的长度 k = 8 random_seed = 0 tp_world_size = 1 ep_world_size = int(world_size / tp_world_size) moe_rank_num = ep_world_size - shared_expert_rank_num local_moe_expert_num = moe_expert_num // moe_rank_num globalBS = bs * ep_world_size is_shared = (shared_expert_rank_num > 0) is_quant = (quant_mode > 0) zero_expert_num = 1 copy_expert_num = 1 const_expert_num = 1 def gen_const_expert_alpha_1(): const_expert_alpha_1 = torch.empty(size=[const_expert_num, h], dtype=input_dtype).uniform_(-1, 1) return const_expert_alpha_1 def gen_const_expert_alpha_2(): const_expert_alpha_2 = torch.empty(size=[const_expert_num, h], dtype=input_dtype).uniform_(-1, 1) return const_expert_alpha_2 def gen_const_expert_v(): const_expert_v = torch.empty(size=[const_expert_num, h], dtype=input_dtype).uniform_(-1, 1) return const_expert_v def get_new_group(rank): for i in range(tp_world_size): # 如果tp_world_size = 2,ep_world_size = 8,则为[[0, 2, 4, 6, 8, 10, 12, 14], [1, 3, 5, 7, 9, 11, 13, 15]] ep_ranks = [x * tp_world_size + i for x in range(ep_world_size)] ep_group = dist.new_group(backend="hccl", ranks=ep_ranks) if rank in ep_ranks: ep_group_t = ep_group print(f"rank:{rank} ep_ranks:{ep_ranks}") for i in range(ep_world_size): # 如果tp_world_size = 2,ep_world_size = 8,则为[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13], [14, 15]] tp_ranks = [x + tp_world_size * i for x in range(tp_world_size)] tp_group = dist.new_group(backend="hccl", ranks=tp_ranks) if rank in tp_ranks: tp_group_t = tp_group print(f"rank:{rank} tp_ranks:{tp_ranks}") return ep_group_t, tp_group_t def get_hcomm_info(rank, comm_group): if torch.__version__ > '2.0.1': hcomm_info = comm_group._get_backend(torch.device("npu")).get_hccl_comm_name(rank) else: hcomm_info = comm_group.get_hccl_comm_name(rank) return hcomm_info def warm_up_dispatch(rank, group_ep, group_tp): x_warm_up = torch.empty(size=[1, h], dtype=input_dtype).uniform_(-1024, 1024).to(input_dtype).npu() expert_ids_warm_up = torch.arange(0, k, dtype=torch.int32).unsqueeze(0).npu() dispatch_kwargs_before = get_dispatch_kwargs_warmup( x_warm_up=x_warm_up, expert_ids_warm_up=expert_ids_warm_up, group_ep=group_ep, group_tp=group_tp, ep_rank_id=rank // tp_world_size, tp_rank_id=rank % tp_world_size, ) ( expand_x, dynamic_scales, expand_idx, expert_token_nums, ep_recv_counts, tp_recv_counts, _ ) = torch_npu.npu_moe_distribute_dispatch_v2(**dispatch_kwargs_before) return expand_x, dynamic_scales, expand_idx, expert_token_nums, ep_recv_counts, tp_recv_counts def get_dispatch_kwargs_warmup( x_warm_up, expert_ids_warm_up, group_ep, group_tp, ep_rank_id, tp_rank_id, ): x_warm_up = x_warm_up.to(input_dtype).npu() expert_ids_warm_up = expert_ids_warm_up.to(torch.int32).npu() return { 'x': x_warm_up, 'expert_ids': expert_ids_warm_up, 'x_active_mask': None, 'group_ep': group_ep, 'group_tp': group_tp, 'ep_rank_id': ep_rank_id, 'tp_rank_id': tp_rank_id, 'ep_world_size': ep_world_size, 'tp_world_size': tp_world_size, 'expert_shard_type': 0, 'shared_expert_num': 0, 'shared_expert_rank_num': shared_expert_rank_num, 'moe_expert_num': moe_expert_num, 'scales': None, 'quant_mode': 2, 'global_bs': 16, } def run_npu_process(rank): torch_npu.npu.set_device(rank) rank = rank + 16 * server_index dist.init_process_group(backend='hccl', rank=rank, world_size=world_size, init_method=f'tcp://{master_ip}:{port}') ep_group, tp_group = get_new_group(rank) ep_hcomm_info = get_hcomm_info(rank, ep_group) tp_hcomm_info = get_hcomm_info(rank, tp_group) # 创建输入tensor x = torch.randn(bs, h, dtype=input_dtype).npu() expert_ids = torch.tensor([[5, 7, 17, 4, 2, 6, 11, 16], [10, 12, 13, 15, 19, 4, 18, 1], [19, 33, 1, 17, 9, 5, 0, 32], [19, 11, 17, 0, 10, 5, 7, 9], [10, 16, 11, 17, 33, 8, 9, 3], [12, 19, 5, 7, 1, 3, 18, 16], [11, 9, 13, 16, 12, 33, 17, 14], [16, 4, 9, 5, 0, 10, 11, 17]], dtype=torch.int32).npu() expert_scales = torch.randn(bs, k, dtype=torch.float32).npu() scales_shape = (1 + moe_expert_num, h) if shared_expert_rank_num else (moe_expert_num, h) if is_dispatch_scales: scales = torch.randn(scales_shape, dtype=torch.float32).npu() else: scales = None const_expert_alpha_1 = gen_const_expert_alpha_1().npu() const_expert_alpha_2 = gen_const_expert_alpha_2().npu() const_expert_v = gen_const_expert_v().npu() out = warm_up_dispatch(rank, ep_hcomm_info, tp_hcomm_info) expand_x, dynamic_scales, assist_info_for_combine, expert_token_nums, ep_recv_counts, tp_recv_counts, expand_scales = torch_npu.npu_moe_distribute_dispatch_v2( x=x, expert_ids=expert_ids, group_ep=ep_hcomm_info, group_tp=tp_hcomm_info, ep_world_size=ep_world_size, tp_world_size=tp_world_size, ep_rank_id=rank // tp_world_size, tp_rank_id=rank % tp_world_size, expert_shard_type=0, shared_expert_rank_num=shared_expert_rank_num, moe_expert_num=moe_expert_num, scales=scales, quant_mode=quant_mode, global_bs=globalBS, zero_expert_num=zero_expert_num, copy_expert_num=copy_expert_num, const_expert_num=const_expert_num) if is_quant: expand_x = expand_x.to(input_dtype) x = torch_npu.npu_moe_distribute_combine_v2(expand_x=expand_x, expert_ids=expert_ids, assist_info_for_combine=assist_info_for_combine, ep_send_counts=ep_recv_counts, tp_send_counts=tp_recv_counts, expert_scales=expert_scales, group_ep=ep_hcomm_info, group_tp=tp_hcomm_info, ep_world_size=ep_world_size, tp_world_size=tp_world_size, ep_rank_id=rank // tp_world_size, tp_rank_id=rank % tp_world_size, expert_shard_type=0, shared_expert_rank_num=shared_expert_rank_num, moe_expert_num=moe_expert_num, global_bs=globalBS, ori_x=x, const_expert_alpha_1=const_expert_alpha_1, const_expert_alpha_2=const_expert_alpha_2, const_expert_v=const_expert_v, zero_expert_num=zero_expert_num, copy_expert_num=copy_expert_num, const_expert_num=const_expert_num) print(f'rank {rank} epid {rank // tp_world_size} tpid {rank % tp_world_size} npu finished! \n') if __name__ == "__main__": print(f"bs={bs}") print(f"global_bs={globalBS}") print(f"shared_expert_rank_num={shared_expert_rank_num}") print(f"moe_expert_num={moe_expert_num}") print(f"k={k}") print(f"quant_mode={quant_mode}", flush=True) print(f"local_moe_expert_num={local_moe_expert_num}", flush=True) print(f"tp_world_size={tp_world_size}", flush=True) print(f"ep_world_size={ep_world_size}", flush=True) if tp_world_size != 1 and local_moe_expert_num > 1: print("unSupported tp = 2 and local moe > 1") exit(0) if shared_expert_rank_num > ep_world_size: print("shared_expert_rank_num 不能大于 ep_world_size") exit(0) if shared_expert_rank_num > 0 and ep_world_size % shared_expert_rank_num != 0: print("ep_world_size 必须是 shared_expert_rank_num的整数倍") exit(0) if moe_expert_num % moe_rank_num != 0: print("moe_expert_num 必须是 moe_rank_num 的整数倍") exit(0) p_list = [] for rank in range(rank_per_dev): p = Process(target=run_npu_process, args=(rank,)) p_list.append(p) for p in p_list: p.start() for p in p_list: p.join() print("run npu success.") -
图模式调用
# 仅支持静态图 import os import torch import random import torch_npu import torchair import numpy as np from torch.multiprocessing import Process import torch.distributed as dist from torch.distributed import ReduceOp import time # 控制模式 quant_mode = 2 # 2为动态量化 is_dispatch_scales = True # 动态量化可选择是否传scales input_dtype = torch.bfloat16 # 输出dtype server_num = 1 server_index = 0 port = 50001 master_ip = '127.0.0.1' dev_num = 16 world_size = server_num * dev_num rank_per_dev = int(world_size / server_num) # 每个host有几个die shared_expert_rank_num = 0 # 共享专家数 moe_expert_num = 32 # moe专家数 bs = 8 # token数量 h = 7168 # 每个token的长度 k = 8 random_seed = 0 tp_world_size = 1 ep_world_size = int(world_size / tp_world_size) moe_rank_num = ep_world_size - shared_expert_rank_num local_moe_expert_num = moe_expert_num // moe_rank_num globalBS = bs * ep_world_size is_shared = (shared_expert_rank_num > 0) is_quant = (quant_mode > 0) zero_expert_num = 1 copy_expert_num = 1 const_expert_num = 1 class MOE_DISTRIBUTE_GRAPH_Model(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, expert_ids, group_ep, group_tp, ep_world_size, tp_world_size, ep_rank_id, tp_rank_id, expert_shard_type, shared_expert_rank_num, moe_expert_num, scales, quant_mode, global_bs, expert_scales, elastic_info, const_expert_alpha_1, const_expert_alpha_2, const_expert_v, zero_expert_num, copy_expert_num, const_expert_num): output_dispatch_npu = torch_npu.npu_moe_distribute_dispatch_v2( x=x, expert_ids=expert_ids, group_ep=group_ep, group_tp=group_tp, ep_world_size=ep_world_size, tp_world_size=tp_world_size, ep_rank_id=ep_rank_id, tp_rank_id=tp_rank_id, expert_shard_type=expert_shard_type, shared_expert_rank_num=shared_expert_rank_num, moe_expert_num=moe_expert_num, scales=scales, quant_mode=quant_mode, global_bs=global_bs, elastic_info=elastic_info, zero_expert_num=zero_expert_num, copy_expert_num=copy_expert_num, const_expert_num=const_expert_num ) expand_x_npu, _, assist_info_for_combine_npu, _, ep_recv_counts_npu, tp_recv_counts_npu, expand_scales = output_dispatch_npu if expand_x_npu.dtype == torch.int8: expand_x_npu = expand_x_npu.to(input_dtype) output_combine_npu = torch_npu.npu_moe_distribute_combine_v2( expand_x=expand_x_npu, expert_ids=expert_ids, assist_info_for_combine=assist_info_for_combine_npu, ep_send_counts=ep_recv_counts_npu, tp_send_counts=tp_recv_counts_npu, expert_scales=expert_scales, group_ep=group_ep, group_tp=group_tp, ep_world_size=ep_world_size, tp_world_size=tp_world_size, ep_rank_id=ep_rank_id, tp_rank_id=tp_rank_id, expert_shard_type=expert_shard_type, shared_expert_rank_num=shared_expert_rank_num, moe_expert_num=moe_expert_num, global_bs=global_bs, elastic_info=elastic_info, ori_x=x, const_expert_alpha_1=const_expert_alpha_1, const_expert_alpha_2=const_expert_alpha_2, const_expert_v=const_expert_v, zero_expert_num=zero_expert_num, copy_expert_num=copy_expert_num, const_expert_num=const_expert_num ) x = output_combine_npu x_combine_res = output_combine_npu return [x_combine_res, output_combine_npu] def gen_const_expert_alpha_1(): const_expert_alpha_1 = torch.empty(size=[const_expert_num, h], dtype=input_dtype).uniform_(-1, 1) return const_expert_alpha_1 def gen_const_expert_alpha_2(): const_expert_alpha_2 = torch.empty(size=[const_expert_num, h], dtype=input_dtype).uniform_(-1, 1) return const_expert_alpha_2 def gen_const_expert_v(): const_expert_v = torch.empty(size=[const_expert_num, h], dtype=input_dtype).uniform_(-1, 1) return const_expert_v def get_new_group(rank): for i in range(tp_world_size): ep_ranks = [x * tp_world_size + i for x in range(ep_world_size)] ep_group = dist.new_group(backend="hccl", ranks=ep_ranks) if rank in ep_ranks: ep_group_t = ep_group print(f"rank:{rank} ep_ranks:{ep_ranks}") for i in range(ep_world_size): tp_ranks = [x + tp_world_size * i for x in range(tp_world_size)] tp_group = dist.new_group(backend="hccl", ranks=tp_ranks) if rank in tp_ranks: tp_group_t = tp_group print(f"rank:{rank} tp_ranks:{tp_ranks}") return ep_group_t, tp_group_t def get_hcomm_info(rank, comm_group): if torch.__version__ > '2.0.1': hcomm_info = comm_group._get_backend(torch.device("npu")).get_hccl_comm_name(rank) else: hcomm_info = comm_group.get_hccl_comm_name(rank) return hcomm_info def warm_up_dispatch(rank, group_ep, group_tp): x_warm_up = torch.empty(size=[1, h], dtype=input_dtype).uniform_(-1024, 1024).to(input_dtype).npu() expert_ids_warm_up = torch.arange(0, k, dtype=torch.int32).unsqueeze(0).npu() dispatch_kwargs_before = get_dispatch_kwargs_warmup( x_warm_up=x_warm_up, expert_ids_warm_up=expert_ids_warm_up, group_ep=group_ep, group_tp=group_tp, ep_rank_id=rank//tp_world_size, tp_rank_id=rank%tp_world_size, ) ( expand_x, dynamic_scales, expand_idx, expert_token_nums, ep_recv_counts, tp_recv_counts, _ ) = torch_npu.npu_moe_distribute_dispatch_v2(**dispatch_kwargs_before) return expand_x, dynamic_scales, expand_idx, expert_token_nums, ep_recv_counts, tp_recv_counts def get_dispatch_kwargs_warmup( x_warm_up, expert_ids_warm_up, group_ep, group_tp, ep_rank_id, tp_rank_id, ): x_warm_up = x_warm_up.to(input_dtype).npu() expert_ids_warm_up = expert_ids_warm_up.to(torch.int32).npu() return { 'x': x_warm_up, 'expert_ids': expert_ids_warm_up, 'x_active_mask': None, 'group_ep': group_ep, 'group_tp': group_tp, 'ep_rank_id': ep_rank_id, 'tp_rank_id': tp_rank_id, 'ep_world_size': ep_world_size, 'tp_world_size': tp_world_size, 'expert_shard_type': 0, 'shared_expert_num': 0, 'shared_expert_rank_num': shared_expert_rank_num, 'moe_expert_num': moe_expert_num, 'scales': None, 'quant_mode': 2, 'global_bs': 16, } def run_npu_process(rank): torch_npu.npu.set_device(rank) rank = rank + 16 * server_index dist.init_process_group( backend='hccl', rank=rank, world_size=world_size, init_method=f'tcp://{master_ip}:{port}' ) ep_group, tp_group = get_new_group(rank) ep_hcomm_info = get_hcomm_info(rank, ep_group) tp_hcomm_info = get_hcomm_info(rank, tp_group) # 创建输入tensor x = torch.randn(bs, h, dtype=input_dtype).npu() expert_ids = torch.tensor([ [0, 8, 4, 1, 6, 12, 14, 17], [14, 10, 7, 3, 0, 12, 11, 17], [12, 0, 5, 11, 19, 4, 6, 18], [17, 3, 4, 10, 18, 0, 1, 2], [13, 16, 9, 10, 15, 6, 7, 14], [17, 15, 14, 8, 16, 18, 3, 12], [4, 12, 2, 17, 15, 3, 9, 10], [16, 7, 12, 9, 18, 3, 19, 17] ], dtype=torch.int32).npu() expert_scales = torch.randn(bs, k, dtype=torch.float32).npu() scales_shape = (1 + moe_expert_num, h) if shared_expert_rank_num else (moe_expert_num, h) if is_dispatch_scales: scales = torch.randn(scales_shape, dtype=torch.float32).npu() else: scales = None elastic_info = None available_ranks = [1, 2, 3, 5, 7, 9, 10, 11, 13, 14] const_expert_alpha_1 = gen_const_expert_alpha_1().npu() const_expert_alpha_2 = gen_const_expert_alpha_2().npu() const_expert_v = gen_const_expert_v().npu() out = warm_up_dispatch(rank, ep_hcomm_info, tp_hcomm_info) model = MOE_DISTRIBUTE_GRAPH_Model() model = model.npu() npu_backend = torchair.get_npu_backend() model = torch.compile(model, backend=npu_backend, dynamic=False) output = model.forward( x, expert_ids, ep_hcomm_info, tp_hcomm_info, ep_world_size, tp_world_size, rank // tp_world_size, rank % tp_world_size, 0, shared_expert_rank_num, moe_expert_num, scales, quant_mode, globalBS, expert_scales, elastic_info, const_expert_alpha_1, const_expert_alpha_2, const_expert_v, zero_expert_num, copy_expert_num, const_expert_num ) torch.npu.synchronize() print(f'rank {rank} epid {rank // tp_world_size} tpid {rank % tp_world_size} npu finished! \n') time.sleep(10) if __name__ == "__main__": print(f"bs={bs}") print(f"global_bs={globalBS}") print(f"shared_expert_rank_num={shared_expert_rank_num}") print(f"moe_expert_num={moe_expert_num}") print(f"k={k}") print(f"quant_mode={quant_mode}", flush=True) print(f"local_moe_expert_num={local_moe_expert_num}", flush=True) print(f"tp_world_size={tp_world_size}", flush=True) print(f"ep_world_size={ep_world_size}", flush=True) if tp_world_size != 1 and local_moe_expert_num > 1: print("unSupported tp = 2 and local moe > 1") exit(0) if shared_expert_rank_num > ep_world_size: print("shared_expert_rank_num 不能大于 ep_world_size") exit(0) if shared_expert_rank_num > 0 and ep_world_size % shared_expert_rank_num != 0: print("ep_world_size 必须是 shared_expert_rank_num的整数倍") exit(0) if moe_expert_num % moe_rank_num != 0: print("moe_expert_num 必须是 moe_rank_num 的整数倍") exit(0) p_list = [] for rank in range(rank_per_dev): p = Process(target=run_npu_process, args=(rank,)) p_list.append(p) for p in p_list: p.start() for p in p_list: p.join() print("run npu success.")