"""
The scripts provides an example for multiple comm groups with allreduce cascading case
Main Functions:
- allreduce_cascading_worker: Main function for allreduce_cascading
- test_allreduce_cascading: Performance test function
"""
import multiprocessing as mp
import traceback
import numpy as np
import pytest
import torch
import pypto
from distributed_config import DistributedConfig, collect_process_errors
def _create_shmem_tensors(group_names, world_sizes, shmem_shape):
shmem_tensor = pypto.distributed.create_shmem_tensor(
group_names[0], world_sizes[0], pypto.DT_FP32, shmem_shape)
my_pe = pypto.distributed.my_symbolic_pe(group_names[0])
shmem_tensor1 = pypto.distributed.create_shmem_tensor(
group_names[1], world_sizes[1], pypto.DT_FP32, shmem_shape)
my_pe1 = pypto.distributed.my_symbolic_pe(group_names[1])
return shmem_tensor, my_pe, shmem_tensor1, my_pe1
def _get_input_tile(in_tensor, batch_size, bs_idx, view_row_shape, hidden_size):
in_tensor_tile = pypto.view(
in_tensor, (view_row_shape, in_tensor.shape[1]), [bs_idx * view_row_shape, 0],
valid_shape=[(batch_size - bs_idx * view_row_shape).min(view_row_shape), in_tensor.shape[1]])
pypto.set_vec_tile_shapes(view_row_shape, hidden_size)
in_tensor_tile_fp32 = pypto.cast(in_tensor_tile, pypto.DT_FP32)
return in_tensor_tile_fp32
def _perform_allreduce(input_tensor, stage_params):
"""执行 AllReduce 阶段
Args:
input_tensor: 输入tensor
stage_params: 列表参数 [shmem_tensor, world_size, shmem_shape, my_pe,
view_row_shape, hidden_size, batch_size, bs_idx]
Returns:
AllReduce 结果 (BF16)
"""
(shmem_tensor, world_size, shmem_shape, my_pe,
view_row_shape, hidden_size, batch_size, bs_idx) = stage_params
pypto.set_vec_tile_shapes(view_row_shape, hidden_size)
for dyn_idx in range(world_size):
put_out = pypto.distributed.shmem_put(
input_tensor, [0, 0], shmem_tensor, dyn_idx,
put_op=pypto.AtomicType.ADD, pred=[input_tensor]
)
pypto.distributed.shmem_signal(
shmem_tensor, dyn_idx, 1, shmem_shape,
[0, 0], target_pe=dyn_idx, sig_op=pypto.AtomicType.ADD, pred=[put_out]
)
wait_until_out = pypto.distributed.shmem_wait_until(
shmem_tensor, my_pe, world_size,
shmem_shape, [0, 0], cmp=pypto.OpType.EQ, clear_signal=True, pred=[input_tensor]
)
all_reduce_out_fp32 = pypto.distributed.shmem_get(
shmem_tensor, my_pe, shmem_shape, [0, 0], pred=[wait_until_out],
valid_shape=[(batch_size - bs_idx * view_row_shape).min(view_row_shape), hidden_size]
)
pypto.set_vec_tile_shapes(1, hidden_size)
all_reduce_out_bf16 = pypto.cast(all_reduce_out_fp32, pypto.DT_BF16)
return all_reduce_out_bf16
@pypto.frontend.jit()
def allreduce_cascading_kernel(
in_tensor: pypto.Tensor([pypto.DYNAMIC, ...], pypto.DT_BF16),
out_tensor: pypto.Tensor([pypto.DYNAMIC, ...], pypto.DT_BF16),
group_names,
world_sizes,
):
batch_size = in_tensor.shape[0]
hidden_size = in_tensor.shape[1]
view_row_shape = 8
bs_loop = (batch_size + view_row_shape - 1) // view_row_shape
shmem_shape = [view_row_shape, hidden_size]
for bs_idx in pypto.loop(bs_loop, name="LOOP_ALLREDUCE_CASCADING", idx_name="bs_idx"):
shmem_tensor, my_pe, shmem_tensor1, my_pe1 = _create_shmem_tensors(
group_names, world_sizes, shmem_shape)
in_tensor_tile_fp32 = _get_input_tile(
in_tensor, batch_size, bs_idx, view_row_shape, hidden_size)
stage1_params = [shmem_tensor, world_sizes[0], shmem_shape, my_pe,
view_row_shape, hidden_size, batch_size, bs_idx]
all_reduce_out_bf16 = _perform_allreduce(in_tensor_tile_fp32, stage1_params)
stage2_params = [shmem_tensor1, world_sizes[1], shmem_shape, my_pe1,
view_row_shape, hidden_size, batch_size, bs_idx]
all_reduce_out_bf161 = _perform_allreduce(all_reduce_out_bf16, stage2_params)
out_tensor[bs_idx * pypto.symbolic_scalar(view_row_shape):] = all_reduce_out_bf161
def generate_group_splits(world_size, group_info=None):
if group_info is not None:
return group_info
group_info = {
"even_odd_0": list(range(0, world_size, 2)),
"even_odd_1": list(range(1, world_size, 2)),
"half_0": list(range(0, world_size // 2)),
"half_1": list(range(world_size // 2, world_size)),
}
return group_info
def generate_allreduce_cascading_golden_data(config: DistributedConfig, group_info=None):
batch_size = 13
hidden_size = 256
world_size = config.world_size
input_datas = []
for _ in range(world_size):
in_tensor = torch.randn((batch_size, hidden_size), dtype=torch.bfloat16).share_memory_()
input_datas.append([in_tensor])
group_info = generate_group_splits(world_size, group_info)
intermediate_datas = [None] * world_size
for group_name in ["even_odd_0", "even_odd_1"]:
group_ranks = group_info.get(group_name, [])
if not group_ranks:
continue
group_sum_fp32 = torch.zeros((batch_size, hidden_size), dtype=torch.float32)
for rank in group_ranks:
group_sum_fp32 += input_datas[rank][0].to(torch.float32).cpu()
group_sum_bf16 = group_sum_fp32.to(torch.bfloat16)
for rank in group_ranks:
intermediate_datas[rank] = group_sum_bf16
output_datas = [None] * world_size
for group_name in ["half_0", "half_1"]:
group_ranks = group_info.get(group_name, [])
if not group_ranks:
continue
group_sum_fp32 = torch.zeros((batch_size, hidden_size), dtype=torch.float32)
for rank in group_ranks:
group_sum_fp32 += intermediate_datas[rank].to(torch.float32).cpu()
group_sum_bf16 = group_sum_fp32.to(torch.bfloat16)
for rank in group_ranks:
output_datas[rank] = group_sum_bf16
return input_datas, output_datas
def allreduce_cascading_worker(worker_params, error_queue: mp.Queue):
"""
Args:
worker_params: 列表参数 [config, input_data, output_data, logical_rank_id, group_info]
error_queue: 错误队列
"""
try:
config, input_data, output_data, logical_rank_id, group_info = worker_params
group_info = generate_group_splits(config.world_size, group_info)
groups = config.init_hccl_comm(logical_rank_id, group_info)
physical_device_id = config.get_physical_device_id(logical_rank_id)
device = f'npu:{physical_device_id}'
in_tensor = input_data[0]
golden_out_tensor = output_data
out_tensor = torch.empty(in_tensor.shape, dtype=torch.bfloat16, device=device)
inputs = [in_tensor.to(device)]
group_key0 = "even_odd_0" if logical_rank_id % 2 == 0 else "even_odd_1"
group_name0 = groups[0]
world_size0 = len(group_info.get(group_key0, []))
mid = config.world_size // 2
group_key1 = "half_0" if logical_rank_id < mid else "half_1"
group_name1 = groups[1]
world_size1 = len(group_info.get(group_key1, []))
group_names = [group_name0, group_name1]
world_sizes = [world_size0, world_size1]
allreduce_cascading_kernel(*inputs, out_tensor, group_names, world_sizes)
np.testing.assert_allclose(
np.array(out_tensor.cpu().flatten().tolist()),
np.array(golden_out_tensor.cpu().flatten().tolist()),
rtol=8e-3,
atol=8e-3,
)
except Exception as e:
if error_queue is not None:
error_queue.put((logical_rank_id, str(e), traceback.format_exc()))
raise
@pytest.mark.world_size(4)
def test_allreduce_cascading():
mp.set_start_method('spawn', force=True)
config = DistributedConfig(world_size=4)
default_group_info = {
"even_odd_0": [0, 2],
"even_odd_1": [1, 3],
"half_0": [0, 1],
"half_1": [2, 3],
}
input_datas, output_datas = generate_allreduce_cascading_golden_data(config)
error_queue = mp.Queue()
processes = []
for i in range(config.world_size):
worker_params = [config, input_datas[i], output_datas[i], i, None]
p = mp.Process(
target=allreduce_cascading_worker,
args=(worker_params, error_queue)
)
p.start()
processes.append(p)
for p in processes:
p.join()
collect_process_errors(processes, error_queue)
def main():
test_allreduce_cascading()
if __name__ == '__main__':
main()