import os
import numpy as np
import torch
import torch.distributed as dist
import torch.distributed.distributed_c10d as c10d
import torch.multiprocessing as mp
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import create_common_tensor
from torch_npu.testing.common_distributed import skipIfUnsupportMultiNPU
import torch_npu
class OptionsTest(TestCase):
@classmethod
def _init_dist_hccl(cls, rank, options, world_size):
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
os.environ['HCCL_WHITELIST_DISABLE'] = '1'
torch_npu.npu.set_device(rank)
dist.init_process_group(backend='hccl', pg_options=options, world_size=world_size, rank=rank)
@classmethod
def _test_all_reduce_with_options(cls, rank, ranks, world_size, input1):
options = torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options()
hccl_config1 = {"hccl_buffer_size": 300, "group_name": "custom"}
options.hccl_config = hccl_config1
OptionsTest._init_dist_hccl(rank, options, world_size)
input1 = input1.npu()
dist.all_reduce(input1)
hccl_config2 = {"hccl_buffer_size": 200}
options.hccl_config = hccl_config2
dist.all_reduce(input1)
default_pg = c10d._get_default_group()._get_backend(torch.device('npu'))
test_case = TestCase()
test_case.assertEqual(default_pg.options.hccl_config, hccl_config1,
"Once Options are set for a ProcessGroupHCCL, later changes to Options won't affect "
"that ProcessGroupHCCL.")
test_case.assertEqual(default_pg.options.hccl_config.get("group_name", ""), "custom")
pg = dist.new_group(backend='hccl', ranks=ranks, pg_options=options)
dist.all_reduce(input1, group=pg)
@classmethod
def _test_options_wrong_type(cls, rank, hccl_config, error_expect, world_size, input1):
options = torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options()
options.hccl_config = hccl_config
input1 = input1.npu()
test_case = TestCase()
with test_case.assertRaisesRegex(RuntimeError, error_expect):
OptionsTest._init_dist_hccl(rank, options, world_size)
dist.all_reduce(input1)
@classmethod
def _test_options_group_name_wrong_types(cls, rank, ranks, world_size, input1):
cls._test_options_wrong_type(rank, {"group_name": 123}, "Value type of group_name should be string", world_size, input1)
@classmethod
def _test_options_qos_traffic_class_wrong_types(cls, rank, ranks, world_size, input1):
cls._test_options_wrong_type(rank, {"qos_traffic_class": "123"}, "Value type of qos_traffic_class should be int.", world_size, input1)
@classmethod
def _test_options_qos_service_level_wrong_types(cls, rank, ranks, world_size, input1):
cls._test_options_wrong_type(rank, {"qos_service_level": "123"}, "Value type of qos_service_level should be int.", world_size, input1)
@classmethod
def _test_options_hccl_op_expansion_mode_wrong_types(cls, rank, ranks, world_size, input1):
cls._test_options_wrong_type(rank, {"hccl_op_expansion_mode": "123"}, "Value type of hccl_op_expansion_mode should be int.", world_size, input1)
def _test_multiprocess(self, f, input1, world_size):
ctx = mp.get_context('spawn')
ps = []
ranks = range(world_size)
for rank in ranks:
p = ctx.Process(
target=f,
args=(rank, ranks, world_size, input1.cpu()))
p.start()
ps.append(p)
for p in ps:
p.join()
for p in ps:
self.assertEqual(p.exitcode, 0)
@skipIfUnsupportMultiNPU(2)
def test_all_reduce_with_options(self):
ranks = [2]
shape = [np.int32, 0, [2, 3, 16]]
for world_size in ranks:
exp_input, input1 = create_common_tensor(shape, -10, 10)
self._test_multiprocess(OptionsTest._test_all_reduce_with_options,
input1, world_size)
@skipIfUnsupportMultiNPU(2)
def test_options_group_name_wrong_type(self):
ranks = [2]
shape = [np.int32, 0, [2, 3, 16]]
for world_size in ranks:
exp_input, input1 = create_common_tensor(shape, -10, 10)
self._test_multiprocess(OptionsTest._test_options_group_name_wrong_types,
input1, world_size)
@skipIfUnsupportMultiNPU(2)
def test_options_qos_traffic_class_wrong_type(self):
ranks = [2]
shape = [np.int32, 0, [2, 3, 16]]
for world_size in ranks:
exp_input, input1 = create_common_tensor(shape, -10, 10)
self._test_multiprocess(OptionsTest._test_options_qos_traffic_class_wrong_types,
input1, world_size)
@skipIfUnsupportMultiNPU(2)
def test_options_qos_service_level_wrong_type(self):
ranks = [2]
shape = [np.int32, 0, [2, 3, 16]]
for world_size in ranks:
exp_input, input1 = create_common_tensor(shape, -10, 10)
self._test_multiprocess(OptionsTest._test_options_qos_service_level_wrong_types,
input1, world_size)
@skipIfUnsupportMultiNPU(2)
def test_options_hccl_op_expansion_mode_wrong_type(self):
ranks = [2]
shape = [np.int32, 0, [2, 3, 16]]
for world_size in ranks:
exp_input, input1 = create_common_tensor(shape, -10, 10)
self._test_multiprocess(OptionsTest._test_options_hccl_op_expansion_mode_wrong_types,
input1, world_size)
if __name__ == '__main__':
run_tests()