torch.distributed.distributed_c10d._world.default_pg._get_backend(torch.device("npu")).get_hccl_comm_name
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Atlas A3 训练系列产品 | √ |
| Atlas A2 训练系列产品 | √ |
| Atlas 推理系列产品 | √ |
| Atlas 训练系列产品 | √ |
功能说明
从初始化完成的集合通信域中获取集合通信域名字。
函数原型
torch.distributed.distributed_c10d._world.default_pg._get_backend(torch.device("npu")).get_hccl_comm_name(rankid->int,init_comm=True) -> string
注:接口为PyTorch的ProcessGroup类,backend为NPU backend的方法。ProcessGroup可以为default_pg,也可以为torch.distributed.distributed_c10d.new_group创建的非default_pg。
[!NOTICE]
调用该接口时,需要保证当前current device被设置为正确。
参数说明
-
rankid (
int):必选参数,集合通信对应device的rankid。传入的rankid为全局的rankid,多机间device具有唯一的rankid。 -
init_comm (
int):可选参数,默认值为True。当值为True时,调用get_hccl_comm_name会在hccl还未完成初始化的情况下完成初始化,并返回group name。当值为False时,调用get_hccl_comm_name在hccl还未完成初始化时,不会进行初始化(包括申请内存资源等操作),并返回空字符串。
Note
hccl初始化会申请内存资源,造成内存升高,默认申请内存大小为Send buffer与Recv buffer各200M,共400M。buffer大小受环境变量HCCL_BUFFSIZE控制。
返回值说明
string
代表string类型的集合通信域的名字。
约束说明
- 使用该接口前确保
init_process_group已被调用,且初始化的backend为hccl。 - PyTorch 2.1.0及以后版本与PyTorch 2.1.0之前的版本对该接口调用方式不同,见调用示例。
调用示例
import torch
import torch_npu
import torch.multiprocessing as mp
import os
from torch.distributed.distributed_c10d import _get_default_group
import torch.distributed as dist
def example(rank, world_size):
torch.npu.set_device("npu:" + str(rank))
dist.init_process_group("hccl", rank=rank, world_size=world_size)
default_pg = _get_default_group()
if torch.__version__ > '2.0':
hcomm_info = default_pg._get_backend(torch.device("npu")).get_hccl_comm_name(rank)
else:
hcomm_info = default_pg.get_hccl_comm_name(rank)
print(hcomm_info)
def main():
world_size = 2
mp.spawn(example,
args=(world_size, ),
nprocs=world_size,
join=True)
if __name__ == "__main__":
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29505"
main()
group_name_0
group_name_0