torch_npu.npu_mm_reduce_scatter_base

产品支持情况

产品 是否支持
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品

功能说明

  • API功能:TP切分场景下,实现matmul和reduce_scatter的融合,融合算子内部实现计算和通信流水并行。支持perchannel,pertoken量化。

  • 计算公式: x1x1代表输入input

    基础场景:

    output=reducescatter(x1@x2+bias)output = reducescatter(x1 \mathbin{@} x2 + bias)

    量化场景:

    output=reducescatter((x1_scale∗x2_scale)∗(x1@x2+bias))output = reducescatter((x1\_scale * x2\_scale) * (x1 \mathbin{@} x2 + bias))

Note

使用该接口时,请确保驱动固件包和CANN包都为配套的8.0.RC2版本或者配套的更高版本,否则将会引发报错,比如BUS ERROR等。

函数原型

torch_npu.npu_mm_reduce_scatter_base(input, x2, hcom, world_size, *, reduce_op='sum', bias=None, x1_scale=None, x2_scale=None, comm_turn=0, output_dtype=None, comm_mode=None) -> Tensor

参数说明

  • input (Tensor):必选参数。数据类型支持float16bfloat16int8,数据格式支持NDND,输入shape支持2维,形如(m, k)。

  • x2 (Tensor):必选参数。数据类型与input一致,数据格式支持NDNDNZNZNZNZ仅在comm_modeaiv时支持。输入shape支持2维,形如(k, n)。轴满足matmul算子入参要求,k轴相等,且k轴取值范围为[256, 65535),m轴需要整除world_size

  • hcom (str):必选参数。通信域handle名,通过get_hccl_comm_name接口获取。

  • world_size (int):必选参数。通信域内的rank总数。

    • Atlas A2 训练系列产品支持2、4、8卡,支持HCCS链路all mesh组网(每张卡和其它卡两两相连)。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品支持2、4、8、16、32卡,支持HCCS链路double ring组网(多张卡按顺序组成一个圈,每张卡只和左右卡相连)。
  • *:必选参数,代表其之前的变量是位置相关的,必须按照顺序输入;之后的变量是可选参数,位置无关,需要使用键值对赋值,不赋值会使用默认值。

  • reduce_op (str):可选参数。reduce操作类型,当前仅支持'sum',默认值为'sum'。

  • bias (Tensor):可选参数。数据类型支持float16bfloat16,数据格式支持NDND格式。数据类型需要和input保持一致。bias仅支持一维,且维度大小与output的第1维大小相同。当前版本暂不支持bias输入为非0的场景。

  • x1_scale (Tensor):可选参数。mm左矩阵反量化参数。数据类型支持float32,数据格式支持NDND格式。数据维度为(m, 1), 支持pertoken量化。

  • x2_scale (Tensor):可选参数。mm右矩阵反量化参数。数据类型支持float32int64,数据格式支持NDND格式。数据维度为(1, n), 支持perchannel量化。如需传入int64数据类型的,需要提前调用torch_npu.npu_trans_quant_param来获取int64数据类型的x2_scale

  • comm_turn (int):可选参数。表示rank间通信切分粒度,默认值为0,表示默认的切分方式。当前版本仅支持输入0。

  • output_dtype (ScalarType):可选参数。表示输出数据类型。仅支持在量化场景且x1_scalex2_scale均为float32时,可指定输出数据类型为bfloat16float16,默认值为bfloat16

  • comm_mode (str):可选参数。表示通信模式,支持ai_cpuaiv两种模式。ai_cpu模式仅支持基础场景。aiv模式支持基础场景和量化场景。默认值为ai_cpu

返回值说明

Tensor

shape维度和input保持一致。 基础场景时数据类型和input保持一致。 量化场景下,x2_scaleint64数据类型时,输出数据类型为float16x1_scalex2_scale均为float32时, 输出数据类型由output_dtype指定,默认为bfloat16

约束说明

  • input不支持输入转置后的tensor,x2转置后输入,需要满足shape的第一维大小与input的最后一维相同,满足matmul的计算条件。
  • comm_modeai_cpu时:
    • 该接口仅在训练场景下使用。
    • 该接口支持图模式。
    • Atlas A2 训练系列产品:一个模型中的通算融合算子(AllGatherMatmul、MatmulReduceScatter、MatmulAllReduce),仅支持相同通信域。
  • comm_modeaiv时,训练和推理场景均可使用。

