npu_alltoall_allgather_bmm对外接口
def npu_alltoall_allgather_bmm(
x: Tensor,
weight: Tensor,
group_ep: str,
group_ep_worldsize: int,
group_tp: str,
group_tp_worldsize: int,
*,
bias: Optional[Tensor] = None,
shard_type: Optional[int] = 0,
act_type: Optional[str] = "None",
need_allgather_out: Optional[bool] = False,
need_activation_feature: Optional[bool] = False
) -> Tuple[Tensor, Tensor, Tensor]:
计算逻辑: bmm指BatchMatMul,AllToAllAllGatherBatchMatMul算子是实现AllToAll、AllGather集合通信与BatchMatMul计算并行的算子。 大体计算流程为:AllToAll集合通信-->AllGather集合通信-->BatchMatMul-->激活(可选,可以没有)
计算逻辑如下,其中y1Out y2OutOptional y3OutOptional为输出,x weight bias为输入,activating为激活函数(由act_type决定,当act_type为None时,表示不调用激活函数)
alltoallOut=AllToAll(x) alltoallOut = AllToAll(x)
y2OutOptional=AllGather(alltoallOut) y2OutOptional = AllGather(alltoallOut)
y3OutOptional=BatchMatMul(y2OutOptional,weight,bias) y3OutOptional = BatchMatMul(y2OutOptional, weight, bias)
y1Out=activating(y3OutOptional) y1Out = activating(y3OutOptional)
输入输出及属性说明
输入:
- x:必选输入,Tensor,数据类型支持float16,bfloat16。该输入进行AllToAll、AllGather集合通信,必须为3维,数据格式支持ND,通信后结果作为BatchMatMul计算的左矩阵。
- weight:必选输入,Tensor,数据类型支持float16, bfloat16,类型需与x保持一致,必须为3维,数据格式支持ND, BatchMatMul计算的右矩阵。
- bias:可选输入,Tensor,数据类型支持float16, float32。x为float16时,bias需为float16;x为bfloat16时,bias需为float32,必须为两维或三维,数据格式支持ND。BatchMatMul计算的bias。
输出:
- y1Out:Tensor,数据类型支持float16, bfloat16,仅支持3维。最终计算结果,如果有激活函数则为激活函数的输出,否则为BatchMatMul的输出。数据类型与输入x保持一致。
- y2OutOptional:Tensor,可选输出,数据类型支持float16, bfloat16,仅支持3维。AllGather的输出,数据类型与输入x保持一致。反向可能需要。
- y3OutOptional:Tensor,可选输出,数据类型支持float16, bfloat16,仅支持3维。有激活函数时,BatchMatMul的输出,类型与输入x保持一致。
属性:
- group_ep:必选属性,str。ep通信域名称,专家并行的通信域。
- group_ep_worldsize:必选属性,int。ep通信域size,支持2/4/8/16/32。
- group_tp:必选属性,str。tp通信域名称,Tensor并行的通信域。
- group_tp_worldsize:必选属性,int。tp通信域size,支持2/4/8/16/32。
- shard_type:可选属性,int,默认值为0,0表示在H维度按tp域进行allgather,1表示在C维度上按tp域进行allgather。
- act_type:可选属性,str,激活函数类型,默认值为None,表示无激活函数。支持GELU/Silu/FastGELU/Relu/None等。
- need_allgather_out:是否需要输出allgather后的结果,默认False,表示不需要输出。
- need_activation_feature:是否需要输出执行激活函数前的结果(BatchMatMul后),默认False,表示不需要输出。仅在act_type不为None的时候有意义。
输入shape限制
因为集合通信及BatchMatMul计算所需,输入输出shape需满足以下数学关系:(其中ep=group_ep_worldsize,tp=group_tp_worldsize)
按H轴进行AllGather场景,shard_type为0时:
- x: (E, C, H/tp)
- weight:(E/ep, H, M/tp)
- bias:支持两维或三维,三维时shape为:(E/ep, 1, M/tp),两维时shape为:(E/ep, M/tp)
- y1Out:(E/ep, ep*C, M/tp)
- y2OutOptional:(E/ep, ep*C, H)
- y3OutOptional:(E/ep, ep*C, M/tp)
按C轴进行AllGather场景,shard_type为1时:
- x: (E, C/tp, H);
- weight:(E/ep, H, M/tp);
- bias:支持两维或三维,三维时shape为:(E/ep, 1, M/tp),两维时shape为:(E/ep, M/tp)
- y1Out:(E/ep, ep*tp*C/tp, M/tp);
- y2OutOptional:(E/ep, ep*tp*C/tp, H);
- y3OutOptional:(E/ep, ep*tp*C/tp, M/tp)
数据关系说明:
- 比如x.size(0)等于E,weight.size(0)等于E/ep,则表示,x.size(0) = ep*weight.size(0),x.size(0)是ep的整数倍;其他关系类似
- E的取值范围为[2, 512],且E是ep的整数倍;
- H的取值范围为:[1, 65535],当shard_type为0时,H需为tp的整数倍;
- M/tp的取值为:[1, 65535];
- E/ep的取值范围为:[1, 32];
- ep、tp均仅支持2、4、8、16、32;
- group_ep和group_tp名称不能相同;
- C大于0,上限为算子device内存上限,当shard_type为1时,C需为tp的整数倍;
- 不支持跨超节点,只支持超节点内。
npu_alltoall_allgather_bmm 的调用示例
在终端调用命令如下:
python3 -m torch.distributed.launch --nproc_per_node 8 --master_addr 127.0.0.1 --master_port 29500 demo_test.py
注:master_addr和master_port参数需用户根据实际情况设置
demo_test.py的示例代码如下:
import os
import pytest
import torch
import torch.distributed as dist
from torch.distributed.distributed_c10d import _get_default_group, ReduceOp
import torch_npu
from mindspeed.ops.npu_all_to_all_all_gather_bmm import npu_alltoall_allgather_bmm
world_size = 8
ep_size = 4
tp_size = 2
def setup_ep_tp(rank, tp_size, ep_size, backend_type):
# 初始化EP域
print("device %d initialize ep group" % rank, flush=True)
for i in range(tp_size):
ep_ranks = [x + ep_size * i for x in range(ep_size)]
ep_group = dist.new_group(backend=backend_type, ranks=ep_ranks)
if rank in ep_ranks:
ep_group_tmp = ep_group
print("device %d initialize tp group" % rank, flush=True)
for i in range(ep_size):
tp_ranks = [x * ep_size + i for x in range(tp_size)]
tp_group = dist.new_group(backend=backend_type, ranks=tp_ranks)
if rank in tp_ranks:
tp_group_tmp = tp_group
return ep_group_tmp, tp_group_tmp
def get_ep_tp_hcomm_info(rank, ep_size, tp_size):
ep_group, tp_group = setup_ep_tp(rank, tp_size, ep_size, "hccl")
if torch.__version__ > '2.0.1':
ep_hcomm_info = ep_group._get_backend(torch.device("npu")).get_hccl_comm_name(rank)
tp_hcomm_info = tp_group._get_backend(torch.device("npu")).get_hccl_comm_name(rank)
else:
ep_hcomm_info = ep_group.get_hccl_comm_name(rank)
tp_hcomm_info = tp_group.get_hccl_comm_name(rank)
return ep_hcomm_info, tp_hcomm_info
if __name__ == '__main__':
dtype = torch.float16
x_shard_type = 1
out_y2_flag = True
out_y3_flag = False
act_type = "None"
transpose_weight = False
rank = int(os.environ["LOCAL_RANK"])
torch_npu.npu.set_device(rank)
dist.init_process_group(backend="hccl", rank=rank, world_size=world_size)
ep_group, tp_group = get_ep_tp_hcomm_info(rank, ep_size, tp_size)
print(f'current device: {torch_npu.npu.current_device()}, local rank = {rank}, hcomm_info = {ep_group}, {tp_group}')
E, C, H, M = 4, 1024, 1024, 8192
if x_shard_type == 0:
x_shape = (E, C, H / tp_size)
elif x_shard_type == 1:
x_shape = (E, C / tp_size, H)
else:
x_shape = (E / ep_size, tp_size * ep_size * C, M / tp_size)
weight_shape = (E / ep_size, H, M / tp_size)
if transpose_weight == True:
weight_shape = (E / ep_size, M / tp_size, H)
bias_shape = (E / ep_size, 1, M / tp_size)
x_shape = tuple(int(item) for item in x_shape)
weight_shape = tuple(int(item) for item in weight_shape)
bias_shape = tuple(int(item) for item in bias_shape)
x = torch.rand(x_shape)
weight = torch.rand(weight_shape)
bias = torch.rand(bias_shape)
x_npu = x.npu().to(dtype)
weight_npu = weight.npu().to(dtype)
if transpose_weight == True:
print(f'!!!!before transpose, weight_npu.size()={weight_npu.size()}')
weight_npu = weight_npu.transpose(1, 2)
print(f'!!!!after transpose, weight_npu.size()={weight_npu.size()}')
print(f'!!!!after transpose, weight_npu.is_contiguous()={weight_npu.is_contiguous()}')
bias_npu = bias.npu().to(dtype)
# 赋值None可以验证bias为空的场景
bias_npu = None
y_npu = npu_alltoall_allgather_bmm(x_npu,
weight_npu,
ep_group,
ep_size,
tp_group,
tp_size,
bias=bias_npu,
shard_type=x_shard_type,
act_type=act_type,
need_allgather_out=out_y2_flag,
need_activation_feature=out_y3_flag)
if rank == 0:
for i, y in enumerate(y_npu[0]):
y.cpu().numpy().tofile(f"./y_{i}.bin")