304e7895创建于 2025年12月31日历史提交

torch_npu.npu_gmm_alltoallv

产品支持情况

产品

是否支持

Atlas A3 训练系列产品/Atlas A3 推理系列产品

功能说明

  • API功能:MoE网络中,完成路由专家GroupedMatMul、AlltoAllv融合并实现与共享专家MatMul并行融合,先计算后通信。

  • 路由专家计算公式:

    • gmm_x指路由专家GroupedMatMul计算的左矩阵。
    • gmm_weight指路由专家GroupedMatMul计算的右矩阵。
    • gmm_y指路由专家进行GroupedMatMul计算的输出,后续用于Unpermute计算。
    • unpermute_out是gmm_y进行Unpermute计算的输出结果,作为AlltoAllv通信的输入。
    • y指对unpermute_out进行AlltoAllv通信输出。
  • 共享专家计算公式:

    • mm_x指共享专家MatMul计算的左矩阵。
    • mm_weight指共享专家MatMul计算的右矩阵。
    • mm_y指共享专家MatMul计算的输出。

函数原型

torch_npu.npu_gmm_alltoallv(gmm_x, gmm_weight, hcom, ep_world_size, send_counts, recv_counts, *, send_counts_tensor=None, recv_counts_tensor=None, mm_x=None, mm_weight=None, trans_gmm_weight=False, trans_mm_weight=False) -> (Tensor, Tensor)

参数说明

  • gmm_xTensor):必选参数,GroupedMatMul计算的左矩阵。数据类型支持float16bfloat16,支持2维,shape为(A,H1)(A, H1),数据格式支持ND。
  • gmm_weightTensor):必选参数,GroupedMatMul计算的右矩阵。数据类型与gmm_x保持一致,支持3维,shape为(e,H1,N1)(e, H1, N1),数据格式支持ND。
  • hcomstr):必选参数,专家并行的通信域名,字符串长度要求(0, 128)。
  • ep_world_sizeint):必选参数,EP通信域size,取值支持8、16、32、64、128。
  • send_countsList[int]):必选参数,为一个列表,表示发送给其他卡的token数,列表长度为卡数。列表中元素的数据类型支持int,取值为e*ep_world_size,最大值为256。
  • recv_countsList[int]):必选参数,为一个列表,表示接收其他卡的token数,列表长度为卡数。列表中元素的数据类型支持int,取值大小为e*ep_world_size,最大值为256。
  • send_counts_tensorTensor):可选参数,数据类型支持int,shape为(e∗ep_world_size,)(e*ep\_world\_size,),数据格式支持ND。当前版本暂不支持,使用默认值即可。
  • recv_counts_tensorTensor):可选参数,数据类型支持int,shape为(e∗ep_world_size,)(e*ep\_world\_size,),数据格式支持ND。当前版本暂不支持,使用默认值即可。
  • mm_xTensor):可选参数,共享专家MatMul计算中的左矩阵。当需要融合共享专家矩阵计算时,该参数必选,数据类型支持float16bfloat16,支持2维,shape为(BS,H2)(BS, H2)
  • mm_weightTensor):可选参数,共享专家MatMul计算中的右矩阵。当需要融合共享专家矩阵计算时,该参数必选,数据类型与mm_x保持一致,支持2维,shape为(H2,N2)(H2, N2)
  • trans_gmm_weightbool):可选参数,GroupedMatMul的右矩阵是否需要转置,true表示需要转置,false表示不转置。
  • trans_mm_weightbool):可选参数,共享专家MatMul的右矩阵是否需要转置,true表示需要转置,false表示不转置。

返回值说明

  • yTensor):表示最终计算结果,数据类型与输入gmm_x保持一致,支持2维,shape为(BSK,N1)(BSK, N1)
  • mm_yTensor):共享专家MatMul的输出,数据类型与mm_x保持一致,支持2维,shape为(BS,N2)(BS, N2)。仅当传入mm_xmm_weight才输出。

约束说明

  • 该接口支持推理场景下使用。
  • 该接口支持图模式。
  • 单卡通信量取值大于等于2MB。
  • 输入参数Tensor中shape使用的变量说明:
    • BSK:本卡接收的token数(BS*K=BSK),是recv_counts参数累加之和,取值范围(0, 52428800)。

    • H1:表示路由专家hidden size隐藏层大小,取值范围(0, 65536)。

    • H2:表示共享专家hidden size隐藏层大小,取值范围(0, 12288]。

    • e:表示单卡上专家个数,e<=32,e * ep_world_size最大支持256。

    • N1:表示路由专家的head_num,取值范围(0, 65536)。

    • N2:表示共享专家的head_num,取值范围(0, 65536)。

    • BS:batch sequence size。

    • K:表示选取top_k个专家,K的范围[2, 8]。

    • A:本卡发送的token数,是send_counts参数累加之和。

    • EP通信域内所有卡上的A参数的累加和等于所有卡上的BSK参数的累加和。