调用示例

  • 单算子模式调用

    import torch
    import torch_npu
    import torch.distributed as dist
    import torch.multiprocessing as mp
    def run_mm_reduce_scatter_base(rank, world_size, master_ip, master_port, x1_shape, x2_shape, dtype):
        torch_npu.npu.set_device(rank)
        init_method = 'tcp://' + master_ip + ':' + master_port
        dist.init_process_group(backend="hccl", rank=rank, world_size=world_size, init_method=init_method)
        from torch.distributed.distributed_c10d import _get_default_group
        default_pg = _get_default_group()
        if torch.__version__ > '2.0.1':
            hcomm_info = default_pg._get_backend(torch.device("npu")).get_hccl_comm_name(rank)
        else:
            hcomm_info = default_pg.get_hccl_comm_name(rank)
    
        input_ = torch.randn(x1_shape, dtype=dtype).npu()
        weight = torch.randn(x2_shape, dtype=dtype).npu()
        output = torch_npu.npu_mm_reduce_scatter_base(input_, weight, hcomm_info, world_size)
    
    if __name__ == "__main__":
        worksize = 8
        master_ip = '127.0.0.1'
        master_port = '50001'
        x1_shape = [128, 512]
        x2_shape = [512, 64]
        dtype = torch.float16
    
        mp.spawn(run_mm_reduce_scatter_base, args=(worksize, master_ip, master_port, x1_shape, x2_shape, dtype), nprocs=worksize)
    
  • 图模式调用

    import torch
    import torch_npu
    import torch.distributed as dist
    import torch.multiprocessing as mp
    class MM_REDUCESCATTER_GRAPH_Model(torch.nn.Module):
        def __init__(self):
            super().__init__()
        def forward(self, input, weight, hcomm_info, world_size, reduce_op):
            output = torch_npu.npu_mm_reduce_scatter_base(input, weight, hcomm_info, world_size,
                                                          reduce_op=reduce_op)
            return output
    def define_model(model, graph_type):
        import torchair
        if graph_type == 1:  # 传统入图模式,静态shape+在线编译场景
            npu_backend = torchair.get_npu_backend(compiler_config=None)
            model = torch.compile(model, backend=npu_backend, dynamic=False)
        elif graph_type == 2:  # ACLNN入图模式,动态shape+二进制
            npu_backend = torchair.get_npu_backend(compiler_config=None)
            model = torch.compile(model, backend=npu_backend, dynamic=True)
        else:
            print("Error type")
        return model
    def get_graph(input, weight, hcomm_info, world_size):
        model = MM_REDUCESCATTER_GRAPH_Model()
        model = define_model(model, 2)
        model_output = model(input, weight, hcomm_info, world_size, reduce_op="sum")
        return model_output
    def run_mm_reduce_scatter_base(rank, world_size, master_ip, master_port, x1_shape, x2_shape, dtype):
        torch_npu.npu.set_device(rank)
        init_method = 'tcp://' + master_ip + ':' + master_port
        dist.init_process_group(backend="hccl", rank=rank, world_size=world_size, init_method=init_method)
        from torch.distributed.distributed_c10d import _get_default_group
        default_pg = _get_default_group()
        if torch.__version__ > '2.0.1':
            hcomm_info = default_pg._get_backend(torch.device("npu")).get_hccl_comm_name(rank)
        else:
            hcomm_info = default_pg.get_hccl_comm_name(rank)
        input = torch.randn(x1_shape, dtype=dtype).npu()
        weight = torch.randn(x2_shape, dtype=dtype).npu()
        output = get_graph(input, weight, hcomm_info, world_size)
        print("output:", output)
    if __name__ == "__main__":
        worksize = 8
        master_ip = '127.0.0.1'
        master_port = '50001'
        x1_shape = [128, 512]
        x2_shape = [512, 64]
        dtype = torch.float16
        mp.spawn(run_mm_reduce_scatter_base, args=(worksize, master_ip, master_port, x1_shape, x2_shape, dtype), nprocs=worksize)