import os
import unittest
import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import create_common_tensor, SupportedDevices
from torch_npu.testing.common_distributed import skipIfUnsupportMultiNPU
class TestMoeDistributeDispatch(TestCase):
@classmethod
def _init_dist_hccl(cls, rank, world_size, ep_world_size, tp_world_size):
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '50000'
os.environ['HCCL_WHITELIST_DISABLE'] = '1'
torch_npu.npu.set_device(rank)
dist.init_process_group(backend='hccl', world_size=world_size, rank=rank)
ep_ranks_list = []
tp_ranks_list = []
for i in range(tp_world_size):
ep_ranks_list.append(list(range(i, world_size, tp_world_size)))
for i in range(ep_world_size):
tp_ranks_list.append(list(range(i * tp_world_size, (i + 1) * tp_world_size)))
for i in range(tp_world_size):
ep_group = dist.new_group(backend='hccl', ranks=ep_ranks_list[i])
if rank in ep_ranks_list[i]:
ep_group_tmp = ep_group
for i in range(ep_world_size):
tp_group = dist.new_group(backend='hccl', ranks=tp_ranks_list[i])
if rank in tp_ranks_list[i]:
tp_group_tmp = tp_group
return dist, ep_group_tmp, tp_group_tmp
@classmethod
def _test_npu_moe_distribute_dispatch_v2(cls, rank, input_list):
expt_token_list, x1_list, x2_list, topk1_list, topk2_list, elastic_info, ep_world_size, tp_world_size, globalBS,\
sharedExpertRankNum, moeExpertNum, h, init_pg, c2p, p2c = input_list
tp_world_size_2 = 2
if rank % tp_world_size_2 == 0:
x = x1_list[rank // tp_world_size_2]
topk = topk1_list[rank // tp_world_size_2]
else:
x = x2_list[rank // tp_world_size_2]
topk = topk2_list[rank // tp_world_size_2]
pg, ep_group, tp_group = init_pg(rank, ep_world_size * tp_world_size_2, ep_world_size, tp_world_size_2)
ep_hcomm_name = ep_group._get_backend(torch.device('npu')).get_hccl_comm_name(rank)
tp_hcomm_name = tp_group._get_backend(torch.device('npu')).get_hccl_comm_name(rank)
x = x.npu()
topk = topk.npu()
if elastic_info is not None:
elastic_info = elastic_info.npu()
out, _, _, _, _, _, _ = torch_npu.npu_moe_distribute_dispatch_v2(x=x,
expert_ids=topk,
elastic_info=elastic_info,
group_ep=ep_hcomm_name,
ep_world_size=ep_world_size,
ep_rank_id=int(rank // tp_world_size_2),
moe_expert_num=moeExpertNum,
scales=None,
group_tp=tp_hcomm_name,
tp_world_size=tp_world_size,
tp_rank_id=int(rank % tp_world_size) if tp_world_size != 1 else 0,
expert_shard_type=0,
shared_expert_num=int(sharedExpertRankNum > 0),
shared_expert_rank_num=sharedExpertRankNum,
quant_mode=0,
global_bs=globalBS)
if tp_world_size == 1:
_ = torch_npu._npu_distribute_barrier(
x_ref=x,
group=ep_hcomm_name,
world_size=ep_world_size)
if rank // tp_world_size_2 < sharedExpertRankNum:
A = int(globalBS // sharedExpertRankNum)
else:
local = int(moeExpertNum // (ep_world_size - sharedExpertRankNum))
A = int(globalBS * local)
out = (out.reshape(tp_world_size * A, h))[:int(expt_token_list[rank]), :]
c2p.put((rank, out.cpu()))
p2c.get()
def _test_multiprocess(self, f, init_pg, input_list):
expt_out_list, expt_token_list, x1_list, x2_list, topk1_list, topk2_list, elastic_info, ep_world_size,\
tp_world_size, globalBS, sharedExpertRankNum, moeExpertNum, h = input_list
ctx = mp.get_context('spawn')
tp_world_size_2 = 2
c2p = ctx.Queue(ep_world_size * tp_world_size_2)
p2c = ctx.Queue(ep_world_size * tp_world_size_2)
ps = []
for i in range(ep_world_size * tp_world_size_2):
p = ctx.Process(
target=f,
args=(i, [expt_token_list, x1_list, x2_list, topk1_list, topk2_list, elastic_info, ep_world_size,
tp_world_size, globalBS, sharedExpertRankNum, moeExpertNum, h, init_pg, c2p, p2c]))
p.start()
ps.append(p)
for _ in range(ep_world_size * tp_world_size_2):
rank, output = c2p.get()
self.assertEqual(output, expt_out_list[rank],
("rank {} Expect receive tensor {} but got {}.").format(rank, expt_out_list[rank], output))
for _ in range(ep_world_size * tp_world_size_2):
p2c.put(0)
for p in ps:
p.join()
def _construct_excepted_result(self, x1_list, x2_list, topk1_list, topk2_list, bs, h, k, globalBS,
sharedExpertRankNum, moeExpertNum, ep_world_size, tp_world_size):
col_idx = torch.arange(0, globalBS * k, dtype=torch.int32)
row_idx = col_idx.view(k, -1).permute(1, 0)
row_idx = row_idx.reshape([globalBS, k]).contiguous()
x1 = torch.cat(x1_list, dim=0).view(-1, h)
x2 = torch.cat(x2_list, dim=0).view(-1, h)
topk1 = torch.cat(topk1_list, dim=0).view(-1, k)
topk2 = torch.cat(topk2_list, dim=0).view(-1, k)
expandX1, expand_row1, expand_expert1 = torch_npu.npu_moe_init_routing(x1.npu(), row_idx=row_idx.npu(),
expert_idx=topk1.npu(),
active_num=globalBS)
expandX2, expand_row2, expand_expert2 = torch_npu.npu_moe_init_routing(x2.npu(), row_idx=row_idx.npu(),
expert_idx=topk2.npu(),
active_num=globalBS)
expandX1 = expandX1.cpu()
expandX2 = expandX2.cpu()
expand_expert1 = expand_expert1.cpu()
expand_expert2 = expand_expert2.cpu()
shared_list = []
shared_tokens = []
for i in range(sharedExpertRankNum):
tmp_list = []
shared_tokens.append(bs * (int(moeExpertNum / sharedExpertRankNum) + 1))
tmp_list.append(x1[(bs * i):(bs * (i + 1)), :])
for j in range(int(moeExpertNum / sharedExpertRankNum)):
tmp_list.append(x1[(bs * (i + (j + 1) * sharedExpertRankNum)):(bs * (i + (j + 1) * sharedExpertRankNum + 1)), :])
tmp_list = torch.cat(tmp_list, dim=0).to(torch.float16)
shared_list.append(tmp_list)
if sharedExpertRankNum != 0:
shared_x1 = torch.cat(shared_list, dim=0)
token1 = torch.cat((torch.tensor(shared_tokens), torch.bincount(expand_expert1, minlength=moeExpertNum)))
token2 = torch.cat((torch.tensor(shared_tokens), torch.bincount(expand_expert2, minlength=moeExpertNum)))
shared_list = []
for i in range(sharedExpertRankNum):
tmp_list = []
tmp_list.append(x2[(bs * i):(bs * (i + 1)), :])
for j in range(int(moeExpertNum / sharedExpertRankNum)):
tmp_list.append(x2[(bs * (i + (j + 1) * sharedExpertRankNum)):(bs * (i + (j + 1) * sharedExpertRankNum + 1)), :])
tmp_list = torch.cat(tmp_list, dim=0).to(torch.float16)
shared_list.append(tmp_list)
if sharedExpertRankNum != 0:
shared_x2 = torch.cat(shared_list, dim=0)
golden_expandX1 = torch.cat((shared_x1, expandX1)).view(-1, h)
golden_expandX2 = torch.cat((shared_x2, expandX2)).view(-1, h)
else:
golden_expandX1 = expandX1.view(-1, h)
golden_expandX2 = expandX2.view(-1, h)
sums1 = 0
sums2 = 0
out_list = []
token_list = []
for i in range(ep_world_size):
start1 = sums1
end1 = sums1 + int(token1[i])
sums1 = end1
start2 = sums2
end2 = sums2 + int(token2[i])
sums2 = end2
if tp_world_size == 2:
out_list.append(torch.cat((golden_expandX1[start1:end1, :], golden_expandX2[start2:end2, :])))
out_list.append(torch.cat((golden_expandX2[start2:end2, :], golden_expandX1[start1:end1, :])))
token_list.append(int(token1[i]) + int(token2[i]))
token_list.append(int(token1[i]) + int(token2[i]))
else:
out_list.append(golden_expandX1[start1:end1, :])
out_list.append(golden_expandX2[start2:end2, :])
token_list.append(int(token1[i]))
token_list.append(int(token2[i]))
return out_list, token_list
@skipIfUnsupportMultiNPU(16)
@SupportedDevices(['Ascend910_93', 'Ascend950'])
def test_npu_moe_distribute_dispatch_v2(self):
ep_world_size = 8
tp_world_size = 2
tp_world_size_1 = 1
world_size = ep_world_size * tp_world_size
bs = 8
h = 7168
k = 4
shared_expert_rank_num_1 = 1
moe_expert_num_7 = 7
shared_expert_rank_num_0 = 0
moe_expert_num_8 = 8
global_bs = bs * ep_world_size
dtype = np.float16
data_format = -1
topk = torch.tile(torch.arange(k), (bs,)).int().view(-1, k)
topk1_list = []
topk2_list = []
x1_shape = [dtype, data_format, [bs, h]]
x2_shape = [dtype, data_format, [bs, h]]
x1_list = []
x2_list = []
elastic_info_shape = [np.int32, data_format, [4 + ep_world_size]]
elastic_info = []
for _ in range(ep_world_size):
x1, _ = create_common_tensor(x1_shape, -1, 1)
x2, _ = create_common_tensor(x2_shape, -1, 1)
x1_list.append(x1)
x2_list.append(x2)
topk1_list.append(topk)
topk2_list.append(topk)
temp_elastic_info, _ = create_common_tensor(elastic_info_shape, -1, 1)
elastic_info.append(temp_elastic_info)
expt_out_list_1, expt_token_list_1 = self._construct_excepted_result(x1_list, x2_list, topk1_list, topk2_list, bs, h, k,
global_bs, shared_expert_rank_num_1, moe_expert_num_7, ep_world_size, tp_world_size)
expt_out_list_2, expt_token_list_2 = self._construct_excepted_result(x1_list, x2_list, topk1_list, topk2_list, bs, h, k,
global_bs, shared_expert_rank_num_0, moe_expert_num_8, ep_world_size, tp_world_size)
expt_out_list_3, expt_token_list_3 = self._construct_excepted_result(x1_list, x2_list, topk1_list, topk2_list, bs, h, k,
global_bs, shared_expert_rank_num_1, moe_expert_num_7, ep_world_size, tp_world_size_1)
expt_out_list_4, expt_token_list_4 = self._construct_excepted_result(x1_list, x2_list, topk1_list, topk2_list, bs, h, k,
global_bs, shared_expert_rank_num_0, moe_expert_num_8, ep_world_size, tp_world_size_1)
self._test_multiprocess(TestMoeDistributeDispatch._test_npu_moe_distribute_dispatch_v2,
TestMoeDistributeDispatch._init_dist_hccl, [expt_out_list_1, expt_token_list_1, x1_list, x2_list, topk1_list,
topk2_list, elastic_info, ep_world_size, tp_world_size, global_bs, shared_expert_rank_num_1, moe_expert_num_7, h])
self._test_multiprocess(TestMoeDistributeDispatch._test_npu_moe_distribute_dispatch_v2,
TestMoeDistributeDispatch._init_dist_hccl, [expt_out_list_2, expt_token_list_2, x1_list, x2_list, topk1_list,
topk2_list, elastic_info, ep_world_size, tp_world_size, global_bs, shared_expert_rank_num_0, moe_expert_num_8, h])
self._test_multiprocess(TestMoeDistributeDispatch._test_npu_moe_distribute_dispatch_v2,
TestMoeDistributeDispatch._init_dist_hccl, [expt_out_list_3, expt_token_list_3, x1_list, x2_list, topk1_list,
topk2_list, elastic_info, ep_world_size, tp_world_size_1, global_bs, shared_expert_rank_num_1, moe_expert_num_7, h])
self._test_multiprocess(TestMoeDistributeDispatch._test_npu_moe_distribute_dispatch_v2,
TestMoeDistributeDispatch._init_dist_hccl, [expt_out_list_4, expt_token_list_4, x1_list, x2_list, topk1_list,
topk2_list, elastic_info, ep_world_size, tp_world_size_1, global_bs, shared_expert_rank_num_0, moe_expert_num_8, h])
@skipIfUnsupportMultiNPU(8)
@SupportedDevices(['Ascend950'])
def test_npu_moe_distribute_dispatch_arch35(self):
ep_world_size = 8
tp_world_size = 1
world_size = ep_world_size * tp_world_size
bs = 64
h = 4096
k = 16
shared_expert_rank_num_1 = 1
moe_expert_num_7 = 7
shared_expert_rank_num_0 = 0
moe_expert_num_8 = 8
global_bs = bs * ep_world_size
dtype = np.bfloat16
data_format = -1
topk = torch.tile(torch.arange(k), (bs,)).int().view(-1, k)
topk1_list = []
topk2_list = []
x1_shape = [dtype, data_format, [bs, h]]
x2_shape = [dtype, data_format, [bs, h]]
x1_list = []
x2_list = []
elastic_info = None
for _ in range(ep_world_size):
x1, _ = create_common_tensor(x1_shape, -1, 1)
x2, _ = create_common_tensor(x2_shape, -1, 1)
x1_list.append(x1)
x2_list.append(x2)
topk1_list.append(topk)
topk2_list.append(topk)
expt_out_list_1, expt_token_list_1 = self._construct_excepted_result(x1_list, x2_list, topk1_list, topk2_list, bs, h, k,
global_bs, shared_expert_rank_num_1, moe_expert_num_7, ep_world_size, tp_world_size)
expt_out_list_2, expt_token_list_2 = self._construct_excepted_result(x1_list, x2_list, topk1_list, topk2_list, bs, h, k,
global_bs, shared_expert_rank_num_0, moe_expert_num_8, ep_world_size, tp_world_size)
self._test_multiprocess(TestMoeDistributeDispatch._test_npu_moe_distribute_dispatch_v2,
TestMoeDistributeDispatch._init_dist_hccl, [expt_out_list_1, expt_token_list_1, x1_list, x2_list, topk1_list,
topk2_list, elastic_info, ep_world_size, tp_world_size, global_bs, shared_expert_rank_num_1, moe_expert_num_7, h])
self._test_multiprocess(TestMoeDistributeDispatch._test_npu_moe_distribute_dispatch_v2,
TestMoeDistributeDispatch._init_dist_hccl, [expt_out_list_2, expt_token_list_2, x1_list, x2_list, topk1_list,
topk2_list, elastic_info, ep_world_size, tp_world_size, global_bs, shared_expert_rank_num_0, moe_expert_num_8, h])
if __name__ == '__main__':
run_tests()