import os
import unittest
import numpy as np
import torch
import torch_npu
import torch.distributed as dist
import torch.multiprocessing as mp
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import create_common_tensor, SupportedDevices
from torch_npu.testing.common_distributed import skipIfUnsupportMultiNPU
quant_mode = 2
is_dispatch_scales = True
input_dtype = torch.bfloat16
sharedExpertRankNum = 0
moeExpertNum = 16
bs = 8
h = 1024
k = 8
random_seed = 0
ep_world_size = 16
moe_rank_num = ep_world_size - sharedExpertRankNum
local_moe_expert_num = moeExpertNum // moe_rank_num
globalBS = bs * ep_world_size
is_shared = (sharedExpertRankNum > 0)
is_quant = (quant_mode > 0)
def gen_unique_topk_array(low, high):
array = []
for _ in range(bs):
top_idx = list(np.arange(low, high, dtype=np.int32))
np.random.shuffle(top_idx)
array.append(top_idx[0:k])
return np.array(array)
def test_npu_moe_distribute_dispatch_combine_urma_(rank, c2p, p2c):
torch_npu.npu.set_device(rank)
init_method = 'tcp://' + "127.0.0.1" + ':' + '50001'
dist.init_process_group(backend='hccl', world_size=ep_world_size, rank=rank, init_method=init_method)
ep_group = dist.new_group(backend="hccl", ranks=range(ep_world_size))
ep_hcomm_info = comm_group._get_backend(torch.device("npu")).get_hccl_comm_name(rank)
x = torch.randn(bs, h, dtype=input_dtype).npu()
expert_ids = gen_unique_topk_array(0, moeExpertNum).astype(np.int32)
expert_ids = torch.from_numpy(expert_ids).npu()
expert_scales = torch.randn(bs, k, dtype=torch.float32).npu()
scales_shape = (1 + moeExpertNum, h) if sharedExpertRankNum == 0 else (moeExpertNum, h)
if is_dispatch_scales:
scales = torch.randn(scales_shape, dtype=torch.float32).npu()
else:
scales = None
torch.npu.synchronize()
y, expand_idx, comm_cmd_info = torch_npu.npu_moe_distribute_dispatch_setup(
x=x,
expert_ids=expert_ids,
group_ep=ep_hcomm_info,
ep_world_size=ep_world_size,
ep_rank_id=rank,
moe_expert_num=moeExpertNum,
expert_shard_type=0,
shared_expert_rank_num=sharedExpertRankNum,
scales=scales,
quant_mode=quant_mode,
global_bs=globalBS)
torch.npu.synchronize()
expand_x, dynamic_scales, assist_info_for_combine, expert_token_nums = torch_npu.npu_moe_distribute_dispatch_teardown(
x=x,
y=y,
expert_ids=expert_ids,
comm_cmd_info=comm_cmd_info,
group_ep=ep_hcomm_info,
ep_world_size=ep_world_size,
ep_rank_id=rank,
moe_expert_num=moeExpertNum,
expert_shard_type=0,
shared_expert_rank_num=sharedExpertRankNum,
quant_mode=quant_mode,
global_bs=globalBS)
if is_quant:
expand_x = expand_x.to(input_dtype)
torch.npu.synchronize()
quant_expand_x, comm_cmd_info = torch_npu.npu_moe_distribute_combine_setup(
expand_x=expand_x,
expert_ids=expert_ids,
assist_info_for_combine=assist_info_for_combine,
group_ep=ep_hcomm_info,
ep_world_size=ep_world_size,
ep_rank_id=rank,
moe_expert_num=moeExpertNum,
expert_shard_type=0,
shared_expert_rank_num=sharedExpertRankNum,
comm_quant_mode=quant_mode,
global_bs=globalBS)
torch.npu.synchronize()
out = torch_npu.npu_moe_distribute_combine_teardown(
expand_x=expand_x,
quant_expand_x=quant_expand_x,
expert_ids=expert_ids,
expand_idx=expand_idx,
expert_scales=expert_scales,
comm_cmd_info=comm_cmd_info,
group_ep=ep_hcomm_info,
ep_world_size=ep_world_size,
ep_rank_id=rank,
moe_expert_num=moeExpertNum,
expert_shard_type=0,
shared_expert_rank_num=sharedExpertRankNum,
comm_quant_mode=quant_mode,
global_bs=globalBS)
c2p.put((rank, out.cpu()))
p2c.get()
class TestMoeDistributeDispatchCombineUrma(TestCase):
@skipIfUnsupportMultiNPU(16)
@SupportedDevices(['Ascend910_'])
def test_npu_moe_distribute_dispatch(self):
ctx = mp.get_context('spawn')
c2p = ctx.Queue(ep_world_size)
p2c = ctx.Queue(ep_world_size)
p_list = []
for rank in range(ep_world_size):
p = ctx.Process(target=test_npu_moe_distribute_dispatch_combine_urma_, args=(rank, c2p, p2c))
p.start()
p_list.append(p)
for _ in range(ep_world_size):
rank_id, out = c2p.get()
print("recv rank", rank_id, "data success")
for _ in range(ep_world_size):
p2c.put(0)
for p in p_list:
p.join()
if __name__ == '__main__':
run_tests()