torch_npu.npu_all_gather_base_mm

产品支持情况

产品 是否支持
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品

功能说明

  • API功能:TP切分场景下,融合allgathermatmul,实现通信和计算流水并行。

  • 计算公式: x1x1代表输入x1

    基础场景:

    output=allgather(x1)@x2+biasoutput = allgather(x1) \mathbin{@} x2 + bias

    gather_out=allgather(x1)gather\_out = allgather(x1)

    量化场景:

    output=(allgather(x1_scale)∗x2_scale)∗(allgather(x1)@x2+bias)output = (allgather(x1\_scale) * x2\_scale) * (allgather(x1)\mathbin{@} x2 + bias)

    gather_out=allgather(x1)gather\_out = allgather(x1)

Note

使用该接口时,请确保驱动固件包和CANN包都为配套的8.0.RC2版本或者配套的更高版本,否则将会引发报错,比如BUS ERROR等。

函数原型

torch_npu.npu_all_gather_base_mm(x1, x2, hcom, world_size, bias=None, x1_scale=None, x2_scale=None, gather_index=0, gather_output=True, comm_turn=0, output_dtype=None, comm_mode=None) -> tuple[Tensor, Tensor]

参数说明

  • x1 (Tensor):必选参数,表示矩阵乘法中的左矩阵,数据类型支持float16bfloat16int8,数据格式支持ND,输入shape支持2维,形如(m, k),轴满足matmul算子入参要求,第二轴与x2的第一轴相等,且k的取值范围为[256, 65535)。
  • x2 (Tensor):必选参数,表示矩阵乘法中的右矩阵,数据类型需要和x1保持一致,数据格式支持NDNDNZNZNZNZ仅在comm_modeaiv时支持。输入shape支持2维,形如(k, n),轴满足matmul算子入参要求,第一轴与x1的第二轴相等,且k的取值范围为[256, 65535)。
  • hcom (string):必选参数,通信域handle名,通过get_hccl_comm_name接口获取。
  • world_size (int):必选参数,通信域内的rank总数。
    • Atlas A2 训练系列产品:支持2、4、8卡,支持HCCS链路all mesh组网(每张卡和其它卡两两相连)。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:支持2、4、8、16、32卡,支持HCCS链路double ring组网(多张卡按顺序组成一个圈,每张卡只和左右卡相连)。
  • bias (Tensor):可选参数,数据类型支持float16、bfloat16,数据格式支持ND格式。数据类型需要和x1保持一致。bias仅支持一维,且维度大小与output的第1维大小相同。当前版本暂不支持bias输入为非0的场景。
  • x1_scale (Tensor):可选参数,mm左矩阵反量化参数。数据类型支持float32,数据格式支持NDND格式。数据维度为(m, 1),支持pertoken量化。
  • x2_scale (Tensor):可选参数,mm右矩阵反量化参数。数据类型支持float32int64,数据格式支持NDND格式。数据维度为(1, n),支持perchannel量化。如需传入int64数据类型的,需要提前调用torch_npu.npu_trans_quant_param来获取int64数据类型的x2_scale
  • gather_index (int):可选参数,表示gather操作对象,0表示对x1做gather,1表示对x2做gather。默认值0。当前版本仅支持输入0。
  • gather_output (bool):可选参数,表示是否需要gather输出。默认值True。
  • comm_turn (int):可选参数,表示rank间通信切分粒度,默认值为0,表示默认的切分方式。当前版本仅支持输入0。
  • output_dtype (ScalarType):可选参数,表示第一个输出的数据类型。仅支持在量化场景且x1_scalex2_scale均为float32时,可指定输出数据类型为bfloat16float16,默认值为bfloat16
  • comm_mode (string):可选参数,表示通信模式,支持ai_cpuaiv两种模式。ai_cpu模式仅支持基础场景。aiv模式支持基础场景和量化场景。默认值为ai_cpu

返回值说明

  • output (Tensor):第一个输出Tensor,allgather+matmul的结果。 基础场景时数据类型和x1保持一致。 量化场景下,x2_scaleint64数据类型时,输出数据类型为float16x1_scalex2_scale均为float32时,输出数据类型由output_dtype指定,默认为bfloat16
  • gather_out (Tensor):第二个输出Tensor,allgather的结果,由gather_output参数控制是否输出,gather_output为False时,返回空Tensor。

