npu_bmm_reducescatter_alltoall对外接口
def npu_bmm_reducescatter_alltoall(x: Tensor,
weight: Tensor,
group_ep: str,
group_ep_worldsize: int,
group_tp: str,
group_tp_worldsize: int,
*,
bias: Optional[Tensor] = None,
shard_type: Optional[int] = 0) -> Tensor:
计算逻辑: BatchMatMulReduceScatterAllToAll是实现BatchMatMul计算与ReduceScatter、AllToAll集合通信并行的算子。 大体计算流程为:BatchMatMul计算-->转置(shard_type等于0时需要)-->ReduceScatter集合通信-->Add-->AllToAll集合通信
计算逻辑如下,其中out为最终输出,x weight bias为输入
bmmOut=BatchMatMul(x,weight) bmmOut = BatchMatMul(x,weight)
reduceScatterOut=ReduceScatter(bmmOut) reduceScatterOut = ReduceScatter(bmmOut)
addOut=Add(reduceScatterOut,bias) addOut = Add(reduceScatterOut, bias)
out=AllToAll(addOut) out = AllToAll(addOut)
输入输出及属性说明
输入:
- x:必选输入,Tensor,数据类型float16,bfloat16,必须为3维。BatchMatMul计算的左矩阵。
- weight:必选输入,Tensor,数据类型float16, bfloat16,必须为3维,类型与x保持一致。BatchMatMul计算的右矩阵。
- bias:可选输入,Tensor,数据类型float16, float32。x为float16时,bias需为float16;x为bfloat16时,bias需为float32。支持两维或三维。BatchMatMul计算的bias。(由于要进行ReduceScatter通信,因此需要在通信之后再Add)。
输出:
- out:Tensor,数据类型float16, bfloat16,必须为3维。最终计算结果,类型与输入x保持一致。
属性:
- group_ep:必选属性,str。ep通信域名称,专家并行的通信域。
- group_ep_worldsize:必选属性,int。ep通信域size,支持2/4/8/16/32。
- group_tp:必选属性,str。tp通信域名称,Tensor并行的通信域。
- group_tp_worldsize:必选属性,int。tp通信域size,支持2/4/8/16/32。
- shard_type:可选属性,int,默认值为0。0表示输出在H维度按tp分片,1表示输出在C维度按tp分片。
输入限制
因为集合通信及BatchMatMul计算所需,输入输出shape需满足以下数学关系:(其中ep=group_ep_worldsize,tp=group_tp_worldsize)
按H轴进行ReduceScatter场景,即shard_type为0场景:
- x: (E/ep, ep*C, M/tp)
- weight:(E/ep, M/tp, H)
- bias:(E/ep, 1, H/tp) 两维时为(E/ep, H/tp)
- out:(E, C, H/tp)
按C轴进行ReduceScatter场景,即shard_type为1场景:
- x: (E/ep, ep*tp*C/tp, M/tp)
- weight:(E/ep, M/tp, H)
- bias:(E/ep, 1, H) 两维时为(E/ep, H)
- out:(E, C/tp, H)
数据关系说明:
- 比如x.size(0)等于E/tp,out.size(0)等于E,则表示,out.size(0) = ep*x.size(0),out.size(0)是ep的整数倍;其他关系类似
- E的取值范围为[2, 512],且E是ep的整数倍;
- H的取值范围为:[1, 65535],当shard_type为0时,H需为tp的整数倍;
- M/tp的取值范围为:[1, 65535];
- E/ep的取值范围为:[1, 32];
- ep、tp均仅支持2、4、8、16、32;
- group_ep和group_tp名称不能相同;
- C大于0,上限为算子device内存上限,当shard_type为1时,C需为tp的整数倍;
- 不支持跨超节点,只支持超节点内。
npu_bmm_reducescatter_alltoall 类的调用示例(待验证)
在终端调用命令如下:
python3 -m torch.distributed.launch --nproc_per_node 8 --master_addr 127.0.0.1 --master_port 29500 demo_test.py
注:master_addr和master_port参数需用户根据实际情况设置,8表示ep_size*tp_size,按实际修改
demo_test.py的示例代码如下:
import os
import pytest
import torch
import torch.distributed as dist
from torch.distributed.distributed_c10d import _get_default_group, ReduceOp
import torch_npu
from mindspeed.ops.npu_bmm_reduce_scatter_all_to_all import npu_bmm_reducescatter_alltoall
world_size = 8
ep_size = 4
tp_size = 2
def get_hcomm_info(n, i):
default_pg = _get_default_group()
if torch.__version__ > '2.0.1':
hcomm_info = default_pg._get_backend(torch.device('npu')).get_hccl_comm_name(i)
else:
hcomm_info = default_pg.get_hccl_comm_name(i)
return hcomm_info
def setup_ep_tp(rank, tp_size, ep_size, backend_type):
# 初始化EP域
print("device %d initialize ep group" % rank, flush=True)
for i in range(tp_size):
ep_ranks = [x + ep_size * i for x in range(ep_size)]
ep_group = dist.new_group(backend=backend_type, ranks=ep_ranks)
if rank in ep_ranks:
ep_group_tmp = ep_group
print("device %d initialize tp group" % rank, flush=True)
for i in range(ep_size):
tp_ranks = [x * ep_size + i for x in range(tp_size)]
tp_group = dist.new_group(backend=backend_type, ranks=tp_ranks)
if rank in tp_ranks:
tp_group_tmp = tp_group
return ep_group_tmp, tp_group_tmp
def get_ep_tp_hcomm_info(rank, ep_size, tp_size):
ep_group, tp_group = setup_ep_tp(rank, tp_size, ep_size, "hccl")
if torch.__version__ > '2.0.1':
ep_hcomm_info = ep_group._get_backend(torch.device("npu")).get_hccl_comm_name(rank)
tp_hcomm_info = tp_group._get_backend(torch.device("npu")).get_hccl_comm_name(rank)
else:
ep_hcomm_info = ep_group.get_hccl_comm_name(rank)
tp_hcomm_info = tp_group.get_hccl_comm_name(rank)
return ep_hcomm_info, tp_hcomm_info
def test_npu_bmm_reducescatter_alltoall(dtype, y_shard_type, transpose_weight):
rank = int(os.environ["LOCAL_RANK"])
torch_npu.npu.set_device(rank)
dist.init_process_group(backend="hccl", rank=rank, world_size=world_size)
ep_group, tp_group = get_ep_tp_hcomm_info(rank, ep_size, tp_size)
hcomm_info = get_hcomm_info(world_size, rank)
print(f'current device: {torch_npu.npu.current_device()}, local rank = {rank}, hcomm_info = {ep_group}, {tp_group}')
E, C, H, M = 4, 1024, 1024, 8192
if y_shard_type == 0:
x_shape = (E / ep_size, ep_size * C, M / tp_size)
bias_shape = (E / ep_size, 1, H / tp_size)
else:
x_shape = (E / ep_size, tp_size * ep_size * C, M / tp_size)
bias_shape = (E / ep_size, 1, H)
weight_shape = (E / ep_size, M / tp_size, H)
if transpose_weight == True:
weight_shape = (E / ep_size, H, M / tp_size)
x_shape = tuple(int(item) for item in x_shape)
weight_shape = tuple(int(item) for item in weight_shape)
bias_shape = tuple(int(item) for item in bias_shape)
x = torch.rand(x_shape)
weight = torch.rand(weight_shape)
bias = torch.rand(bias_shape)
x_npu = x.npu().to(dtype)
weight_npu = weight.npu().to(dtype)
if transpose_weight == True:
print(f'!!!!before transpose, weight_npu.size()={weight_npu.size()}')
weight_npu = weight_npu.transpose(1, 2)
print(f'!!!!after transpose, weight_npu.size()={weight_npu.size()}')
print(f'!!!!after transpose, weight_npu.is_contiguous()={weight_npu.is_contiguous()}')
bias_npu = bias.npu().to(dtype)
y = npu_bmm_reducescatter_alltoall(x_npu,
weight_npu,
ep_group,
ep_size,
tp_group,
tp_size,
bias=bias_npu,
shard_type=y_shard_type)
print(f'y_shape = {y.size()}')
if y_shard_type == 0:
assert y.size() == (E, C, int(H / tp_size))
else:
assert y.size() == (E, C, H)
return y
if __name__ == '__main__':
dtype = torch.float16
shard_type = 1
transpose_weight = False
y_npu = test_npu_bmm_reducescatter_alltoall(dtype, shard_type, transpose_weight)
rank = int(os.environ["LOCAL_RANK"])
if rank == 0:
for i, y in enumerate(y_npu):
y.cpu().numpy().tofile(f"./y_{i}.bin")