调用示例

  • 单算子模式调用

    import torch
    import torch_npu
    import torch.distributed as dist
    import torch.multiprocessing as mp
    
    def run_npu_gmm_alltoallv(rank, ep_world_size, master_ip, master_port, gmm_x, gmm_w, send_counts, recv_counts, dtype):
        torch_npu.npu.set_device(rank)
        init_method = 'tcp://' + master_ip + ':' + master_port
        dist.init_process_group(backend="hccl", rank=rank, world_size=ep_world_size, init_method=init_method)
        from torch.distributed.distributed_c10d import _get_default_group
        default_pg = _get_default_group()
        if torch.__version__ > '2.0.1':
            hcom_info = default_pg._get_backend(torch.device("npu")).get_hccl_comm_name(rank)
        else:
            hcom_info = default_pg.get_hccl_comm_name(rank)
    
        input = torch.randn(gmm_x, dtype=dtype).npu()
        weight = torch.randn(gmm_w, dtype=dtype).npu()
    
        print(torch_npu.npu_gmm_alltoallv(gmm_x =input,
                                                gmm_weight = weight,
                                                hcom= hcom_info,
                                                ep_world_size = ep_world_size,
                                                send_counts = list(send_counts),
                                                recv_counts = list(recv_counts),
                                                send_counts_tensor = None, #send_counts_tensor,
                                                recv_counts_tensor = None, #recv_counts_tensor,
                                                mm_x = None,
                                                mm_weight = None,
                                                trans_gmm_weight = False,
                                                trans_mm_weight  = False))
    
    if __name__ == "__main__":
        epWorkSize = 8
        e = 4
        master_ip = '127.0.0.1'
        master_port = '50001'
        BS = 512
        K = 8
        gmm_x_shape = [BS*K, 2048]
        gmm_weight_shape = [e, 2048, 2048]
        send_counts = [128] * (e * epWorkSize)
        recv_counts = [128] * (e * epWorkSize)
    
        dtype = torch.float16
    
        mp.spawn(run_npu_gmm_alltoallv, args=(epWorkSize, master_ip, master_port, gmm_x_shape, gmm_weight_shape, send_counts, recv_counts, dtype), nprocs=epWorkSize)
    
  • 图模式调用

    import torch
    import torch_npu
    import torch.distributed as dist
    import torch.multiprocessing as mp
    import torchair
    
    class GMM_ALLTOALLV_GRAPH_Model(torch.nn.Module):
        def __init__(self):
            super().__init__()
        def forward(self,gmm_x, gmm_weight,
                    hcom, ep_world_size,
                    send_counts, recv_counts, send_counts_tensor, recv_counts_tensor,
                    mm_x, mm_weight,
                    trans_gmm_weight, trans_mm_weight):
            return torch_npu.npu_gmm_alltoallv(gmm_x =gmm_x,
                                                gmm_weight = gmm_weight,
                                                hcom= hcom,
                                                ep_world_size = ep_world_size,
                                                send_counts = list(send_counts),
                                                recv_counts = list(recv_counts),
                                                send_counts_tensor = None, #send_counts_tensor,
                                                recv_counts_tensor = None, #recv_counts_tensor,
                                                mm_x = mm_x,
                                                mm_weight = mm_weight,
                                                trans_gmm_weight = trans_gmm_weight,
                                                trans_mm_weight  = trans_mm_weight)
    
    def run_npu_gmm_alltoallv(rank, ep_world_size, master_ip, master_port, gmm_x, gmm_w, send_counts, recv_counts, dtype):
        torch_npu.npu.set_device(rank)
        init_method = 'tcp://' + master_ip + ':' + master_port
        dist.init_process_group(backend="hccl", rank=rank, world_size=ep_world_size, init_method=init_method)
        from torch.distributed.distributed_c10d import _get_default_group
        default_pg = _get_default_group()
        if torch.__version__ > '2.0.1':
            hcom_info = default_pg._get_backend(torch.device("npu")).get_hccl_comm_name(rank)
        else:
            hcom_info = default_pg.get_hccl_comm_name(rank)
    
        input = torch.randn(gmm_x, dtype=dtype).npu()
        weight = torch.randn(gmm_w, dtype=dtype).npu()
    
        model = GMM_ALLTOALLV_GRAPH_Model()
        npu_backend = torchair.get_npu_backend(compiler_config=None)
        # 静态图:dynamic=False;动态图:dynamic=True
        model = torch.compile(GMM_ALLTOALLV_GRAPH_Model(), backend=npu_backend, dynamic=False)
        print(model(gmm_x=input,
                        gmm_weight=weight,
                        send_counts_tensor=None,
                        recv_counts_tensor=None,
                        mm_x=None,
                        mm_weight=None,
                        hcom=hcom_info,
                        ep_world_size=ep_world_size,
                        send_counts=send_counts,
                        recv_counts=recv_counts,
                        trans_gmm_weight=False,
                        trans_mm_weight=False))
    
    if __name__ == "__main__":
        epWorkSize = 8
        e = 4
        master_ip = '127.0.0.1'
        master_port = '50001'
        BS = 512
        K = 8
        gmm_x_shape = [BS*K, 2048]
        gmm_weight_shape = [e, 2048, 2048]
        send_counts = [128] * (e * epWorkSize)
        recv_counts = [128] * (e * epWorkSize)
    
        dtype = torch.float16
    
        mp.spawn(run_npu_gmm_alltoallv, args=(epWorkSize, master_ip, master_port, gmm_x_shape, gmm_weight_shape, send_counts, recv_counts, dtype), nprocs=epWorkSize)