约束说明

  • x1不支持输入转置后的tensor,x2转置后输入,需要满足shape的第一维大小与x1的最后一维相同,满足matmul的计算条件。
  • comm_modeai_cpu时:
    • 该接口支持训练场景下使用。
    • 该接口支持图模式。
    • Atlas A2 训练系列产品:一个模型中的通算融合算子(AllGatherMatmul、MatmulReduceScatter、MatmulAllReduce),仅支持相同通信域。
  • comm_modeaiv时,训练和推理场景均可使用。

调用示例

  • 单算子模式调用

    import torch
    import torch_npu
    import torch.distributed as dist
    import torch.multiprocessing as mp
    def run_all_gather_base_mm(rank, world_size, master_ip, master_port, x1_shape, x2_shape, dtype):
        torch_npu.npu.set_device(rank)
        init_method = 'tcp://' + master_ip + ':' + master_port
        dist.init_process_group(backend="hccl", rank=rank, world_size=world_size, init_method=init_method)
        from torch.distributed.distributed_c10d import _get_default_group
        default_pg = _get_default_group()
        if torch.__version__ > '2.0.1':
            hcomm_info = default_pg._get_backend(torch.device("npu")).get_hccl_comm_name(rank)
        else:
            hcomm_info = default_pg.get_hccl_comm_name(rank)
    
        tensor_allgather_shape = x1_shape
        single_shape = [x1_shape[0] // world_size, x1_shape[1]]
    
        input_ = torch.randn(single_shape, dtype=dtype).npu()
        weight = torch.randn(x2_shape, dtype=dtype).npu()
        output, gather_out = torch_npu.npu_all_gather_base_mm(input_, weight, hcomm_info, world_size)
    
    if __name__ == "__main__":
        worksize = 8
        master_ip = '127.0.0.1'
        master_port = '50001'
        x1_shape = [128, 512]
        x2_shape = [512, 64]
        dtype = torch.float16
    
        mp.spawn(run_all_gather_base_mm, args=(worksize, master_ip, master_port, x1_shape, x2_shape, dtype), nprocs=worksize)
    
  • 图模式调用

    import torch
    import torch_npu
    import torch.distributed as dist
    import torch.multiprocessing as mp
    class ALLGATHER_MM_GRAPH_Model(torch.nn.Module):
        def __init__(self):
            super().__init__()
        def forward(self, input, weight, hcomm_info, world_size, gather_output):
            output, gather_output = torch_npu.npu_all_gather_base_mm(input, weight, hcomm_info, world_size, gather_output=gather_output)
            return output, gather_output
    def define_model(model, graph_type):
        import torchair
        if graph_type == 1:  # 传统入图模式,静态shape+在线编译场景
            npu_backend = torchair.get_npu_backend(compiler_config=None)
            model = torch.compile(model, backend=npu_backend, dynamic=False)
        elif graph_type == 2:  # ACLNN入图模式,动态shape+二进制
            npu_backend = torchair.get_npu_backend(compiler_config=None)
            model = torch.compile(model, backend=npu_backend, dynamic=True)
        else:
            print("Error type")
        return model
    def get_graph(input, weight, hcomm_info, world_size, gather_output):
        model = ALLGATHER_MM_GRAPH_Model()
        model = define_model(model, 2)
        model_output = model(input, weight, hcomm_info, world_size, gather_output=gather_output)
        output_npu = model_output[0]
        gather_output_npu = model_output[1]
        return output_npu, gather_output_npu
    def run_all_gather_base_mm(rank, world_size, master_ip, master_port, x1_shape, x2_shape, dtype):
        torch_npu.npu.set_device(rank)
        init_method = 'tcp://' + master_ip + ':' + master_port
        dist.init_process_group(backend="hccl", rank=rank, world_size=world_size, init_method=init_method)
        from torch.distributed.distributed_c10d import _get_default_group
        default_pg = _get_default_group()
        if torch.__version__ > '2.0.1':
            hcomm_info = default_pg._get_backend(torch.device("npu")).get_hccl_comm_name(rank)
        else:
            hcomm_info = default_pg.get_hccl_comm_name(rank)
        single_shape = [x1_shape[0] // world_size, x1_shape[1]]
        input = torch.randn(single_shape, dtype=dtype).npu()
        weight = torch.randn(x2_shape, dtype=dtype).npu()
        is_gather_out = True
        output, gather_out = get_graph(input, weight, hcomm_info, world_size, is_gather_out)
        print("output:", output)
    if __name__ == "__main__":
        worksize = 8
        master_ip = '127.0.0.1'
        master_port = '50001'
        x1_shape = [128, 512]
        x2_shape = [512, 64]
        dtype = torch.float16
        mp.spawn(run_all_gather_base_mm, args=(worksize, master_ip, master_port, x1_shape, x2_shape, dtype), nprocs=worksize)