import os
import unittest
import shutil
from pathlib import Path
import torch
import torch.distributed as dist
import torchair.ge_concrete_graph.ge_converter.experimental.patch_for_hcom_allreduce
import torch_npu
import torchair
torch.manual_seed(7)
torch.npu.manual_seed_all(7)
class HcomStaticKernelTest(unittest.TestCase):
@classmethod
def _init_dist_hccl(cls, rank, world_size):
torchair.patch_for_hcom()
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29505'
os.environ['HCCL_WHITELIST_DISABLE'] = '1'
torch_npu.npu.set_device(rank)
dist.init_process_group(backend='hccl', world_size=world_size, rank=rank)
return dist
@classmethod
def _test_static_kernel_with_cache(cls, rank, world_size, input, cache_dir):
class CacheHcomModel(torch.nn.Module):
def __init__(self, config, cache_dir):
super(CacheHcomModel, self).__init__()
self.relu = torch.nn.ReLU()
self.cached_prompt = torchair.inference.cache_compile(self.prompt, config=config, cache_dir=cache_dir)
def prompt(self, x):
return self._forward(x)
def forward(self, x):
return self.cached_prompt(x)
def _forward(self, x):
relu_01 = self.relu(x)
reshape_01 = torch.reshape(relu_01, (1, 32, 1, 128))
softmax_01 = torch.nn.functional.softmax(reshape_01)
sqrt_01 = torch.sqrt(softmax_01)
relu_02 = self.relu(sqrt_01)
square_01 = torch.square(relu_02)
torch_npu.distributed.distributed_c10d.dist.all_reduce(square_01)
add_01 = torch.add(square_01, square_01)
return add_01
torch.npu.set_device(rank)
HcomStaticKernelTest._init_dist_hccl(rank, world_size)
config = torchair.CompilerConfig()
config.mode = "reduce-overhead"
config.experimental_config.aclgraph._aclnn_static_shape_kernel = True
npu_model = CacheHcomModel(config, cache_dir).npu()
input0 = input.npu()
npu_output = npu_model(input0)
@classmethod
def _test_static_kernel_without_cache(cls, rank, world_size, input, kernel_build_dir):
class HcomModel(torch.nn.Module):
def __init__(self):
super(HcomModel, self).__init__()
self.relu = torch.nn.ReLU()
def forward(self, x):
relu_01 = self.relu(x)
reshape_01 = torch.reshape(relu_01, (1, 32, 1, 128))
softmax_01 = torch.nn.functional.softmax(reshape_01)
sqrt_01 = torch.sqrt(softmax_01)
relu_02 = self.relu(sqrt_01)
square_01 = torch.square(relu_02)
torch_npu.distributed.distributed_c10d.dist.all_reduce(square_01)
add_01 = torch.add(square_01, square_01)
return add_01
torch.npu.set_device(rank)
HcomStaticKernelTest._init_dist_hccl(rank, world_size)
config = torchair.CompilerConfig()
config.mode = "reduce-overhead"
config.experimental_config.aclgraph._aclnn_static_shape_kernel = True
config.experimental_config.aclgraph._aclnn_static_shape_kernel_build_dir = kernel_build_dir
input0 = input.npu()
npu_model = HcomModel().npu()
npu_backend = torchair.get_npu_backend(compiler_config=config)
npu_model = torch.compile(npu_model, fullgraph=True, backend=npu_backend, dynamic=False)
npu_output = npu_model(input0)
def test_static_kernel_without_cache(self):
kernel_build_dir = "./static_kernel_dir_without_cache"
if os.path.exists(kernel_build_dir):
shutil.rmtree(kernel_build_dir)
os.makedirs(kernel_build_dir, exist_ok = True)
result = os.popen("ls /dev | grep davinci | wc -l")
dev_num = result.read()
result.close()
device_size = int(dev_num) - 1
if device_size < 2:
return
world_size = device_size
os.environ["LOCAL_WORLD_SIZE"] = str(world_size)
with torch.multiprocessing.Manager() as manager:
input = torch.randn(1, 4, 8, 128, dtype=torch.float16)
torch.multiprocessing.spawn(HcomStaticKernelTest._test_static_kernel_without_cache,
args=(world_size, input, kernel_build_dir),
nprocs=world_size, join=True)
static_kernel_dir_path = Path(kernel_build_dir)
self.assertTrue(static_kernel_dir_path.exists())
outputs_dirs = [d for d in static_kernel_dir_path.iterdir() if
d.is_dir() and d.name == "aclnn_static_shape_kernel_outputs"]
self.assertEqual(len(outputs_dirs), 1)
ts_outputs_dirs = [d for d in outputs_dirs[0].iterdir() if
d.is_dir() and d.name.endswith("_outputs") and d.name.startswith("ts")]
self.assertEqual(len(ts_outputs_dirs), world_size)
run_pkgs = list(outputs_dirs[0].rglob("*.run"))
self.assertEqual(len(run_pkgs), 1)
def test_static_kernel_with_cache(self):
cache_dir = "./static_kernel_dir_with_cache"
if os.path.exists(cache_dir):
shutil.rmtree(cache_dir)
os.makedirs(cache_dir, exist_ok = True)
result = os.popen("ls /dev | grep davinci | wc -l")
dev_num = result.read()
result.close()
device_size = int(dev_num) - 1
if device_size < 2:
return
world_size = device_size
os.environ["LOCAL_WORLD_SIZE"] = str(world_size)
with torch.multiprocessing.Manager() as manager:
input = torch.randn(1, 4, 8, 128, dtype=torch.float16)
torch.multiprocessing.spawn(HcomStaticKernelTest._test_static_kernel_with_cache,
args=(world_size, input, cache_dir),
nprocs=world_size, join=True)
cache_dir_path = Path(cache_dir)
self.assertTrue(cache_dir_path.exists())
cachemode_dirs = [d for d in cache_dir_path.iterdir() if
d.is_dir() and d.name.startswith("CacheHcomModel")]
self.assertEqual(len(cachemode_dirs), 1)
cache_rank_dirs = [d for d in cachemode_dirs[0].iterdir() if d.is_dir()]
self.assertEqual(len(cache_rank_dirs), world_size)
run_pkgs = list(cachemode_dirs[0].rglob("*.run"))
self.assertEqual(len(run_pkgs), 1)
run_pkg_path = run_pkgs[0].resolve()
rank_0_dir = f'world{world_size}global_rank0'
self.assertTrue(rank_0_dir in str(run_pkg_path))
with torch.multiprocessing.Manager() as manager:
input = torch.randn(1, 4, 8, 128, dtype=torch.float16)
torch.multiprocessing.spawn(HcomStaticKernelTest._test_static_kernel_with_cache,
args=(world_size, input, cache_dir),
nprocs=world_size, join=True)
self.assertTrue(cache_dir_path.exists())
run_pkgs_02 = list(cache_dir_path.rglob("*.run"))
self.assertEqual(len(run_pkgs), 1)
self.assertEqual(str(run_pkg_path), str(run_pkgs_02[0].resolve()))
def test_static_kernel_with_gloo_fix(self):
kernel_build_dir = "./static_kernel_dir_with_gloo_fix"
if os.path.exists(kernel_build_dir):
shutil.rmtree(kernel_build_dir)
os.makedirs(kernel_build_dir, exist_ok = True)
result = os.popen("ls /dev | grep davinci | wc -l")
dev_num = result.read()
result.close()
device_size = int(dev_num) - 1
if device_size < 2:
return
world_size = device_size
os.environ["LOCAL_WORLD_SIZE"] = str(world_size)
with torch.multiprocessing.Manager() as manager:
input = torch.randn(1, 4, 8, 128, dtype=torch.float16)
torch.multiprocessing.spawn(HcomStaticKernelTest._test_static_kernel_without_cache,
args=(world_size, input, kernel_build_dir),
nprocs=world_size, join=True)
static_kernel_dir_path = Path(kernel_build_dir)
self.assertTrue(static_kernel_dir_path.exists())
outputs_dirs = [d for d in static_kernel_dir_path.iterdir() if
d.is_dir() and d.name == "aclnn_static_shape_kernel_outputs"]
self.assertEqual(len(outputs_dirs), 1)
ts_outputs_dirs = [d for d in outputs_dirs[0].iterdir() if
d.is_dir() and d.name.endswith("_outputs") and d.name.startswith("ts")]
self.assertEqual(len(ts_outputs_dirs), world_size)
run_pkgs = list(outputs_dirs[0].rglob("*.run"))
self.assertEqual(len(run_pkgs), 1)
if __name__ == '__main__':
unittest.main()