torch_npu.npu_moe_distribute_combine
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
功能说明
-
API功能:先进行reduce_scatterv通信,再进行alltoallv通信,最后将接收的数据整合(乘权重再相加)。需与torch_npu.npu_moe_distribute_dispatch配套使用,相当于按npu_moe_distribute_dispatch算子收集数据的路径原路返回。
-
计算公式:
rs_out=ReduceScatterV(expend_x)ata_out=AlltoAllv(rs_out)x=Sum(expert_scales∗ata_out+expert_scales∗shared_expert_x)rs\_out = ReduceScatterV(expend\_x)\\ ata\_out = AlltoAllv(rs\_out)\\ x = Sum(expert\_scales * ata\_out + expert\_scales * shared\_expert\_x)
函数原型
torch_npu.npu_moe_distribute_combine(expand_x, expert_ids, expand_idx, ep_send_counts, expert_scales, group_ep, ep_world_size, ep_rank_id, moe_expert_num, *, tp_send_counts=None, x_active_mask=None, activation_scale=None, weight_scale=None, group_list=None, expand_scales=None, shared_expert_x=None, group_tp="", tp_world_size=0, tp_rank_id=0, expert_shard_type=0, shared_expert_num=1, shared_expert_rank_num=0, global_bs=0, out_dtype=0, comm_quant_mode=0, group_list_type=0) -> Tensor
参数说明
-
expand_x (
Tensor):必选参数。根据expert_ids进行扩展过的token特征,要求为2维张量,shape为(max(tp_world_size, 1) *A, H),数据类型支持bfloat16、float16,数据格式为NDND,支持非连续的Tensor。- Atlas A2 训练系列产品/Atlas A2 推理系列产品:不支持共享专家场景。
-
expert_ids (
Tensor):必选参数。每个token的topK个专家索引,要求为2维张量,shape为(BS, K)。数据类型支持int32,数据格式为NDND,支持非连续的Tensor。对应torch_npu.npu_moe_distribute_dispatch的expert_ids输入,张量里value取值范围为[0, moe_expert_num),且同一行中的K个value不能重复。 -
expand_idx (
Tensor):必选参数。表示给同一专家发送的token个数,要求为1维张量。数据类型支持int32,数据格式为NDND,支持非连续的Tensor。对应torch_npu.npu_moe_distribute_dispatch的expand_idx输出。- Atlas A2 训练系列产品/Atlas A2 推理系列产品:要求shape为(BS * K,)。
- Atlas A3 训练系列产品/Atlas A3 推理系列产品:要求shape为(BS * K,)。
-
ep_send_counts (
Tensor):必选参数。示本卡每个专家发给EP(Expert Parallelism)域每个卡的token数(token数以前缀和的形式表示),要求为1维张量。数据类型支持int32,数据格式为NDND,支持非连续的Tensor。对应torch_npu.npu_moe_distribute_dispatch的ep_recv_counts输出。- Atlas A2 训练系列产品/Atlas A2 推理系列产品:要求shape为(moe_expert_num+2*global_bs*K*server_num,),前
moe_expert_num个数表示在EP通信域内,该卡上每个专家收到来自其他各卡的token数(以前缀和的形式表示),2*global_bs*K*server_num用于存储机间和机内通信前,combine可提前做reduce操作的token个数和通信区偏移量,global_bs传入0时此处按照bs*ep_world_size计算。 - Atlas A3 训练系列产品/Atlas A3 推理系列产品:要求shape为(ep_world_size*max(tp_world_size, 1)*local_expert_num,)。
- Atlas A2 训练系列产品/Atlas A2 推理系列产品:要求shape为(moe_expert_num+2*global_bs*K*server_num,),前
-
expert_scales (
Tensor):必选参数。表示每个token的topK个专家的权重,要求为2维张量,shape为(BS, K),其中共享专家不需要乘权重系数,直接相加即可。数据类型支持float,数据格式为NDND,支持非连续的Tensor。 -
group_ep (
str):必选参数。EP通信域名称,专家并行的通信域。字符串长度范围为[1, 128)。Atlas A3 训练系列产品/Atlas A3 推理系列产品时不能和group_tp相同。 -
ep_world_size (
int):必选参数,EP通信域size。- Atlas A2 训练系列产品/Atlas A2 推理系列产品:取值支持16、32、64。
- Atlas A3 训练系列产品/Atlas A3 推理系列产品:取值支持8、16、32、64、128、144、256、288。
-
ep_rank_id (
int):必选参数,EP通信域本卡ID,取值范围[0, ep_world_size),同一个EP通信域中各卡的ep_rank_id不重复。 -
moe_expert_num (
int):必选参数,MoE专家数量,取值范围[1, 512],并且满足以下条件:moe_expert_num%(ep_world_size - shared_expert_rank_num)=0。- Atlas A2 训练系列产品/Atlas A2 推理系列产品:还需满足moe_expert_num/(ep_world_size - shared_expert_rank_num) <= 24。
-
*:必选参数,代表其之前的变量是位置相关的,必须按照顺序输入;之后的变量是可选参数,位置无关,需要使用键值对赋值,不赋值会使用默认值。
-
tp_send_counts (
Tensor):可选参数,表示本卡每个专家发给TP(Tensor Parallelism)通信域每个卡的数据量。对应torch_npu.npu_moe_distribute_dispatch的tp_recv_counts输出。- Atlas A2 训练系列产品/Atlas A2 推理系列产品:不支持TP通信域,使用默认输入None。
- Atlas A3 训练系列产品/Atlas A3 推理系列产品:支持TP通信域,要求为一个1维张量,shape为(tp_world_size,),数据类型支持
int32,数据格式为NDND,支持非连续的Tensor。
-
x_active_mask (
Tensor):预留参数,暂未使用,使用默认值即可。 -
activation_scale (
Tensor):预留参数,暂未使用,使用默认值即可。 -
weight_scale (
Tensor):预留参数,暂未使用,使用默认值即可。 -
group_list (
Tensor):预留参数,暂未使用,使用默认值即可。 -
expand_scales (
Tensor):对应torch_npu.npu_moe_distribute_dispatch的expand_scales输出。- Atlas A2 训练系列产品/Atlas A2 推理系列产品:必选参数,要求为1维张量,shape为(A,),数据类型支持
float,数据格式为NDND,支持非连续的Tensor。 - Atlas A3 训练系列产品/Atlas A3 推理系列产品:暂不支持该参数,使用默认值即可。
- Atlas A2 训练系列产品/Atlas A2 推理系列产品:必选参数,要求为1维张量,shape为(A,),数据类型支持
-
shared_expert_x (
Tensor):预留参数,暂未使用,使用默认值即可。 -
group_tp (
str):可选参数,TP通信域名称,数据并行的通信域。有TP域通信才需要传参,若无TP域通信,使用默认值""即可。- Atlas A2 训练系列产品/Atlas A2 推理系列产品:eager模式使用默认值即可,图模式传入与
group_ep相同。 - Atlas A3 训练系列产品/Atlas A3 推理系列产品:字符串长度范围为[1, 128),不能和
group_ep相同。
- Atlas A2 训练系列产品/Atlas A2 推理系列产品:eager模式使用默认值即可,图模式传入与
-
tp_world_size (
int):可选参数,TP通信域size。有TP域通信才需要传参。- Atlas A2 训练系列产品/Atlas A2 推理系列产品:不支持TP域通信,使用默认值0即可。
- Atlas A3 训练系列产品/Atlas A3 推理系列产品:当有TP域通信时,取值范围[0, 2],0和1表示无TP域通信,2表示有TP域通信。
-
tp_rank_id (
int):可选参数,TP通信域本卡ID。有TP域通信才需要传参。- Atlas A2 训练系列产品/Atlas A2 推理系列产品:不支持TP域通信,使用默认值0即可。
- Atlas A3 训练系列产品/Atlas A3 推理系列产品:当有TP域通信时,取值范围[0, 1],同一个TP通信域中各卡的
tp_rank_id不重复。无TP域通信时,传0即可。
-
expert_shard_type (
int):表示共享专家卡排布类型。- Atlas A2 训练系列产品/Atlas A2 推理系列产品:暂不支持该参数,使用默认值即可。
- Atlas A3 训练系列产品/Atlas A3 推理系列产品:当前仅支持0,表示共享专家卡排在MoE专家卡前面。
-
shared_expert_num (
int):表示共享专家数量,一个共享专家可以复制部署到多个卡上。- Atlas A2 训练系列产品/Atlas A2 推理系列产品:暂不支持该参数,使用默认值即可。
- Atlas A3 训练系列产品/Atlas A3 推理系列产品:仅支持1,默认值为1。
-
shared_expert_rank_num (
int):可选参数,表示共享专家卡数量。- Atlas A2 训练系列产品/Atlas A2 推理系列产品:不支持共享专家,传0即可。
- Atlas A3 训练系列产品/Atlas A3 推理系列产品:取值范围[0, ep_world_size)。取0表示无共享专家,不取0时需满足ep_world_size%shared_expert_rank_num=0。
-
global_bs (
int):可选参数,表示EP域全局的BS(batch size)大小。- Atlas A2 训练系列产品/Atlas A2 推理系列产品:当每个rank的BS不同时,支持传入max_bs*ep_world_size或者256*ep_world_size,其中max_bs表示单rank BS最大值,建议按max_bs*ep_world_size传入,固定按256*ep_world_size传入,在后续版本BS大于256的场景下会无法支持;当每个rank的BS相同时,支持取值0或BS*ep_world_size。
- Atlas A3 训练系列产品/Atlas A3 推理系列产品:当每个rank的BS不同时,支持传入max_bs*ep_world_size,其中max_bs表示单rank BS最大值;当每个rank的BS相同时,支持取值0或BS*ep_world_size。
-
out_dtype (
int):预留参数,暂未使用,使用默认值即可。 -
comm_quant_mode (
int):表示通信量化类型。- Atlas A2 训练系列产品/Atlas A2 推理系列产品:支持取0和2。0表示通信时不量化,2表示通信时进行
int8量化。仅当HCCL_INTRA_PCIE_ENABLE=1且HCCL_INTRA_ROCE_ENABLE=0且驱动版本不低于25.0.RC1.1时才支持取2。 - Atlas A3 训练系列产品/Atlas A3 推理系列产品:支持取0和2。0表示通信时不量化,2表示通信时进行
int8量化。当且仅当tp_world_size不等于2时,可以使能int8量化。
- Atlas A2 训练系列产品/Atlas A2 推理系列产品:支持取0和2。0表示通信时不量化,2表示通信时进行
-
group_list_type (
int):预留参数,暂未使用,使用默认值即可。
返回值说明
Tensor
表示处理后的token,要求为2维张量,shape为(BS, H),数据类型支持bfloat16、float16,类型与输入expand_x保持一致,数据格式为NDND,不支持非连续的Tensor。
约束说明
-
该接口支持推理场景下使用。
-
该接口支持静态图模式,
npu_moe_distribute_dispatch和npu_moe_distribute_combine必须配套使用。 -
在不同产品型号、不同通信算法或不同版本中,
npu_moe_distribute_dispatch的Tensor输出expand_idx、ep_recv_counts、tp_recv_counts、expand_scales中的元素值可能不同,使用时直接将上述Tensor传给npu_moe_distribute_combine对应参数即可,模型其他业务逻辑不应对其存在依赖。 -
调用接口过程中使用的
group_ep、ep_world_size、moe_expert_num、group_tp、tp_world_size、expert_shard_type、shared_expert_num、shared_expert_rank_num、global_bs参数取值所有卡需保持一致,group_ep、ep_world_size、moe_expert_num、group_tp、tp_world_size、expert_shard_type、global_bs网络中不同层中也需保持一致,且和torch_npu.npu_moe_distribute_dispatch对应参数也保持一致。 -
Atlas A3 训练系列产品/Atlas A3 推理系列产品:该场景下单卡包含双DIE(简称为“晶粒”或“裸片”),因此参数说明里的“本卡”均表示单DIE。
-
参数里Shape使用的变量如下:
-
A:表示本卡接收的最大token数量,取值范围如下:
- 对于共享专家,当
global_bs为0时,要满足A=BS*shared_expert_num/shared_expert_rank_num;当global_bs非0时,要满足A=global_bs*shared_expert_num/shared_expert_rank_num。 - 对于MoE专家,当
global_bs为0时,要满足A>=BS*ep_world_size*min(local_expert_num, K);当global_bs非0时,要满足A>=global_bs* min(local_expert_num, K)。
- 对于共享专家,当
-
H:表示hidden size隐藏层大小。
- Atlas A2 训练系列产品/Atlas A2 推理系列产品:取值范围(0, 7168],且保证是32的整数倍。
- Atlas A3 训练系列产品/Atlas A3 推理系列产品:仅支持 7168。
-
BS:表示待发送的token数量。
- Atlas A2 训练系列产品/Atlas A2 推理系列产品:取值范围为0<BS≤256。
- Atlas A3 训练系列产品/Atlas A3 推理系列产品:取值范围为0<BS≤512。
-
K:表示选取topK个专家,需满足0<K≤moe_expert_num。
- Atlas A2 训练系列产品/Atlas A2 推理系列产品:保证取值范围为0<K≤16。
- Atlas A3 训练系列产品/Atlas A3 推理系列产品:保证取值范围为0<K≤8。
-
server_num:表示服务器的节点数,取值只支持2、4、8。
- Atlas A2 训练系列产品/Atlas A2 推理系列产品:仅该场景的shape使用了该变量。
-
local_expert_num:表示本卡专家数量。
- 对于共享专家卡,local_expert_num=1
- 对于MoE专家卡,local_expert_num=moe_expert_num/(ep_world_size-shared_expert_rank_num),当local_expert_num>1时,不支持TP域通信。
-
-
HCCL通信域缓存区大小:
调用本接口前需检查通信域缓存区大小取值是否合理,单位MB,不配置时默认为200MB。
- Atlas A2 训练系列产品/Atlas A2 推理系列产品:
该场景支持通过环境变量HCCL_BUFFSIZE配置。
- 设置大小要求>=2*(BS*ep_world_size*min(local_expert_num, K)*H*sizeof(uint16)+2MB)。
- Atlas A3 训练系列产品/Atlas A3 推理系列产品:
该场景不仅支持通过环境变量HCCL_BUFFSIZE配置,还支持通过hccl_buffer_size配置(参考《PyTorch训练模型迁移调优》中“性能调优>性能调优方法>通信优化>优化方法>hccl_buffer_size”章节)。
- ep通信域内:设置大小要求>=2且满足1024^2*(HCCL_BUFFSIZE-2)/2>=BS*2*(H+128)*(ep_world_size*local_expert_num+K+1),local_expert_num需使用MoE专家卡的本卡专家数。
- tp通信域内:设置大小要求>=A * (H * 2 + 128) * 2。
- Atlas A2 训练系列产品/Atlas A2 推理系列产品:
该场景支持通过环境变量HCCL_BUFFSIZE配置。
-
Atlas A2 训练系列产品/Atlas A2 推理系列产品:配置环境变量HCCL_INTRA_PCIE_ENABLE=1和HCCL_INTRA_ROCE_ENABLE=0可以减少跨机通信数据量,提升算子性能。此时要求HCCL_BUFFSIZE>=moe_expert_num*BS*(H*sizeof(dtype_x)+4*((K+7)/8*8)*sizeof(uint32))+4MB+100MB。并且,对于入参moe_expert_num,只要求moe_expert_num%(ep_world_size - shared_expert_rank_num)=0,不要求moe_expert_num/(ep_world_size - shared_expert_rank_num) <= 24。
-
本文公式中的“/”表示整除。
-
通信域使用约束:
-
一个模型中的
npu_moe_distribute_dispatch和npu_moe_distribute_combine算子仅支持相同EP通信域,且该通信域中不允许有其他算子。 -
一个模型中的
npu_moe_distribute_dispatch和npu_moe_distribute_combine算子仅支持相同TP通信域或都不支持TP通信域,有TP通信域时该通信域中不允许有其他算子。 -
Atlas A3 训练系列产品/Atlas A3 推理系列产品:一个通信域内的节点需在一个超节点内,不支持跨超节点。
-
-
组网约束:
- Atlas A2 训练系列产品/Atlas A2 推理系列产品:多机场景仅支持交换机组网,不支持双机直连组网。
调用示例
-
单算子模式调用
import os import torch import random import torch_npu import numpy as np from torch.multiprocessing import Process import torch.distributed as dist from torch.distributed import ReduceOp # 控制模式 quant_mode = 2 # 2为动态量化 is_dispatch_scales = True # 动态量化可选择是否传scales input_dtype = torch.bfloat16 # 输出dtype server_num = 1 server_index = 0 port = 50001 master_ip = '127.0.0.1' dev_num = 16 world_size = server_num * dev_num rank_per_dev = int(world_size / server_num) # 每个host有几个die sharedExpertRankNum = 2 # 共享专家数 moeExpertNum = 14 # moe专家数 bs = 8 # token数量 h = 7168 # 每个token的长度 k = 8 random_seed = 0 tp_world_size = 1 ep_world_size = int(world_size / tp_world_size) moe_rank_num = ep_world_size - sharedExpertRankNum local_moe_expert_num = moeExpertNum // moe_rank_num globalBS = bs * ep_world_size is_shared = (sharedExpertRankNum > 0) is_quant = (quant_mode > 0) def gen_unique_topk_array(low, high, bs, k): array = [] for i in range(bs): top_idx = list(np.arange(low, high, dtype=np.int32)) random.shuffle(top_idx) array.append(top_idx[0:k]) return np.array(array) def get_new_group(rank): for i in range(tp_world_size): # 如果tp_world_size = 2,ep_world_size = 8,则为[[0, 2, 4, 6, 8, 10, 12, 14], [1, 3, 5, 7, 9, 11, 13, 15]] ep_ranks = [x * tp_world_size + i for x in range(ep_world_size)] ep_group = dist.new_group(backend="hccl", ranks=ep_ranks) if rank in ep_ranks: ep_group_t = ep_group print(f"rank:{rank} ep_ranks:{ep_ranks}") for i in range(ep_world_size): # 如果tp_world_size = 2,ep_world_size = 8,则为[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13], [14, 15]] tp_ranks = [x + tp_world_size * i for x in range(tp_world_size)] tp_group = dist.new_group(backend="hccl", ranks=tp_ranks) if rank in tp_ranks: tp_group_t = tp_group print(f"rank:{rank} tp_ranks:{tp_ranks}") return ep_group_t, tp_group_t def get_hcomm_info(rank, comm_group): if torch.__version__ > '2.0.1': hcomm_info = comm_group._get_backend(torch.device("npu")).get_hccl_comm_name(rank) else: hcomm_info = comm_group.get_hccl_comm_name(rank) return hcomm_info def run_npu_process(rank): torch_npu.npu.set_device(rank) rank = rank + 16 * server_index dist.init_process_group(backend='hccl', rank=rank, world_size=world_size, init_method=f'tcp://{master_ip}:{port}') ep_group, tp_group = get_new_group(rank) ep_hcomm_info = get_hcomm_info(rank, ep_group) tp_hcomm_info = get_hcomm_info(rank, tp_group) # 创建输入tensor x = torch.randn(bs, h, dtype=input_dtype).npu() expert_ids = gen_unique_topk_array(0, moeExpertNum, bs, k).astype(np.int32) expert_ids = torch.from_numpy(expert_ids).npu() expert_scales = torch.randn(bs, k, dtype=torch.float32).npu() scales_shape = (1 + moeExpertNum, h) if sharedExpertRankNum else (moeExpertNum, h) if is_dispatch_scales: scales = torch.randn(scales_shape, dtype=torch.float32).npu() else: scales = None expand_x, dynamic_scales, expand_idx, expert_token_nums, ep_recv_counts, tp_recv_counts, expand_scales = torch_npu.npu_moe_distribute_dispatch( x=x, expert_ids=expert_ids, group_ep=ep_hcomm_info, group_tp=tp_hcomm_info, ep_world_size=ep_world_size, tp_world_size=tp_world_size, ep_rank_id=rank // tp_world_size, tp_rank_id=rank % tp_world_size, expert_shard_type=0, shared_expert_rank_num=sharedExpertRankNum, moe_expert_num=moeExpertNum, scales=scales, quant_mode=quant_mode, global_bs=globalBS) if is_quant: expand_x = expand_x.to(input_dtype) x = torch_npu.npu_moe_distribute_combine(expand_x=expand_x, expert_ids=expert_ids, expand_idx=expand_idx, ep_send_counts=ep_recv_counts, tp_send_counts=tp_recv_counts, expert_scales=expert_scales, group_ep=ep_hcomm_info, group_tp=tp_hcomm_info, ep_world_size=ep_world_size, tp_world_size=tp_world_size, ep_rank_id=rank // tp_world_size, tp_rank_id=rank % tp_world_size, expert_shard_type=0, shared_expert_rank_num=sharedExpertRankNum, moe_expert_num=moeExpertNum, global_bs=globalBS) print(f'rank {rank} epid {rank // tp_world_size} tpid {rank % tp_world_size} npu finished! \n') if __name__ == "__main__": print(f"bs={bs}") print(f"global_bs={globalBS}") print(f"shared_expert_rank_num={sharedExpertRankNum}") print(f"moe_expert_num={moeExpertNum}") print(f"k={k}") print(f"quant_mode={quant_mode}", flush=True) print(f"local_moe_expert_num={local_moe_expert_num}", flush=True) print(f"tp_world_size={tp_world_size}", flush=True) print(f"ep_world_size={ep_world_size}", flush=True) if tp_world_size != 1 and local_moe_expert_num > 1: print("unSupported tp = 2 and local moe > 1") exit(0) if sharedExpertRankNum > ep_world_size: print("sharedExpertRankNum 不能大于 ep_world_size") exit(0) if sharedExpertRankNum > 0 and ep_world_size % sharedExpertRankNum != 0: print("ep_world_size 必须是 sharedExpertRankNum的整数倍") exit(0) if moeExpertNum % moe_rank_num != 0: print("moeExpertNum 必须是 moe_rank_num 的整数倍") exit(0) p_list = [] for rank in range(rank_per_dev): p = Process(target=run_npu_process, args=(rank,)) p_list.append(p) for p in p_list: p.start() for p in p_list: p.join() print("run npu success.") -
图模式调用
# 仅支持静态图 import os import torch import random import torch_npu import torchair import numpy as np from torch.multiprocessing import Process import torch.distributed as dist from torch.distributed import ReduceOp # 控制模式 quant_mode = 2 # 2为动态量化 is_dispatch_scales = True # 动态量化可选择是否传scales input_dtype = torch.bfloat16 # 输出dtype server_num = 1 server_index = 0 port = 50001 master_ip = '127.0.0.1' dev_num = 16 world_size = server_num * dev_num rank_per_dev = int(world_size / server_num) # 每个host有几个die sharedExpertRankNum = 2 # 共享专家数 moeExpertNum = 14 # moe专家数 bs = 8 # token数量 h = 7168 # 每个token的长度 k = 8 random_seed = 0 tp_world_size = 1 ep_world_size = int(world_size / tp_world_size) moe_rank_num = ep_world_size - sharedExpertRankNum local_moe_expert_num = moeExpertNum // moe_rank_num globalBS = bs * ep_world_size is_shared = (sharedExpertRankNum > 0) is_quant = (quant_mode > 0) class MOE_DISTRIBUTE_GRAPH_Model(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, expert_ids, group_ep, group_tp, ep_world_size, tp_world_size, ep_rank_id, tp_rank_id, expert_shard_type, shared_expert_rank_num, moe_expert_num, scales, quant_mode, global_bs, expert_scales): output_dispatch_npu = torch_npu.npu_moe_distribute_dispatch(x=x, expert_ids=expert_ids, group_ep=group_ep, group_tp=group_tp, ep_world_size=ep_world_size, tp_world_size=tp_world_size, ep_rank_id=ep_rank_id, tp_rank_id=tp_rank_id, expert_shard_type=expert_shard_type, shared_expert_rank_num=shared_expert_rank_num, moe_expert_num=moe_expert_num, scales=scales, quant_mode=quant_mode, global_bs=global_bs) expand_x_npu, _, expand_idx_npu, _, ep_recv_counts_npu, tp_recv_counts_npu, expand_scales = output_dispatch_npu if expand_x_npu.dtype == torch.int8: expand_x_npu = expand_x_npu.to(input_dtype) output_combine_npu = torch_npu.npu_moe_distribute_combine(expand_x=expand_x_npu, expert_ids=expert_ids, expand_idx=expand_idx_npu, ep_send_counts=ep_recv_counts_npu, tp_send_counts=tp_recv_counts_npu, expert_scales=expert_scales, group_ep=group_ep, group_tp=group_tp, ep_world_size=ep_world_size, tp_world_size=tp_world_size, ep_rank_id=ep_rank_id, tp_rank_id=tp_rank_id, expert_shard_type=expert_shard_type, shared_expert_rank_num=shared_expert_rank_num, moe_expert_num=moe_expert_num, global_bs=global_bs) x = output_combine_npu x_combine_res = output_combine_npu return [x_combine_res, output_combine_npu] def gen_unique_topk_array(low, high, bs, k): array = [] for i in range(bs): top_idx = list(np.arange(low, high, dtype=np.int32)) random.shuffle(top_idx) array.append(top_idx[0:k]) return np.array(array) def get_new_group(rank): for i in range(tp_world_size): ep_ranks = [x * tp_world_size + i for x in range(ep_world_size)] ep_group = dist.new_group(backend="hccl", ranks=ep_ranks) if rank in ep_ranks: ep_group_t = ep_group print(f"rank:{rank} ep_ranks:{ep_ranks}") for i in range(ep_world_size): tp_ranks = [x + tp_world_size * i for x in range(tp_world_size)] tp_group = dist.new_group(backend="hccl", ranks=tp_ranks) if rank in tp_ranks: tp_group_t = tp_group print(f"rank:{rank} tp_ranks:{tp_ranks}") return ep_group_t, tp_group_t def get_hcomm_info(rank, comm_group): if torch.__version__ > '2.0.1': hcomm_info = comm_group._get_backend(torch.device("npu")).get_hccl_comm_name(rank) else: hcomm_info = comm_group.get_hccl_comm_name(rank) return hcomm_info def run_npu_process(rank): torch_npu.npu.set_device(rank) rank = rank + 16 * server_index dist.init_process_group(backend='hccl', rank=rank, world_size=world_size, init_method=f'tcp://{master_ip}:{port}') ep_group, tp_group = get_new_group(rank) ep_hcomm_info = get_hcomm_info(rank, ep_group) tp_hcomm_info = get_hcomm_info(rank, tp_group) # 创建输入tensor x = torch.randn(bs, h, dtype=input_dtype).npu() expert_ids = gen_unique_topk_array(0, moeExpertNum, bs, k).astype(np.int32) expert_ids = torch.from_numpy(expert_ids).npu() expert_scales = torch.randn(bs, k, dtype=torch.float32).npu() scales_shape = (1 + moeExpertNum, h) if sharedExpertRankNum else (moeExpertNum, h) if is_dispatch_scales: scales = torch.randn(scales_shape, dtype=torch.float32).npu() else: scales = None model = MOE_DISTRIBUTE_GRAPH_Model() model = model.npu() npu_backend = torchair.get_npu_backend() model = torch.compile(model, backend=npu_backend, dynamic=False) output = model.forward(x, expert_ids, ep_hcomm_info, tp_hcomm_info, ep_world_size, tp_world_size, rank // tp_world_size,rank % tp_world_size, 0, sharedExpertRankNum, moeExpertNum, scales, quant_mode, globalBS, expert_scales) torch.npu.synchronize() print(f'rank {rank} epid {rank // tp_world_size} tpid {rank % tp_world_size} npu finished! \n') if __name__ == "__main__": print(f"bs={bs}") print(f"global_bs={globalBS}") print(f"shared_expert_rank_num={sharedExpertRankNum}") print(f"moe_expert_num={moeExpertNum}") print(f"k={k}") print(f"quant_mode={quant_mode}", flush=True) print(f"local_moe_expert_num={local_moe_expert_num}", flush=True) print(f"tp_world_size={tp_world_size}", flush=True) print(f"ep_world_size={ep_world_size}", flush=True) if tp_world_size != 1 and local_moe_expert_num > 1: print("unSupported tp = 2 and local moe > 1") exit(0) if sharedExpertRankNum > ep_world_size: print("sharedExpertRankNum 不能大于 ep_world_size") exit(0) if sharedExpertRankNum > 0 and ep_world_size % sharedExpertRankNum != 0: print("ep_world_size 必须是 sharedExpertRankNum的整数倍") exit(0) if moeExpertNum % moe_rank_num != 0: print("moeExpertNum 必须是 moe_rank_num 的整数倍") exit(0) p_list = [] for rank in range(rank_per_dev): p = Process(target=run_npu_process, args=(rank,)) p_list.append(p) for p in p_list: p.start() for p in p_list: p.join() print("run npu success.")