import os
import unittest
import numpy as np
import random
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.functional as F
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import create_common_tensor, SupportedDevices
from torch_npu.testing.common_distributed import skipIfUnsupportMultiNPU
class TestMoeDistributeDispatch(TestCase):
def chunk_tensor(self, tensor, num_chunks):
chunk_size = tensor.size(0) // num_chunks
chunks = []
for i in range(num_chunks):
chunk = tensor[i * chunk_size: (i + 1) * chunk_size]
chunks.append(chunk)
return chunks
def gen_x(self, shape, dtype):
tmp = []
for i in range(0, shape[0]):
tmp.extend([i + 1] * shape[1])
tmp = np.random.uniform(-5, 5, size=shape)
return torch.tensor(np.array(tmp).astype(np.float32)).to(dtype).view(shape)
def gen_expert_ids(self, shape, total_expert_num):
a = list(np.arange(0, total_expert_num).astype(np.int32))
tmp = []
for i in range(0, shape[0]):
ids = random.sample(a, shape[1])
tmp = np.append(tmp, ids)
return torch.tensor(tmp).to(torch.int32).view(shape[0], shape[1])
def gen_scale(self, shape, has_scale):
if has_scale:
return torch.tensor(np.random.uniform(1.0, 1.0, size=shape).astype(np.float32)).to(torch.float32)
else:
return None
def gen_dispatch_golden(self, x, expert_ids, scales, has_scale, k, quant_mode, global_bs, ep_world_size, bs, total_expert_num, expert_num_per_rank):
expand_x = torch.repeat_interleave(x, k, dim=0)
if has_scale:
expand_x = expand_x.to(torch.float32)
scales_gather = torch.gather(scales, 0, expert_ids.view(-1).to(torch.int64)).view(-1, 1)
expand_x = torch.mul(expand_x, scales_gather)
dynamic_scales = None
if quant_mode == 2:
expand_x = expand_x.to(torch.float32)
max_value, _ = torch.max(torch.abs(expand_x), dim=1)
dynamic_scales = (torch.tensor([127.0]).to(torch.float32) / max_value).view(-1, 1).to(torch.float32)
expand_x = expand_x * dynamic_scales
expand_x = expand_x.to(torch.int8)
else:
expand_x = expand_x.to(torch.bfloat16)
expert_ids = expert_ids.view(global_bs * k)
expert_ids_sorted, sorted_idx = torch.sort(expert_ids, stable=True)
torch.sort(sorted_idx)
expand_x_sorted = expand_x[sorted_idx]
dynamic_scales_sorted = None
if quant_mode == 2:
dynamic_scales_sorted = dynamic_scales[sorted_idx].view(-1)
expert_ids_input = self.chunk_tensor(expert_ids, ep_world_size)
expand_idx = torch.zeros(size=(global_bs, k)).to(torch.int32)
for rank_id in range(ep_world_size):
expert_ids_per_rank = expert_ids_input[rank_id].view(-1)
unique_expert, inverse_indices = torch.unique(expert_ids_per_rank, sorted=True, return_inverse=True)
valid_expert_token_num_per_rank = torch.bincount(inverse_indices)
expand_idx_per_rank = torch.zeros(size=(bs, k)).to(torch.int32).view(-1)
for i, value in enumerate(unique_expert):
indices = (expert_ids_per_rank == value).nonzero(as_tuple=True)[0]
expand_idx_per_rank[indices] = torch.arange(0, valid_expert_token_num_per_rank[i]).to(torch.int32)
expand_idx[rank_id * bs: (rank_id + 1) * bs, :] = expand_idx_per_rank.view(bs, k)
vaild_expert_token_nums = torch.bincount(expert_ids).to(torch.int32)
expert_token_nums = F.pad(vaild_expert_token_nums, (0, total_expert_num - vaild_expert_token_nums.size(0)), 'constant', 0)
expert_tokens_num_cumsum = []
for rank_id in range(ep_world_size):
count = torch.cumsum(expert_token_nums[rank_id * expert_num_per_rank : (rank_id + 1) * expert_num_per_rank], dim=0)
expert_tokens_num_cumsum.append(count)
ep_recv_counts = []
for expert_id in range(total_expert_num):
for rank_id in range(ep_world_size):
count = torch.sum(expert_ids_input[rank_id].eq(expert_id)).item()
ep_recv_counts.append(count)
ep_recv_counts = torch.tensor(ep_recv_counts).to(torch.int32)
ep_recv_counts_cumsum = []
for rank_id in range(ep_world_size):
count = torch.cumsum(ep_recv_counts[rank_id * expert_num_per_rank * ep_world_size : (rank_id + 1) * expert_num_per_rank * ep_world_size], dim=0)
ep_recv_counts_cumsum.append(count)
actual_tokens = []
count = 0
for rank_id in range(ep_world_size):
count = count + torch.sum(expert_token_nums[rank_id * expert_num_per_rank : (rank_id + 1) * expert_num_per_rank]).item()
actual_tokens.append(count)
actual_tokens = torch.tensor(actual_tokens).to(torch.int32)
return [expand_x_sorted, dynamic_scales_sorted, expand_idx, expert_tokens_num_cumsum, ep_recv_counts_cumsum, None], actual_tokens
def golden_compare(self, rank_id, golden_tensor_list, golden_actual_tokens_cumsum, npu_result, quant_mode, bs, k):
result = []
start_offset_in_golden = golden_actual_tokens_cumsum[rank_id - 1].item() if rank_id > 0 else 0
end_offset_in_golden = golden_actual_tokens_cumsum[rank_id].item()
expand_x_golden = golden_tensor_list[0][start_offset_in_golden : end_offset_in_golden, :]
golden_actual_tokens = golden_actual_tokens_cumsum[rank_id] if rank_id == 0 else golden_actual_tokens_cumsum[rank_id] - golden_actual_tokens_cumsum[rank_id - 1]
expand_x_npu = npu_result[0][0 : golden_actual_tokens.item(), :]
if quant_mode == 0:
self.assertEqual(expand_x_golden, expand_x_npu,
("rank {} Expect receive tensor {} but got {}.").format(rank_id, expand_x_golden, expand_x_npu))
else:
self.assertRtolEqual(expand_x_golden, expand_x_npu, atol=1)
if quant_mode == 2:
dynamic_scales_golden = golden_tensor_list[1][start_offset_in_golden : end_offset_in_golden]
dynamic_scales_npu = npu_result[1][0 : golden_actual_tokens.item()]
self.assertRtolEqual(dynamic_scales_golden, dynamic_scales_npu, 0.001)
expand_idx_golden = golden_tensor_list[2][bs * rank_id: bs * (rank_id + 1), :]
expand_idx_npu = npu_result[2][:bs * k].view(bs, k)
self.assertEqual(expand_idx_golden, expand_idx_npu,
("rank {} Expect receive tensor {} but got {}.").format(rank_id, expand_idx_golden, expand_idx_npu))
expert_tokens_num_golden = golden_tensor_list[3][rank_id]
expert_tokens_num_npu = npu_result[3]
self.assertEqual(expert_tokens_num_golden, expert_tokens_num_npu,
("rank {} Expect receive tensor {} but got {}.").format(rank_id, expert_tokens_num_golden, expert_tokens_num_npu))
ep_recv_counts_golden = golden_tensor_list[4][rank_id]
ep_recv_counts_npu = npu_result[4]
self.assertEqual(ep_recv_counts_golden, ep_recv_counts_npu,
("rank {} Expect receive tensor {} but got {}.").format(rank_id, ep_recv_counts_golden, ep_recv_counts_npu))
def golden_compare_performance_info(self, performance_info):
if performance_info is None:
return
if performance_info.all(performance_info == 0):
raise ValueError("The performance_info Tensor is all zeros, at least one non-zero value is required!")
@classmethod
def run_dispatch_npu(cls, queue, rank, x, expert_ids, scales, ep_world_size, has_scale, total_expert_num, quant_mode, global_bs, use_comm_alg=False, comm_alg=None, performance_info=None):
torch_npu.npu.set_device(rank % 8)
dist.init_process_group(backend="hccl", rank=rank, world_size=ep_world_size, init_method='tcp://' + "127.0.0.1" + ':' + "50000")
ep_ranks_list = list(np.arange(0, ep_world_size))
ep_group = dist.new_group(backend="hccl", ranks=ep_ranks_list)
ep_hcomm_info = ep_group._get_backend(torch.device("npu")).get_hccl_comm_name(rank)
x_npu = x.npu()
expert_ids_npu = expert_ids.npu()
scales_npu = scales.npu() if has_scale else None
if use_comm_alg:
expand_x, dynamic_scales, expand_idx, expert_token_nums, ep_recv_counts, tp_recv_counts, _ = torch_npu.npu_moe_distribute_dispatch_v2(
x=x_npu,
expert_ids=expert_ids_npu,
group_ep=ep_hcomm_info,
ep_world_size=ep_world_size,
ep_rank_id=rank,
moe_expert_num=total_expert_num,
scales=scales_npu,
quant_mode=quant_mode,
global_bs = global_bs,
comm_alg = comm_alg,
performance_info = performance_info,
)
else:
expand_x, dynamic_scales, expand_idx, expert_token_nums, ep_recv_counts, tp_recv_counts, _ = torch_npu.npu_moe_distribute_dispatch(
x=x_npu,
expert_ids=expert_ids_npu,
group_ep=ep_hcomm_info,
ep_world_size=ep_world_size,
ep_rank_id=rank,
moe_expert_num=total_expert_num,
scales=scales_npu,
quant_mode=quant_mode,
global_bs = global_bs
)
queue.put((rank, [expand_x.cpu(), dynamic_scales.cpu(), expand_idx.cpu(), expert_token_nums.cpu(), ep_recv_counts.cpu(), None]))
@skipIfUnsupportMultiNPU(16)
@SupportedDevices(['Ascend910B'])
def test_npu_moe_distribute_dispatch(self):
has_scale = False
quant_mode = 0
ep_world_size = 16
tp_world_size = 0
world_size = ep_world_size
bs = 8
h = 7168
k = 8
sharedExpertRankNum = 0
moeExpertNum = 16
global_bs = bs * ep_world_size
expert_num_per_rank = 1
total_expert_num = world_size * expert_num_per_rank
input_dtype = torch.bfloat16
x_shape = (global_bs, h)
expert_ids_shape = (global_bs, k)
scales_shape = (total_expert_num, h)
x = self.gen_x(x_shape, input_dtype)
expert_ids = self.gen_expert_ids(expert_ids_shape, total_expert_num)
scales = self.gen_scale(scales_shape, has_scale)
x_input = self.chunk_tensor(x, ep_world_size)
expert_ids_input = self.chunk_tensor(expert_ids, ep_world_size)
scales_input = scales
golden_tensor_list, golden_actual_tokens = self.gen_dispatch_golden(x, expert_ids, scales, has_scale, k, quant_mode, global_bs, ep_world_size, bs, total_expert_num, expert_num_per_rank)
p_list = []
rank_list = list(range(0, ep_world_size))
from torch.multiprocessing import Manager
manager = Manager()
result_queue = manager.Queue()
mp.set_start_method("forkserver", force=True)
for rank_id in rank_list:
p = mp.Process(target=TestMoeDistributeDispatch.run_dispatch_npu, args=(result_queue, rank_id, x_input[rank_id], expert_ids_input[rank_id], scales_input,
ep_world_size, has_scale, total_expert_num, quant_mode, global_bs))
p.start()
p_list.append(p)
results = {}
for p in p_list:
p.join()
rank_id, rank_result = result_queue.get()
results[rank_id] = rank_result
for rank_id in rank_list:
self.golden_compare(rank_id, golden_tensor_list, golden_actual_tokens, results[rank_id], quant_mode, bs, k)
@skipIfUnsupportMultiNPU(16)
@SupportedDevices(['Ascend910B'])
def test_npu_moe_distribute_dispatch_v2(self):
has_scale = False
quant_mode = 0
ep_world_size = 16
tp_world_size = 0
world_size = ep_world_size
bs = 8
h = 7168
k = 8
sharedExpertRankNum = 0
moeExpertNum = 16
global_bs = bs * ep_world_size
expert_num_per_rank = 1
total_expert_num = world_size * expert_num_per_rank
comm_alg = "fullmesh"
input_dtype = torch.bfloat16
x_shape = (global_bs, h)
expert_ids_shape = (global_bs, k)
scales_shape = (total_expert_num, h)
x = self.gen_x(x_shape, input_dtype)
expert_ids = self.gen_expert_ids(expert_ids_shape, total_expert_num)
scales = self.gen_scale(scales_shape, has_scale)
x_input = self.chunk_tensor(x, ep_world_size)
expert_ids_input = self.chunk_tensor(expert_ids, ep_world_size)
scales_input = scales
performance_info = [torch.zeros(ep_world_size, dtype=torch.int64) for rank_id in range(ep_world_size)]
golden_tensor_list, golden_actual_tokens = self.gen_dispatch_golden(x, expert_ids, scales, has_scale, k, quant_mode, global_bs, ep_world_size, bs, total_expert_num, expert_num_per_rank)
p_list = []
rank_list = list(range(0, ep_world_size))
from torch.multiprocessing import Manager
manager = Manager()
result_queue = manager.Queue()
mp.set_start_method("forkserver", force=True)
for rank_id in rank_list:
p = mp.Process(target=TestMoeDistributeDispatch.run_dispatch_npu, args=(result_queue, rank_id, x_input[rank_id], expert_ids_input[rank_id], scales_input,
ep_world_size, has_scale, total_expert_num, quant_mode, global_bs, True, comm_alg, performance_info[rank_id]))
p.start()
p_list.append(p)
results = {}
for p in p_list:
p.join()
rank_id, rank_result = result_queue.get()
results[rank_id] = rank_result
for rank_id in rank_list:
self.golden_compare(rank_id, golden_tensor_list, golden_actual_tokens, results[rank_id], quant_mode, bs, k)
self.golden_compare_performance_info(performance_info[rank_id])
if __name__ == '__main__':
run_tests()