torch_npu.npu_moe_distribute_dispatch_v2

产品支持情况

产品 是否支持
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品

功能说明

  • API功能: 需与torch_npu.npu_moe_distribute_combine_v2torch_npu.npu_moe_distribute_combine_add_rms_norm配套使用,完成MoE的并行部署下的token dispatch_v2与combine_v2。

    • 支持动态量化场景,对token数据先进行量化(可选),再进行EP(Expert Parallelism)域的alltoallv通信,再进行TP(Tensor Parallelism)域的allgatherv通信(可选);
    • 支持特殊专家场景。
  • 相较于npu_moe_distribute_dispatch接口,该接口变更如下:

    • npu_moe_distribute_dispatch中shape为(Bs * K,)的返回值expand_idx替换为shape为(A * 128,)的assist_info_for_combine,以包含更详细的token信息辅助torch_npu.npu_moe_distribute_combine_v2高效地进行全卡同步;
    • 新增输入参数comm_alg,可用于代替HCCL_INTRA_PCIE_ENABLE和HCCL_INTRA_ROCE_ENABLE环境变量。
  • 计算公式:

    • 动态量化场景:

      quant_mode不为2,即非动态量化场景:

       quant_out={ x,if quant_mode=0 CastToInt8( CastToFp32(x)× scales),if quant_mode≠0\ quant\_out= \begin{cases} \ x, & \quad \text{if}\ quant\_mode = 0 \\ \ CastToInt8(\ CastToFp32(x) \times \ scales ), & \quad \text{if } quant\_mode ≠ 0 \\ \end{cases}

       alltoall_x_out= alltoallv( quant_out)\ alltoall\_x\_out= \ alltoallv(\ quant\_out)

       expand_x={ allgatherv(alltoall_x_out), 有TP通信域 alltoall_x_out 无TP通信域\ expand\_x= \begin{cases} \ allgatherv(alltoall\_x\_out), & \quad \ 有TP通信域 \\ \ alltoall\_x\_out & \quad \ 无TP通信域 \\ \end{cases}

      quant_mode2,即动态量化场景:

       x_fp32= CastToFp32(x)× scales\ x\_fp32= \ CastToFp32(x) \times \ scales

       dynamic_scales_value=127.0/Max(Abs(x_fp32))\ dynamic\_scales\_value = 127.0/Max(Abs(x\_fp32))

       quant_out=CastToInt8( x_fp32× dynamic_scales_value)\ quant\_out=CastToInt8(\ x\_fp32 \times \ dynamic\_scales\_value )

       alltoall_x_out= alltoallv( quant_out)\ alltoall\_x\_out= \ alltoallv(\ quant\_out)

       alltoall_dynamic_scales_out=alltoall(1.0/dynamic_scales)\ alltoall\_dynamic\_scales\_out = alltoall(1.0/dynamic\_scales)

       expand_x={ allgatherv(alltoall_x_out), 有TP通信域 alltoall_x_out 无TP通信域\ expand\_x= \begin{cases} \ allgatherv(alltoall\_x\_out), & \quad \ 有TP通信域 \\ \ alltoall\_x\_out & \quad \ 无TP通信域 \\ \end{cases}

       dynamic_scales={ allgatherv(alltoall_dynamic_scales_out), 有TP通信域  alltoall_dynamic_scales_out 无TP通信域\ dynamic\_scales= \begin{cases} \ allgatherv(alltoall\_dynamic\_scales\_out), & \quad \ 有TP通信域 \\ \ \ alltoall\_dynamic\_scales\_out & \quad \ 无TP通信域 \\ \end{cases}

    • 特殊专家场景:

      零专家场景,即zero_expert_num不为0:

      Moe(ori_x)=0Moe(ori\_x)=0

      拷贝专家场景,即copy_expert_num不为0:

      Moe(ori_x)=ori_xMoe(ori\_x)=ori\_x

      常量专家场景,即const_expert_num不为0:

      Moe(ori_x)=const_expert_alpha_1∗ori_x+const_expert_alpha_2∗const_expert_vMoe(ori\_x)=const\_expert\_alpha\_1*ori\_x+const\_expert\_alpha\_2*const\_expert\_v

      参数ori_x、const_expert_alpha_1、const_expert_alpha_2、const_expert_v见torch_npu.npu_moe_distribute_combine_v2文档。

函数原型

torch_npu.npu_moe_distribute_dispatch_v2(x, expert_ids, group_ep, ep_world_size, ep_rank_id, moe_expert_num, *, scales=None, x_active_mask=None, expert_scales=None, elastic_info=None, performance_info=None, group_tp="", tp_world_size=0, tp_rank_id=0, expert_shard_type=0, shared_expert_num=1, shared_expert_rank_num=0, quant_mode=0, global_bs=0, expert_token_nums_type=1, comm_alg="", zero_expert_num=0, copy_expert_num=0, const_expert_num=0) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)

参数说明

  • x (Tensor):必选参数,表示计算使用的token数据,需根据expert_ids来发送给其他卡。要求为2维张量,shape为(BS, H),表示有BS个token,数据类型支持bfloat16float16,数据格式为NDND,支持非连续的Tensor。

  • expert_ids (Tensor):必选参数,表示每个token的topK个专家索引,决定每个token要发给哪些专家。要求为2维张量,shape为(BS, K),数据类型支持int32,数据格式为NDND,支持非连续的Tensor。对应torch_npu.npu_moe_distribute_combine_v2expert_ids输入,张量里value取值范围为[0, moe_expert_num),且同一行中的K个value不能重复。

  • group_ep (str):必选参数,EP通信域名称,专家并行的通信域。字符串长度范围为[1,128)。Atlas A3 训练系列产品/Atlas A3 推理系列产品时,group_ep不能与group_tp相同。

  • ep_world_size(int):必选参数,EP通信域size。

    • Atlas A2 训练系列产品/Atlas A2 推理系列产品:ep_world_size的取值范围如下所示。
      • comm_alg设置为"fullmesh"时,ep_world_size取值范围为16、32、64、128、256。
      • comm_alg设置为"hierarchy"时,ep_world_size取值范围为16、32、64。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:取值支持[2, 768]。
  • ep_rank_id (int):必选参数,EP通信域本卡ID,取值范围[0, ep_world_size),同一个EP通信域中各卡的ep_rank_id不重复。

  • moe_expert_num (int):必选参数,MoE专家数量,取值范围[1, 1024],并且满足以下条件:moe_expert_num%(ep_world_size - shared_expert_rank_num)=0。

    • Atlas A2 训练系列产品/Atlas A2 推理系列产品:还需满足moe_expert_num/(ep_world_size - shared_expert_rank_num) <= 24。
  • *:必选参数,代表其之前的变量是位置相关的,必须按照顺序输入;之后的变量是可选参数,位置无关,需要使用键值对赋值,不赋值会使用默认值。

  • scales (Tensor):可选参数,表示每个专家的权重,非量化场景不传,动态量化场景可传可不传。若传值要求为2维张量,如果有共享专家,shape为(shared_expert_num+moe_expert_num, H),如果没有共享专家,shape为(moe_expert_num, H),数据类型支持float,数据格式为NDND,不支持非连续的Tensor。

  • x_active_mask (Tensor):可选参数,表示token是否参与通信。

    • Atlas A2 训练系列产品/Atlas A2 推理系列产品:
      • comm_alg设置为"fullmesh"时,要求是一个1维或者2维张量。当输入为1维时,shape为(BS, ); 当输入为2维时,shape为(BS, K)。数据类型支持bool,数据格式要求为NDND,支持非连续的Tensor。当输入为1维时,参数为true表示对应的token参与通信,true必须排到false之前,例:{true, false, true} 为非法输入;当输入为2D时,参数为true表示当前token对应的expert_ids参与通信,若当前token对应的K个bool值全为false,表示当前token不会参与通信。默认所有token都会参与通信。当每张卡的BS数量不一致时,所有token必须全部有效。支持2维张量属于零计算专家特性,此特性尚在实验阶段,请谨慎使用。
      • comm_alg设置为"hierarchy"时,当前版本不支持,使用默认值None即可。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:要求是一个1维或者2维张量。当输入为1维时,shape为(BS, ); 当输入为2维时,shape为(BS, K)。数据类型支持bool,数据格式要求为NDND,支持非连续的Tensor。当输入为1维时,参数为true表示对应的token参与通信,true必须排到false之前,例:{true, false, true} 为非法输入;当输入为2D时,参数为true表示当前token对应的expert_ids参与通信,若当前token对应的K个bool值全为false,表示当前token不会参与通信。默认所有token都会参与通信。当每张卡的BS数量不一致时,所有token必须全部有效。
  • expert_scales (Tensor):可选参数,表示每个token的topK个专家权重。

    • Atlas A2 训练系列产品/Atlas A2 推理系列产品:要求为2维张量,shape为(BS, K),数据类型支持float,数据格式为NDND,支持非连续的Tensor。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:暂不支持该参数,使用默认值即可。
  • elastic_info (Tensor):预留参数,当前版本不支持,传默认值None即可。

  • performance_info (Tensor):可选参数,表示本卡等待各卡数据的通信时间,单位为us(微秒)。单次算子调用各卡通信耗时会累加到该Tensor上,算子内部不进行自动清零,因此用户每次启用此Tensor开始记录耗时前需对Tensor清零。

    • Atlas A2 训练系列产品/Atlas A2 推理系列产品:可选择传入有效数据或填None,传入None时表示不使能记录通信耗时功能;当传入有效数据时,要求是一个1D的Tensor,shape为(ep_world_size,),数据类型支持int64;数据格式要求为ND,支持非连续的Tensor。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:预留参数,当前版本不支持,传默认值None即可。
  • group_tp (string):可选参数,TP通信域名称,数据并行的通信域。若有TP域通信需要传参,若无TP域通信,使用默认值""即可。

    • Atlas A2 训练系列产品/Atlas A2 推理系列产品:eager模式使用默认值即可,图模式传入与group_ep相同。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:字符串长度范围为[1, 128),不能和group_ep相同。
  • tp_world_size (int):可选参数,TP通信域size。有TP域通信才需要传参。

    • Atlas A2 训练系列产品/Atlas A2 推理系列产品:不支持TP域通信,使用默认值0即可。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:当有TP域通信时,取值范围[0, 2],0和1表示无TP域通信,2表示有TP域通信。
  • tp_rank_id (int):可选参数,TP通信域本卡ID。有TP域通信才需要传参。

    • Atlas A2 训练系列产品/Atlas A2 推理系列产品:不支持TP域通信,使用默认值即可。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:当有TP域通信时,取值范围[0, 1],默认为0,同一个TP通信域中各卡的tp_rank_id不重复。无TP域通信时,传0即可。
  • expert_shard_type (int):可选参数,表示共享专家卡排布类型。

    • Atlas A2 训练系列产品/Atlas A2 推理系列产品:暂不支持该参数,使用默认值即可。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:当前仅支持0,表示共享专家卡排在MoE专家卡前面。
  • shared_expert_num (int):可选参数,表示共享专家数量,一个共享专家可以复制部署到多个卡上。

    • Atlas A2 训练系列产品/Atlas A2 推理系列产品:暂不支持该参数,使用默认值即可。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:取值范围[0, 4],0表示无共享专家,默认值为1。
  • shared_expert_rank_num (int):可选参数,表示共享专家卡数量。

    • Atlas A2 训练系列产品/Atlas A2 推理系列产品:不支持共享专家,使用默认值即可。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:取值范围[0, ep_world_size)。取0表示无共享专家,不取0需满足shared_expert_rank_num%shared_expert_num=0。
  • quant_mode (int):可选参数,表示量化模式。支持取值:0表示非量化(默认),2表示动态量化。当quant_mode为2,dynamic_scales不为None;当quant_mode为0,dynamic_scales为None。

  • global_bs (int):可选参数,表示EP域全局的batch size大小。当每个rank的BS不同时,支持传入max_bs*ep_world_size,其中max_bs表示单rank BS最大值;当每个rank的BS相同时,支持取值0或BS*ep_world_size。

  • expert_token_nums_type (int):可选参数,表示输出expert_token_nums的值类型,取值范围[0, 1],0表示每个专家收到token数量的前缀和,1表示每个专家收到的token数量(默认)。

  • comm_alg (string):可选参数,表示通信亲和内存布局算法。

    • Atlas A2 训练系列产品/Atlas A2 推理系列产品:当前版本支持"","fullmesh","hierarchy"三种输入方式。推荐配置"hierarchy"并搭配25.0.RC1.1及以上版本驱动使用。
      • "": 配置HCCL_INTRA_PCIE_ENABLE=1和HCCL_INTRA_ROCE_ENABLE=0时,调用"hierarchy"算法,否则调用"fullmesh"算法。不推荐使用该方式。
      • "fullmesh": token数据直接通过RDMA方式发往topk个目标专家所在的卡。
      • "hierarchy": token数据经过跨机、机内两次发送,仅不同server同号卡之间使用RDMA通信,server内使用HCCS通信。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:当前版本仅支持""、"fullmesh_v1"、"fullmesh_v2"三种输入方式。
      • "":默认值,使能fullmesh_v1模板;
      • "fullmesh_v1":使能fullmesh_v1模板;
      • "fullmesh_v2":使能fullmesh_v2模板,其中fullmesh_v2模板仅在tp_world_size取值为1时生效,且不支持各卡Bs不一致、使能x_active_mask和使能特殊专家等场景。
  • zero_expert_num (int):可选参数,表示零专家的数量。

    • Atlas A2 训练系列产品/Atlas A2 推理系列产品:
      • comm_alg设置为"fullmesh"时,取值范围[0, MAX_INT32),MAX_INT32 = 2^31 - 1,合法的零专家的ID值是[moe_expert_num, moe_expert_num+zero_expert_num)。参数为非0时属于零计算专家特性,此特性尚在实验阶段,请谨慎使用。
      • comm_alg设置为"hierarchy"时,当前版本不支持,传0即可。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:取值范围[0, MAX_INT32),MAX_INT32 = 2^31 - 1,合法的零专家的ID值是[moe_expert_num, moe_expert_num+zero_expert_num),当comm_alg设置为"fullmesh_v2"时,当前版本不支持,传0即可。
  • copy_expert_num (int):可选参数,表示拷贝专家的数量。

    • Atlas A2 训练系列产品/Atlas A2 推理系列产品:
      • comm_alg设置为"fullmesh"时,取值范围[0, MAX_INT32),MAX_INT32 = 2^31 - 1,合法的拷贝专家的ID值是[moe_expert_num+zero_expert_num, moe_expert_num+zero_expert_num+copy_expert_num)。参数为非0时属于零计算专家特性,此特性尚在实验阶段,请谨慎使用。
      • comm_alg设置为"hierarchy"时,当前版本不支持,传0即可。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:取值范围[0, MAX_INT32),MAX_INT32 = 2^31 - 1,合法的拷贝专家的ID值是[moe_expert_num+zero_expert_num, moe_expert_num+zero_expert_num+copy_expert_num),当comm_alg设置为"fullmesh_v2"时,当前版本不支持,传0即可。
  • const_expert_num (int):可选参数,表示常量专家的数量。

    • Atlas A2 训练系列产品/Atlas A2 推理系列产品:当前版本不支持,传0即可。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:取值范围[0, MAX_INT32),MAX_INT32 = 2^31 - 1,合法的常量专家的ID值是[moe_expert_num+zero_expert_num+copy_expert_num, moe_expert_num+zero_expert_num+copy_expert_num+const_expert_num),当comm_alg设置为"fullmesh_v2"时,当前版本不支持,传0即可。

输出说明

  • expand_x (Tensor):表示本卡收到的token数据,要求为2维张量,shape为(max(tp_world_size, 1) *A, H),A表示在EP通信域可能收到的最大token数,数据类型支持bfloat16float16int8。量化时类型为int8,非量化时与x数据类型保持一致。数据格式为NDND,支持非连续的Tensor。

  • dynamic_scales (Tensor):表示计算得到的动态量化参数。当quant_mode不为0时才有该输出,要求为1维张量,shape为(A,),数据类型支持float,数据格式支持NDND,支持非连续的Tensor。

  • assist_info_for_combine (Tensor):表示给同一专家发送的token个数,要求是一个1维张量,shape为(A * 128, )。数据类型支持int32,数据格式为NDND,支持非连续的Tensor。对应torch_npu.npu_moe_distribute_combine_v2assist_info_for_combine输入。

  • expert_token_nums (Tensor):本卡每个专家实际收到的token数量,要求为1维张量,shape为(local_expert_num,),数据类型int64,数据格式支持NDND,支持非连续的Tensor。

  • ep_recv_counts (Tensor):表示EP通信域各卡收到的token数(token数以前缀和的形式表示),要求为1维张量,数据类型int32,数据格式支持NDND,支持非连续的Tensor。对应torch_npu.npu_moe_distribute_combine_v2ep_send_counts输入。

    • Atlas A2 训练系列产品/Atlas A2 推理系列产品:要求shape为(moe_expert_num+2*global_bs*K*server_num, ),前moe_expert_num个数表示在EP通信域内,该卡上每个专家收到来自其他各卡的token数(以前缀和的形式表示),2*global_bs*K*server_num用于存储机间和机内通信前,combine可提前做reduce操作的token个数和通信区偏移量,global_bs传入0时此处按照bs*ep_world_size计算。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:要求shape为(ep_world_size*max(tp_world_size, 1)*local_expert_num, )。
  • tp_recv_counts (Tensor):表示TP通信域各卡收到的token数量。对应torch_npu.npu_moe_distribute_combine_v2tp_send_counts输入。

    • Atlas A2 训练系列产品/Atlas A2 推理系列产品:不支持TP通信域,暂无该输出,
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:支持TP通信域,要求是一个1D Tensor,shape为(tp_world_size, ),数据类型支持int32,数据格式为NDND,支持非连续的Tensor。
  • expand_scales (Tensor):表示expert_scalesx一起进行alltoallv之后的输出。

    • Atlas A2 训练系列产品/Atlas A2 推理系列产品:要求是一个1维张量,shape为(A, ),数据类型支持float,数据格式要求为NDND,支持非连续的Tensor。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:暂不支持该输出,返回None。

约束说明

  • 该接口支持推理场景下使用。

  • 该接口支持静态图模式,npu_moe_distribute_dispatch_v2npu_moe_distribute_combine_v2必须配套使用。

  • 在不同产品型号、不同通信算法或不同版本中,npu_moe_distribute_dispatch_v2的Tensor输出assist_info_for_combineep_recv_countstp_recv_countsexpand_scales中的元素值可能不同,使用时直接将上述Tensor传给npu_moe_distribute_combine_v2对应参数即可,模型其他业务逻辑不应对其存在依赖。

  • 调用接口过程中使用的group_epep_world_sizemoe_expert_numgroup_tptp_world_sizeexpert_shard_typeshared_expert_numshared_expert_rank_numglobal_bs参数取值所有卡需保持一致,group_epep_world_sizegroup_tptp_world_sizeexpert_shard_typeglobal_bs网络不同层中也需保持一致,且和torch_npu.npu_moe_distribute_combine_v2对应参数也保持一致。

  • Atlas A3 训练系列产品/Atlas A3 推理系列产品:该场景下单卡包含双DIE(简称为“晶粒”或“裸片”),因此参数说明里的“本卡”均表示单DIE。

  • moe_expert_num + zero_expert_num + copy_expert_num + const_expert_num < MAX_INT32。

  • 参数里Shape使用的变量如下:

    • A:表示本卡接收的最大token数量,取值范围如下

      • 对于共享专家,要满足A=BS*shared_expert_num/shared_expert_rank_num。
      • 对于MoE专家,当global_bs为0时,要满足A>=BS*ep_world_size*min(local_expert_num, K);当global_bs不为0时,要满足A>=global_bs* min(local_expert_num, K)。
    • H:表示hidden size隐藏层大小。

      • Atlas A2 训练系列产品/Atlas A2 推理系列产品:H的取值范围如下所示。
        • comm_alg设置为"fullmesh"时,H的取值范围(0, 7168],且保证是32的整数倍。
        • comm_alg设置为"hierarchy"且驱动版本不低于25.0.RC1.1时,H的取值范围(0, 10 * 1024],且保证是32的整数倍。
      • Atlas A3 训练系列产品/Atlas A3 推理系列产品:取值为[1024, 8192]。
    • BS:表示batch sequence size,即本卡最终输出的token数量。

      • Atlas A2 训练系列产品/Atlas A2 推理系列产品:取值范围为0<BS≤256。
      • Atlas A3 训练系列产品/Atlas A3 推理系列产品:取值范围为0<BS≤512,当comm_alg为"fullmesh_v2"时,需满足0<BS≤256。
    • K:表示选取topK个专家,取值范围为0<K≤16,同时满足0 < K ≤ moe_expert_num + zero_expert_num + copy_expert_num + const_expert_num,当comm_alg为"fullmesh_v2"时,取值范围为0<K≤12。

    • server_num:表示服务器的节点数,取值只支持2、4、8。

      • Atlas A2 训练系列产品/Atlas A2 推理系列产品:仅该场景的shape使用了该变量。
    • local_expert_num:表示本卡专家数量。

      • 对于共享专家卡,local_expert_num为1。
      • 对于MoE专家卡,local_expert_num=moe_expert_num/(ep_world_size-shared_expert_rank_num),当local_expert_num大于1时,不支持TP域通信。
      • Atlas A3 训练系列产品/Atlas A3 推理系列产品:应满足0 < local_expert_num * ep_world_size ≤ 2048。
  • HCCL通信域缓存区大小:

    调用本接口前需检查通信域缓存区大小取值是否合理,单位MB,不配置时默认为200MB。

    • Atlas A2 训练系列产品/Atlas A2 推理系列产品: 该场景支持通过环境变量HCCL_BUFFSIZE配置。

      • comm_alg配置为"": 依照HCCL_INTRA_PCIE_ENABLE和HCCL_INTRA_ROCE_ENABLE配置选择"fullmesh"或"hierarchy"公式。
      • comm_alg配置为"fullmesh": 设置大小要求>=2*(BS*ep_world_size*min(local_expert_num, K)*H*sizeof(uint16)+2MB)。
      • comm_alg配置为"hierarchy": 设置大小要求=moe_expert_num*BS*(H*sizeof(dtype_x)+4*((K+7)/8*8)*sizeof(uint32))+4MB+100MB,不要求moe_expert_num/(ep_world_size - shared_expert_rank_num) <= 24。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品: 该场景不仅支持通过环境变量HCCL_BUFFSIZE配置,还支持通过hccl_buffer_size配置(参考《PyTorch训练模型迁移调优》中“性能调优>性能调优方法>通信优化>优化方法>hccl_buffer_size”章节)。

      • ep通信域内,comm_alg配置为"fullmesh_v1"或"": 设置大小要求 >= 2 * (local_expert_num * max_bs * ep_world_size * Align512(Align32(2 * H) + 64) + (K + shared_expert_num) * max_bs * Align512(2 * H))。
      • ep通信域内,comm_alg配置为"fullmesh_v2": 设置大小要求 >= 2 * (local_expert_num * max_bs * ep_world_size * 480Align512(Align32(2 * H) + 64) + (K + shared_expert_num) * max_bs * Align512(2 * H))。
      • tp通信域内:设置大小要求 >= (A * Align512(Align32(h * 2) + 44) + A * Align512(h * 2)) * 2。
      • 其中 480Align512(x) = ((x+480-1)/480)*512,Align512(x) = ((x+512-1)/512)*512,Align32(x) = ((x+32-1)/32)*32。
  • HCCL_INTRA_PCIE_ENABLE和HCCL_INTRA_ROCE_ENABLE:

    • Atlas A2 训练系列产品/Atlas A2 推理系列产品:该环境变量不再推荐使用,建议comm_alg配置"hierarchy"。
  • 本文公式中的“/”表示整除。

  • 通信域使用约束:

    • 一个模型中的npu_moe_distribute_dispatch_v2npu_moe_distribute_combine_v2算子仅支持相同EP通信域,且该通信域中不允许有其他算子。

    • 一个模型中的npu_moe_distribute_dispatch_v2npu_moe_distribute_combine_v2算子仅支持相同TP通信域或都不支持TP通信域,有TP通信域时该通信域中不允许有其他算子。

    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:一个通信域内的节点需在一个超节点内,不支持跨超节点。

  • 组网约束:

    • Atlas A2 训练系列产品/Atlas A2 推理系列产品:多机场景仅支持交换机组网,不支持双机直连组网。
  • 版本配套约束:

    静态图模式下,从Ascend Extension for PyTorch 8.0.0版本开始,Ascend Extension for PyTorch框架会对静态图中最后一个节点输出结果做Meta推导与inferShape推导的结果强校验。当图中只有一个Dispatch_v2算子,若CANN版本落后于Ascend Extension for PyTorch版本,会出现Shape不匹配报错,建议用户升级CANN版本,详细的版本配套关系参见《Ascend Extension for PyTorch 版本说明》中“相关产品版本配套说明”。

调用示例

  • 单算子模式调用

    import os
    import torch
    import random
    import torch_npu
    import numpy as np
    from torch.multiprocessing import Process
    import torch.distributed as dist
    from torch.distributed import ReduceOp
    
    # 控制模式
    quant_mode = 2  # 2为动态量化
    is_dispatch_scales = True  # 动态量化可选择是否传scales
    input_dtype = torch.bfloat16  # 输出dtype
    server_num = 1
    server_index = 0
    port = 50001
    master_ip = '127.0.0.1'
    dev_num = 16
    world_size = server_num * dev_num
    rank_per_dev = int(world_size / server_num)  # 每个host有几个die
    shared_expert_rank_num = 0  # 共享专家数
    moe_expert_num = 32  # moe专家数
    bs = 8  # token数量
    h = 7168  # 每个token的长度
    k = 8
    random_seed = 0
    tp_world_size = 1
    ep_world_size = int(world_size / tp_world_size)
    moe_rank_num = ep_world_size - shared_expert_rank_num
    local_moe_expert_num = moe_expert_num // moe_rank_num
    globalBS = bs * ep_world_size
    is_shared = (shared_expert_rank_num > 0)
    is_quant = (quant_mode > 0)
    zero_expert_num = 1
    copy_expert_num = 1
    const_expert_num = 1
    
    
    def gen_const_expert_alpha_1():
        const_expert_alpha_1 = torch.empty(size=[const_expert_num, h], dtype=input_dtype).uniform_(-1, 1)
        return const_expert_alpha_1
    
    
    def gen_const_expert_alpha_2():
        const_expert_alpha_2 = torch.empty(size=[const_expert_num, h], dtype=input_dtype).uniform_(-1, 1)
        return const_expert_alpha_2
    
    
    def gen_const_expert_v():
        const_expert_v = torch.empty(size=[const_expert_num, h], dtype=input_dtype).uniform_(-1, 1)
        return const_expert_v
    
    
    def get_new_group(rank):
        for i in range(tp_world_size):
            # 如果tp_world_size = 2,ep_world_size = 8,则为[[0, 2, 4, 6, 8, 10, 12, 14], [1, 3, 5, 7, 9, 11, 13, 15]]
            ep_ranks = [x * tp_world_size + i for x in range(ep_world_size)]
            ep_group = dist.new_group(backend="hccl", ranks=ep_ranks)
            if rank in ep_ranks:
                ep_group_t = ep_group
                print(f"rank:{rank} ep_ranks:{ep_ranks}")
        for i in range(ep_world_size):
            # 如果tp_world_size = 2,ep_world_size = 8,则为[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13], [14, 15]]
            tp_ranks = [x + tp_world_size * i for x in range(tp_world_size)]
            tp_group = dist.new_group(backend="hccl", ranks=tp_ranks)
            if rank in tp_ranks:
                tp_group_t = tp_group
                print(f"rank:{rank} tp_ranks:{tp_ranks}")
        return ep_group_t, tp_group_t
    
    
    def get_hcomm_info(rank, comm_group):
        if torch.__version__ > '2.0.1':
            hcomm_info = comm_group._get_backend(torch.device("npu")).get_hccl_comm_name(rank)
        else:
            hcomm_info = comm_group.get_hccl_comm_name(rank)
        return hcomm_info
    
    
    def warm_up_dispatch(rank, group_ep, group_tp):
        x_warm_up = torch.empty(size=[1, h], dtype=input_dtype).uniform_(-1024, 1024).to(input_dtype).npu()
        expert_ids_warm_up = torch.arange(0, k, dtype=torch.int32).unsqueeze(0).npu()
        dispatch_kwargs_before = get_dispatch_kwargs_warmup(
            x_warm_up=x_warm_up,
            expert_ids_warm_up=expert_ids_warm_up,
            group_ep=group_ep,
            group_tp=group_tp,
            ep_rank_id=rank // tp_world_size,
            tp_rank_id=rank % tp_world_size,
        )
        (
            expand_x, dynamic_scales, expand_idx,
            expert_token_nums, ep_recv_counts, tp_recv_counts, _
        ) = torch_npu.npu_moe_distribute_dispatch_v2(**dispatch_kwargs_before)
        return expand_x, dynamic_scales, expand_idx, expert_token_nums, ep_recv_counts, tp_recv_counts
    
    
    def get_dispatch_kwargs_warmup(
        x_warm_up, expert_ids_warm_up, group_ep, group_tp, ep_rank_id, tp_rank_id,
    ):
        x_warm_up = x_warm_up.to(input_dtype).npu()
        expert_ids_warm_up = expert_ids_warm_up.to(torch.int32).npu()
        return {
            'x': x_warm_up,
            'expert_ids': expert_ids_warm_up,
            'x_active_mask': None,
            'group_ep': group_ep,
            'group_tp': group_tp,
            'ep_rank_id': ep_rank_id,
            'tp_rank_id': tp_rank_id,
            'ep_world_size': ep_world_size,
            'tp_world_size': tp_world_size,
            'expert_shard_type': 0,
            'shared_expert_num': 0,
            'shared_expert_rank_num': shared_expert_rank_num,
            'moe_expert_num': moe_expert_num,
            'scales': None,
            'quant_mode': 2,
            'global_bs': 16,
        }
    
    
    def run_npu_process(rank):
        torch_npu.npu.set_device(rank)
        rank = rank + 16 * server_index
        dist.init_process_group(backend='hccl', rank=rank, world_size=world_size, init_method=f'tcp://{master_ip}:{port}')
        ep_group, tp_group = get_new_group(rank)
        ep_hcomm_info = get_hcomm_info(rank, ep_group)
        tp_hcomm_info = get_hcomm_info(rank, tp_group)
    
        # 创建输入tensor
        x = torch.randn(bs, h, dtype=input_dtype).npu()
        expert_ids = torch.tensor([[5, 7, 17, 4, 2, 6, 11, 16],
                                [10, 12, 13, 15, 19, 4, 18, 1],
                                [19, 33, 1, 17, 9, 5, 0, 32],
                                [19, 11, 17, 0, 10, 5, 7, 9],
                                [10, 16, 11, 17, 33, 8, 9, 3],
                                [12, 19, 5, 7, 1, 3, 18, 16],
                                [11, 9, 13, 16, 12, 33, 17, 14],
                                [16, 4, 9, 5, 0, 10, 11, 17]], dtype=torch.int32).npu()
        expert_scales = torch.randn(bs, k, dtype=torch.float32).npu()
    
        scales_shape = (1 + moe_expert_num, h) if shared_expert_rank_num else (moe_expert_num, h)
        if is_dispatch_scales:
            scales = torch.randn(scales_shape, dtype=torch.float32).npu()
        else:
            scales = None
    
        const_expert_alpha_1 = gen_const_expert_alpha_1().npu()
        const_expert_alpha_2 = gen_const_expert_alpha_2().npu()
        const_expert_v = gen_const_expert_v().npu()
    
        out = warm_up_dispatch(rank, ep_hcomm_info, tp_hcomm_info)
    
        expand_x, dynamic_scales, assist_info_for_combine, expert_token_nums, ep_recv_counts, tp_recv_counts, expand_scales = torch_npu.npu_moe_distribute_dispatch_v2(
            x=x,
            expert_ids=expert_ids,
            group_ep=ep_hcomm_info,
            group_tp=tp_hcomm_info,
            ep_world_size=ep_world_size,
            tp_world_size=tp_world_size,
            ep_rank_id=rank // tp_world_size,
            tp_rank_id=rank % tp_world_size,
            expert_shard_type=0,
            shared_expert_rank_num=shared_expert_rank_num,
            moe_expert_num=moe_expert_num,
            scales=scales,
            quant_mode=quant_mode,
            global_bs=globalBS,
            zero_expert_num=zero_expert_num,
            copy_expert_num=copy_expert_num,
            const_expert_num=const_expert_num)
    
        if is_quant:
            expand_x = expand_x.to(input_dtype)
    
        x = torch_npu.npu_moe_distribute_combine_v2(expand_x=expand_x,
                                                expert_ids=expert_ids,
                                                assist_info_for_combine=assist_info_for_combine,
                                                ep_send_counts=ep_recv_counts,
                                                tp_send_counts=tp_recv_counts,
                                                expert_scales=expert_scales,
                                                group_ep=ep_hcomm_info,
                                                group_tp=tp_hcomm_info,
                                                ep_world_size=ep_world_size,
                                                tp_world_size=tp_world_size,
                                                ep_rank_id=rank // tp_world_size,
                                                tp_rank_id=rank % tp_world_size,
                                                expert_shard_type=0,
                                                shared_expert_rank_num=shared_expert_rank_num,
                                                moe_expert_num=moe_expert_num,
                                                global_bs=globalBS,
                                                ori_x=x,
                                                const_expert_alpha_1=const_expert_alpha_1,
                                                const_expert_alpha_2=const_expert_alpha_2,
                                                const_expert_v=const_expert_v,
                                                zero_expert_num=zero_expert_num,
                                                copy_expert_num=copy_expert_num,
                                                const_expert_num=const_expert_num)
        print(f'rank {rank} epid {rank // tp_world_size} tpid {rank % tp_world_size} npu finished! \n')
    
    
    if __name__ == "__main__":
        print(f"bs={bs}")
        print(f"global_bs={globalBS}")
        print(f"shared_expert_rank_num={shared_expert_rank_num}")
        print(f"moe_expert_num={moe_expert_num}")
        print(f"k={k}")
        print(f"quant_mode={quant_mode}", flush=True)
        print(f"local_moe_expert_num={local_moe_expert_num}", flush=True)
        print(f"tp_world_size={tp_world_size}", flush=True)
        print(f"ep_world_size={ep_world_size}", flush=True)
    
        if tp_world_size != 1 and local_moe_expert_num > 1:
            print("unSupported tp = 2 and local moe > 1")
            exit(0)
        if shared_expert_rank_num > ep_world_size:
            print("shared_expert_rank_num 不能大于 ep_world_size")
            exit(0)
        if shared_expert_rank_num > 0 and ep_world_size % shared_expert_rank_num != 0:
            print("ep_world_size 必须是 shared_expert_rank_num的整数倍")
            exit(0)
        if moe_expert_num % moe_rank_num != 0:
            print("moe_expert_num 必须是 moe_rank_num 的整数倍")
            exit(0)
        p_list = []
        for rank in range(rank_per_dev):
            p = Process(target=run_npu_process, args=(rank,))
            p_list.append(p)
    
        for p in p_list:
            p.start()
        for p in p_list:
            p.join()
    
        print("run npu success.")
    
  • 图模式调用

    # 仅支持静态图
    import os
    import torch
    import random
    import torch_npu
    import torchair
    import numpy as np
    from torch.multiprocessing import Process
    import torch.distributed as dist
    from torch.distributed import ReduceOp
    import time
    
    
    # 控制模式
    quant_mode = 2  # 2为动态量化
    is_dispatch_scales = True  # 动态量化可选择是否传scales
    input_dtype = torch.bfloat16  # 输出dtype
    server_num = 1
    server_index = 0
    port = 50001
    master_ip = '127.0.0.1'
    dev_num = 16
    world_size = server_num * dev_num
    rank_per_dev = int(world_size / server_num)  # 每个host有几个die
    shared_expert_rank_num = 0  # 共享专家数
    moe_expert_num = 32  # moe专家数
    bs = 8  # token数量
    h = 7168  # 每个token的长度
    k = 8
    random_seed = 0
    tp_world_size = 1
    ep_world_size = int(world_size / tp_world_size)
    moe_rank_num = ep_world_size - shared_expert_rank_num
    local_moe_expert_num = moe_expert_num // moe_rank_num
    globalBS = bs * ep_world_size
    is_shared = (shared_expert_rank_num > 0)
    is_quant = (quant_mode > 0)
    
    zero_expert_num = 1
    copy_expert_num = 1
    const_expert_num = 1
    
    
    class MOE_DISTRIBUTE_GRAPH_Model(torch.nn.Module):
        def __init__(self):
            super().__init__()
    
        def forward(self, x, expert_ids, group_ep, group_tp, ep_world_size, tp_world_size,
                    ep_rank_id, tp_rank_id, expert_shard_type, shared_expert_rank_num, moe_expert_num,
                    scales, quant_mode, global_bs, expert_scales, elastic_info, const_expert_alpha_1, const_expert_alpha_2, const_expert_v, zero_expert_num, copy_expert_num, const_expert_num):
            output_dispatch_npu = torch_npu.npu_moe_distribute_dispatch_v2(
                x=x,
                expert_ids=expert_ids,
                group_ep=group_ep,
                group_tp=group_tp,
                ep_world_size=ep_world_size,
                tp_world_size=tp_world_size,
                ep_rank_id=ep_rank_id,
                tp_rank_id=tp_rank_id,
                expert_shard_type=expert_shard_type,
                shared_expert_rank_num=shared_expert_rank_num,
                moe_expert_num=moe_expert_num,
                scales=scales,
                quant_mode=quant_mode,
                global_bs=global_bs,
                elastic_info=elastic_info,
                zero_expert_num=zero_expert_num,
                copy_expert_num=copy_expert_num,
                const_expert_num=const_expert_num
            )
            expand_x_npu, _, assist_info_for_combine_npu, _, ep_recv_counts_npu, tp_recv_counts_npu, expand_scales = output_dispatch_npu
            if expand_x_npu.dtype == torch.int8:
                expand_x_npu = expand_x_npu.to(input_dtype)
    
            output_combine_npu = torch_npu.npu_moe_distribute_combine_v2(
                expand_x=expand_x_npu,
                expert_ids=expert_ids,
                assist_info_for_combine=assist_info_for_combine_npu,
                ep_send_counts=ep_recv_counts_npu,
                tp_send_counts=tp_recv_counts_npu,
                expert_scales=expert_scales,
                group_ep=group_ep,
                group_tp=group_tp,
                ep_world_size=ep_world_size,
                tp_world_size=tp_world_size,
                ep_rank_id=ep_rank_id,
                tp_rank_id=tp_rank_id,
                expert_shard_type=expert_shard_type,
                shared_expert_rank_num=shared_expert_rank_num,
                moe_expert_num=moe_expert_num,
                global_bs=global_bs,
                elastic_info=elastic_info,
                ori_x=x,
                const_expert_alpha_1=const_expert_alpha_1,
                const_expert_alpha_2=const_expert_alpha_2,
                const_expert_v=const_expert_v,
                zero_expert_num=zero_expert_num,
                copy_expert_num=copy_expert_num,
                const_expert_num=const_expert_num
            )
            x = output_combine_npu
            x_combine_res = output_combine_npu
            return [x_combine_res, output_combine_npu]
    
    
    def gen_const_expert_alpha_1():
        const_expert_alpha_1 = torch.empty(size=[const_expert_num, h], dtype=input_dtype).uniform_(-1, 1)
        return const_expert_alpha_1
    
    
    def gen_const_expert_alpha_2():
        const_expert_alpha_2 = torch.empty(size=[const_expert_num, h], dtype=input_dtype).uniform_(-1, 1)
        return const_expert_alpha_2
    
    
    def gen_const_expert_v():
        const_expert_v = torch.empty(size=[const_expert_num, h], dtype=input_dtype).uniform_(-1, 1)
        return const_expert_v
    
    
    def get_new_group(rank):
        for i in range(tp_world_size):
            ep_ranks = [x * tp_world_size + i for x in range(ep_world_size)]
            ep_group = dist.new_group(backend="hccl", ranks=ep_ranks)
            if rank in ep_ranks:
                ep_group_t = ep_group
                print(f"rank:{rank} ep_ranks:{ep_ranks}")
        for i in range(ep_world_size):
            tp_ranks = [x + tp_world_size * i for x in range(tp_world_size)]
            tp_group = dist.new_group(backend="hccl", ranks=tp_ranks)
            if rank in tp_ranks:
                tp_group_t = tp_group
                print(f"rank:{rank} tp_ranks:{tp_ranks}")
        return ep_group_t, tp_group_t
    
    
    def get_hcomm_info(rank, comm_group):
        if torch.__version__ > '2.0.1':
            hcomm_info = comm_group._get_backend(torch.device("npu")).get_hccl_comm_name(rank)
        else:
            hcomm_info = comm_group.get_hccl_comm_name(rank)
        return hcomm_info
    
    
    def warm_up_dispatch(rank, group_ep, group_tp):
        x_warm_up = torch.empty(size=[1, h], dtype=input_dtype).uniform_(-1024, 1024).to(input_dtype).npu()
        expert_ids_warm_up = torch.arange(0, k, dtype=torch.int32).unsqueeze(0).npu()
    
        dispatch_kwargs_before = get_dispatch_kwargs_warmup(
            x_warm_up=x_warm_up,
            expert_ids_warm_up=expert_ids_warm_up,
            group_ep=group_ep,
            group_tp=group_tp,
            ep_rank_id=rank//tp_world_size,
            tp_rank_id=rank%tp_world_size,
        )
    
        (
            expand_x, dynamic_scales, expand_idx,
            expert_token_nums, ep_recv_counts, tp_recv_counts, _
        ) = torch_npu.npu_moe_distribute_dispatch_v2(**dispatch_kwargs_before)
        return expand_x, dynamic_scales, expand_idx, expert_token_nums, ep_recv_counts, tp_recv_counts
    
    
    def get_dispatch_kwargs_warmup(
        x_warm_up, expert_ids_warm_up, group_ep, group_tp, ep_rank_id, tp_rank_id,
    ):
        x_warm_up = x_warm_up.to(input_dtype).npu()
        expert_ids_warm_up = expert_ids_warm_up.to(torch.int32).npu()
    
        return {
            'x': x_warm_up,
            'expert_ids': expert_ids_warm_up,
            'x_active_mask': None,
            'group_ep': group_ep,
            'group_tp': group_tp,
            'ep_rank_id': ep_rank_id,
            'tp_rank_id': tp_rank_id,
            'ep_world_size': ep_world_size,
            'tp_world_size': tp_world_size,
            'expert_shard_type': 0,
            'shared_expert_num': 0,
            'shared_expert_rank_num': shared_expert_rank_num,
            'moe_expert_num': moe_expert_num,
            'scales': None,
            'quant_mode': 2,
            'global_bs': 16,
        }
    
    
    def run_npu_process(rank):
        torch_npu.npu.set_device(rank)
        rank = rank + 16 * server_index
        dist.init_process_group(
            backend='hccl',
            rank=rank,
            world_size=world_size,
            init_method=f'tcp://{master_ip}:{port}'
        )
        ep_group, tp_group = get_new_group(rank)
        ep_hcomm_info = get_hcomm_info(rank, ep_group)
        tp_hcomm_info = get_hcomm_info(rank, tp_group)
    
        # 创建输入tensor
        x = torch.randn(bs, h, dtype=input_dtype).npu()
        expert_ids = torch.tensor([
            [0, 8, 4, 1, 6, 12, 14, 17],
            [14, 10, 7, 3, 0, 12, 11, 17],
            [12, 0, 5, 11, 19, 4, 6, 18],
            [17, 3, 4, 10, 18, 0, 1, 2],
            [13, 16, 9, 10, 15, 6, 7, 14],
            [17, 15, 14, 8, 16, 18, 3, 12],
            [4, 12, 2, 17, 15, 3, 9, 10],
            [16, 7, 12, 9, 18, 3, 19, 17]
        ], dtype=torch.int32).npu()
    
        expert_scales = torch.randn(bs, k, dtype=torch.float32).npu()
        scales_shape = (1 + moe_expert_num, h) if shared_expert_rank_num else (moe_expert_num, h)
        if is_dispatch_scales:
            scales = torch.randn(scales_shape, dtype=torch.float32).npu()
        else:
            scales = None
    
        elastic_info = None
        available_ranks = [1, 2, 3, 5, 7, 9, 10, 11, 13, 14]
        const_expert_alpha_1 = gen_const_expert_alpha_1().npu()
        const_expert_alpha_2 = gen_const_expert_alpha_2().npu()
        const_expert_v = gen_const_expert_v().npu()
        out = warm_up_dispatch(rank, ep_hcomm_info, tp_hcomm_info)
    
        model = MOE_DISTRIBUTE_GRAPH_Model()
        model = model.npu()
        npu_backend = torchair.get_npu_backend()
        model = torch.compile(model, backend=npu_backend, dynamic=False)
        output = model.forward(
            x, expert_ids, ep_hcomm_info, tp_hcomm_info, ep_world_size, tp_world_size,
            rank // tp_world_size, rank % tp_world_size, 0, shared_expert_rank_num, moe_expert_num, scales,
            quant_mode, globalBS, expert_scales, elastic_info, const_expert_alpha_1, const_expert_alpha_2, const_expert_v,
            zero_expert_num, copy_expert_num, const_expert_num
        )
        torch.npu.synchronize()
        print(f'rank {rank} epid {rank // tp_world_size} tpid {rank % tp_world_size} npu finished! \n')
    
        time.sleep(10)
    
    
    if __name__ == "__main__":
        print(f"bs={bs}")
        print(f"global_bs={globalBS}")
        print(f"shared_expert_rank_num={shared_expert_rank_num}")
        print(f"moe_expert_num={moe_expert_num}")
        print(f"k={k}")
        print(f"quant_mode={quant_mode}", flush=True)
        print(f"local_moe_expert_num={local_moe_expert_num}", flush=True)
        print(f"tp_world_size={tp_world_size}", flush=True)
        print(f"ep_world_size={ep_world_size}", flush=True)
    
        if tp_world_size != 1 and local_moe_expert_num > 1:
            print("unSupported tp = 2 and local moe > 1")
            exit(0)
    
        if shared_expert_rank_num > ep_world_size:
            print("shared_expert_rank_num 不能大于 ep_world_size")
            exit(0)
    
        if shared_expert_rank_num > 0 and ep_world_size % shared_expert_rank_num != 0:
            print("ep_world_size 必须是 shared_expert_rank_num的整数倍")
            exit(0)
    
        if moe_expert_num % moe_rank_num != 0:
            print("moe_expert_num 必须是 moe_rank_num 的整数倍")
            exit(0)
    
        p_list = []
        for rank in range(rank_per_dev):
            p = Process(target=run_npu_process, args=(rank,))
            p_list.append(p)
    
        for p in p_list:
            p.start()
        for p in p_list:
            p.join()
    
        print("run npu success.")