torch_npu.npu_mm_reduce_scatter_base
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品 | √ |
功能说明
-
API功能:TP切分场景下,实现matmul和reduce_scatter的融合,融合算子内部实现计算和通信流水并行。支持perchannel,pertoken量化。
-
计算公式: x1x1代表输入
input基础场景:
output=reducescatter(x1@x2+bias)output = reducescatter(x1 \mathbin{@} x2 + bias)
量化场景:
output=reducescatter((x1_scale∗x2_scale)∗(x1@x2+bias))output = reducescatter((x1\_scale * x2\_scale) * (x1 \mathbin{@} x2 + bias))
Note
使用该接口时,请确保驱动固件包和CANN包都为配套的8.0.RC2版本或者配套的更高版本,否则将会引发报错,比如BUS ERROR等。
函数原型
torch_npu.npu_mm_reduce_scatter_base(input, x2, hcom, world_size, *, reduce_op='sum', bias=None, x1_scale=None, x2_scale=None, comm_turn=0, output_dtype=None, comm_mode=None) -> Tensor
参数说明
-
input (
Tensor):必选参数。数据类型支持float16、bfloat16、int8,数据格式支持NDND,输入shape支持2维,形如(m, k)。 -
x2 (
Tensor):必选参数。数据类型与input一致,数据格式支持NDND、NZNZ。NZNZ仅在comm_mode为aiv时支持。输入shape支持2维,形如(k, n)。轴满足matmul算子入参要求,k轴相等,且k轴取值范围为[256, 65535),m轴需要整除world_size。 -
hcom (
str):必选参数。通信域handle名,通过get_hccl_comm_name接口获取。 -
world_size (
int):必选参数。通信域内的rank总数。- Atlas A2 训练系列产品支持2、4、8卡,支持HCCS链路all mesh组网(每张卡和其它卡两两相连)。
- Atlas A3 训练系列产品/Atlas A3 推理系列产品支持2、4、8、16、32卡,支持HCCS链路double ring组网(多张卡按顺序组成一个圈,每张卡只和左右卡相连)。
-
*:必选参数,代表其之前的变量是位置相关的,必须按照顺序输入;之后的变量是可选参数,位置无关,需要使用键值对赋值,不赋值会使用默认值。
-
reduce_op (
str):可选参数。reduce操作类型,当前仅支持'sum',默认值为'sum'。 -
bias (
Tensor):可选参数。数据类型支持float16、bfloat16,数据格式支持NDND格式。数据类型需要和input保持一致。bias仅支持一维,且维度大小与output的第1维大小相同。当前版本暂不支持bias输入为非0的场景。 -
x1_scale (
Tensor):可选参数。mm左矩阵反量化参数。数据类型支持float32,数据格式支持NDND格式。数据维度为(m, 1), 支持pertoken量化。 -
x2_scale (
Tensor):可选参数。mm右矩阵反量化参数。数据类型支持float32、int64,数据格式支持NDND格式。数据维度为(1, n), 支持perchannel量化。如需传入int64数据类型的,需要提前调用torch_npu.npu_trans_quant_param来获取int64数据类型的x2_scale。 -
comm_turn (
int):可选参数。表示rank间通信切分粒度,默认值为0,表示默认的切分方式。当前版本仅支持输入0。 -
output_dtype (
ScalarType):可选参数。表示输出数据类型。仅支持在量化场景且x1_scale和x2_scale均为float32时,可指定输出数据类型为bfloat16或float16,默认值为bfloat16。 -
comm_mode (
str):可选参数。表示通信模式,支持ai_cpu、aiv两种模式。ai_cpu模式仅支持基础场景。aiv模式支持基础场景和量化场景。默认值为ai_cpu。
返回值说明
Tensor
shape维度和input保持一致。
基础场景时数据类型和input保持一致。
量化场景下,x2_scale为int64数据类型时,输出数据类型为float16。x1_scale和x2_scale均为float32时, 输出数据类型由output_dtype指定,默认为bfloat16。
约束说明
input不支持输入转置后的tensor,x2转置后输入,需要满足shape的第一维大小与input的最后一维相同,满足matmul的计算条件。comm_mode为ai_cpu时:- 该接口仅在训练场景下使用。
- 该接口支持图模式。
- Atlas A2 训练系列产品:一个模型中的通算融合算子(AllGatherMatmul、MatmulReduceScatter、MatmulAllReduce),仅支持相同通信域。
comm_mode为aiv时,训练和推理场景均可使用。
调用示例
-
单算子模式调用
import torch import torch_npu import torch.distributed as dist import torch.multiprocessing as mp def run_mm_reduce_scatter_base(rank, world_size, master_ip, master_port, x1_shape, x2_shape, dtype): torch_npu.npu.set_device(rank) init_method = 'tcp://' + master_ip + ':' + master_port dist.init_process_group(backend="hccl", rank=rank, world_size=world_size, init_method=init_method) from torch.distributed.distributed_c10d import _get_default_group default_pg = _get_default_group() if torch.__version__ > '2.0.1': hcomm_info = default_pg._get_backend(torch.device("npu")).get_hccl_comm_name(rank) else: hcomm_info = default_pg.get_hccl_comm_name(rank) input_ = torch.randn(x1_shape, dtype=dtype).npu() weight = torch.randn(x2_shape, dtype=dtype).npu() output = torch_npu.npu_mm_reduce_scatter_base(input_, weight, hcomm_info, world_size) if __name__ == "__main__": worksize = 8 master_ip = '127.0.0.1' master_port = '50001' x1_shape = [128, 512] x2_shape = [512, 64] dtype = torch.float16 mp.spawn(run_mm_reduce_scatter_base, args=(worksize, master_ip, master_port, x1_shape, x2_shape, dtype), nprocs=worksize) -
图模式调用
import torch import torch_npu import torch.distributed as dist import torch.multiprocessing as mp class MM_REDUCESCATTER_GRAPH_Model(torch.nn.Module): def __init__(self): super().__init__() def forward(self, input, weight, hcomm_info, world_size, reduce_op): output = torch_npu.npu_mm_reduce_scatter_base(input, weight, hcomm_info, world_size, reduce_op=reduce_op) return output def define_model(model, graph_type): import torchair if graph_type == 1: # 传统入图模式,静态shape+在线编译场景 npu_backend = torchair.get_npu_backend(compiler_config=None) model = torch.compile(model, backend=npu_backend, dynamic=False) elif graph_type == 2: # ACLNN入图模式,动态shape+二进制 npu_backend = torchair.get_npu_backend(compiler_config=None) model = torch.compile(model, backend=npu_backend, dynamic=True) else: print("Error type") return model def get_graph(input, weight, hcomm_info, world_size): model = MM_REDUCESCATTER_GRAPH_Model() model = define_model(model, 2) model_output = model(input, weight, hcomm_info, world_size, reduce_op="sum") return model_output def run_mm_reduce_scatter_base(rank, world_size, master_ip, master_port, x1_shape, x2_shape, dtype): torch_npu.npu.set_device(rank) init_method = 'tcp://' + master_ip + ':' + master_port dist.init_process_group(backend="hccl", rank=rank, world_size=world_size, init_method=init_method) from torch.distributed.distributed_c10d import _get_default_group default_pg = _get_default_group() if torch.__version__ > '2.0.1': hcomm_info = default_pg._get_backend(torch.device("npu")).get_hccl_comm_name(rank) else: hcomm_info = default_pg.get_hccl_comm_name(rank) input = torch.randn(x1_shape, dtype=dtype).npu() weight = torch.randn(x2_shape, dtype=dtype).npu() output = get_graph(input, weight, hcomm_info, world_size) print("output:", output) if __name__ == "__main__": worksize = 8 master_ip = '127.0.0.1' master_port = '50001' x1_shape = [128, 512] x2_shape = [512, 64] dtype = torch.float16 mp.spawn(run_mm_reduce_scatter_base, args=(worksize, master_ip, master_port, x1_shape, x2_shape, dtype), nprocs=worksize)