"""
AllGather MatMul Module
This module implements a fused all-gather and matmul operation for distributed models.
It efficiently combines communication and computation, reducing memory overhead.
Main Functions:
- allgather_matmul: Main function for fused all-gather and matmul computation
- test_allgather_matmul: Functional test function
"""
import multiprocessing as mp
import traceback
import numpy as np
import pytest
import torch
from torch._dynamo import allow_in_graph
from torch._subclasses import fake_tensor
import pypto
from distributed_config import DistributedConfig, collect_process_errors
@pypto.frontend.jit()
def allgather_matmul_kernel(
in_tensor: pypto.Tensor([pypto.DYNAMIC, ...], pypto.DT_BF16),
matmul_weight: pypto.Tensor(),
out_tensor: pypto.Tensor([pypto.DYNAMIC, ...], pypto.DT_BF16),
group_name,
world_size,
attn_dim_per_tp_val,
):
batch_size_per_rank = in_tensor.shape[0]
view_row_shape = 8
bs_loop = (batch_size_per_rank + view_row_shape - 1) // view_row_shape
pypto.set_vec_tile_shapes(attn_dim_per_tp_val)
for bs_idx in pypto.loop(bs_loop, name="LOOP_AG_MM", idx_name="bs_idx"):
shmem_shape = [view_row_shape * world_size, attn_dim_per_tp_val]
shmem_tensor = pypto.distributed.create_shmem_tensor(
group_name, world_size, pypto.DT_BF16, shmem_shape)
shmem_barrier_signal = pypto.distributed.create_shmem_signal(group_name, world_size)
my_pe = pypto.distributed.my_symbolic_pe(group_name)
in_tensor_tile = pypto.view(
in_tensor, (view_row_shape, in_tensor.shape[1]), [bs_idx * view_row_shape, 0],
valid_shape=[(batch_size_per_rank - bs_idx * view_row_shape).min(view_row_shape), in_tensor.shape[1]])
pypto.set_vec_tile_shapes(view_row_shape, attn_dim_per_tp_val)
data_clear_out = pypto.distributed.shmem_clear_data(
shmem_tensor, shmem_shape, [0, 0], pred=[in_tensor_tile])
barrier_out = pypto.distributed.shmem_barrier_all(
shmem_barrier_signal, [data_clear_out])
gathered_tensor = pypto.tensor([view_row_shape * world_size, attn_dim_per_tp_val], pypto.DT_BF16, "gathered")
pypto.set_vec_tile_shapes(view_row_shape, attn_dim_per_tp_val)
for dyn_idx in range(world_size):
put_out = pypto.distributed.shmem_put(
in_tensor_tile, [my_pe * view_row_shape, 0], shmem_tensor, dyn_idx,
put_op=pypto.AtomicType.SET, pred=[barrier_out])
pypto.distributed.shmem_signal(
shmem_tensor, dyn_idx, 1, [view_row_shape, attn_dim_per_tp_val],
[my_pe * view_row_shape, 0], target_pe=dyn_idx, sig_op=pypto.AtomicType.ADD, pred=[put_out])
valid_row = (batch_size_per_rank - bs_idx * view_row_shape).min(view_row_shape)
wait_until_out = pypto.distributed.shmem_wait_until(
shmem_tensor, my_pe, 1, [view_row_shape, attn_dim_per_tp_val],
[dyn_idx * view_row_shape, 0], cmp=pypto.OpType.EQ, clear_signal=True, pred=[in_tensor_tile])
pypto.set_pass_options(sg_set_scope=(1, True, False))
shmem_get_out = pypto.experimental.shmem_load(
shmem_tensor, my_pe, [view_row_shape, attn_dim_per_tp_val],
[dyn_idx * view_row_shape, 0], pred=[wait_until_out],
valid_shape=[valid_row, attn_dim_per_tp_val])
gathered_tensor[dyn_idx * view_row_shape:] = shmem_get_out
pypto.set_pass_options(sg_set_scope=-1)
pypto.set_pass_options(sg_set_scope=(1, True, False))
pypto.set_cube_tile_shapes([8, 8], [128, 256], [256, 512])
matmul_result = pypto.matmul(gathered_tensor, matmul_weight, pypto.DT_BF16, b_trans=True)
out_tensor[bs_idx * pypto.symbolic_scalar(view_row_shape):] = matmul_result
pypto.set_pass_options(sg_set_scope=-1)
_KERNEL_MAP = {
'normal': allgather_matmul_kernel,
}
def generate_golden_data(config: DistributedConfig):
batch_size_per_rank = 8
attn_dim_per_tp = 1536
hidden_size = 5120
torch.manual_seed(42)
input_datas = []
for _ in range(config.world_size):
in_tensor = torch.randn((batch_size_per_rank, attn_dim_per_tp), dtype=torch.bfloat16)
matmul_weight = torch.randn((hidden_size, attn_dim_per_tp), dtype=torch.bfloat16)
input_data = [in_tensor, matmul_weight]
input_datas.append(input_data)
output_datas = allgather_matmul_result_golden(batch_size_per_rank, input_datas)
return input_datas, output_datas, attn_dim_per_tp
def allgather_matmul_result_golden(batch_size_per_rank, input_datas):
output_datas = []
full_batch_size = batch_size_per_rank * len(input_datas)
for inputs in input_datas:
allgather_result = torch.zeros((full_batch_size, input_datas[0][0].shape[1]), dtype=torch.bfloat16)
for dyn_idx, input_data in enumerate(input_datas):
in_tensor = input_data[0]
allgather_result[dyn_idx * batch_size_per_rank:(dyn_idx + 1) * batch_size_per_rank, :] = in_tensor
matmul_weight = inputs[1]
matmul_result = torch.matmul(allgather_result, matmul_weight.T)
output_datas.append([matmul_result])
return output_datas
def get_check_threshold(world_size: int, golden_tensor: torch.Tensor):
scale = np.mean(np.abs(golden_tensor.cpu().flatten().tolist())) + 1e-12
eps = 2 ** -7
tolerance_coeff = min((1 + 0.25 * np.log2(world_size)) * 1.5, 3.0)
rtol = tolerance_coeff * eps * np.sqrt(world_size)
atol = rtol * scale
return rtol, atol
def allgather_matmul_worker(
config: DistributedConfig,
input_data: list,
output_data: list,
logical_rank_id: int,
error_queue: mp.Queue = None,
kernel_name: str = 'normal',
attn_dim_per_tp: int = 1536,
):
try:
kernel = _KERNEL_MAP[kernel_name]
groups = config.init_hccl_comm(logical_rank_id)
device = f'npu:{config.get_physical_device_id(logical_rank_id)}'
in_tensor, matmul_weight = input_data
golden_out_tensor = output_data[0]
out_tensor = torch.empty(
(in_tensor.shape[0] * config.world_size, matmul_weight.shape[0]),
dtype=torch.bfloat16, device=device)
inputs = [in_tensor.to(device), matmul_weight.to(device), out_tensor]
kernel(*inputs, groups[0], config.world_size, attn_dim_per_tp)
out_tensor_rtol, out_tensor_atol = get_check_threshold(config.world_size, golden_out_tensor)
np.testing.assert_allclose(
np.array(out_tensor.cpu().flatten().tolist()),
np.array(golden_out_tensor.cpu().flatten().tolist()),
rtol=out_tensor_rtol,
atol=out_tensor_atol,
)
except Exception as e:
if error_queue is not None:
error_queue.put((logical_rank_id, str(e), traceback.format_exc()))
raise
@allow_in_graph
def allgather_matmul(
in_tensor: torch.Tensor,
matmul_weight: torch.Tensor,
group_name: str,
world_size: int,
):
if isinstance(in_tensor, fake_tensor.FakeTensor):
return None
attn_dim_per_tp = in_tensor.shape[1]
out_tensor = torch.empty(
(in_tensor.shape[0] * world_size, matmul_weight.shape[0]),
dtype=torch.bfloat16, device=in_tensor.device)
allgather_matmul_kernel(in_tensor, matmul_weight, out_tensor, group_name, world_size, attn_dim_per_tp)
return out_tensor
@pytest.mark.world_size(2)
def test_allgather_matmul():
mp.set_start_method('spawn', force=True)
config = DistributedConfig(world_size=2)
input_datas, output_datas, attn_dim_per_tp = generate_golden_data(config)
error_queue = mp.Queue()
processes = []
for i in range(config.world_size):
p = mp.Process(
target=allgather_matmul_worker,
args=(config, input_datas[i], output_datas[i], i, error_queue, 'normal', attn_dim_per_tp)
)
p.start()
processes.append(p)
for p in processes:
p.join()
collect_process_errors(processes, error_queue)
def main():
test_allgather_matmul()
if __name__ == '__main__':
main()