import os
import subprocess
import sys
import textwrap
from torch_npu.testing.testcase import run_tests, TestCase
REQUIRED_C_EXTENSION_CHILDREN = [
"_profiler",
"_distributed_c10d",
"_cd",
"_logging",
"_flops_count",
]
EXPECTED_LOADED_MODULES = [
"torch_npu.npu",
"torch_npu.npu.amp",
"torch_npu.npu.aclnn",
"torch_npu.optim",
"torch_npu.dynamo",
"torch_npu._logging",
"torch_npu._afd",
"torch_npu.profiler",
"torch_npu.distributed",
"torch_npu.distributed.rpc",
"torch_npu.distributed.nn",
"torch_npu.distributed.nn.functional",
"torch_npu.op_plugin",
"torch_npu.op_plugin.meta",
"torch_npu.op_plugin.meta._meta_registrations",
"torch_npu.asd.checksum",
"torch_npu.utils._dynamo",
"torch_npu.utils._inductor",
"torch_npu.utils.custom_ops",
"torch_npu.utils.patch_getenv",
"torch_npu.utils.syncbatchnorm",
]
EXPECTED_NOT_LOADED_MODULES = [
"torch_npu._C._afd",
]
EXPECTED_TOP_LEVEL_ATTRS = [
"npu",
"optim",
"dynamo",
"_afd",
"profiler",
"op_plugin",
"utils",
]
LAZY_TOP_LEVEL_APIS = [
"HiFloat8Tensor",
"erase_stream",
"matmul_checksum",
]
AFD_OPS = [
"attention_worker_scheduler_",
"attention_worker_scheduler",
"ffn_worker_scheduler_",
"ffn_worker_scheduler",
]
class TestTorchNpuBootstrap(TestCase):
def _run_python(self, code: str, *, optional: bool = False):
proc = subprocess.run(
[sys.executable, "-c", textwrap.dedent(code)],
text=True,
capture_output=True,
env=os.environ.copy(),
)
if proc.returncode == 0:
return
message = (
f"subprocess failed with return code {proc.returncode}\n"
f"stdout:\n{proc.stdout}\n"
f"stderr:\n{proc.stderr}"
)
if optional:
self.skipTest(message)
self.fail(message)
def test_01_import_order_compatibility(self):
cases = [
"import torch_npu",
"import torch\nimport torch_npu",
"import torch_npu\nimport torch",
"import torch_npu\nimport torch_npu",
]
for code in cases:
self._run_python(code)
def test_02_import_state_snapshot(self):
self._run_python(
f"""
import sys
import torch
import torch_npu
import torch_npu._C as C
required_c_children = {REQUIRED_C_EXTENSION_CHILDREN!r}
expected_loaded = {EXPECTED_LOADED_MODULES!r}
expected_not_loaded = {EXPECTED_NOT_LOADED_MODULES!r}
expected_top_attrs = {EXPECTED_TOP_LEVEL_ATTRS!r}
assert hasattr(torch, "npu"), "torch.npu is not registered"
assert torch.npu is torch_npu.npu
assert hasattr(torch.Tensor, "npu"), "torch.Tensor.npu is missing"
assert hasattr(torch.nn.Module, "npu"), "torch.nn.Module.npu is missing"
for name in required_c_children:
assert hasattr(C, name), f"torch_npu._C.{{name}} is missing"
# Old behavior: AFD is exposed as torch_npu._afd, not torch_npu._C._afd.
assert hasattr(C, "_afd") is False
missing_modules = [
name for name in expected_loaded if name not in sys.modules
]
assert not missing_modules, (
f"init-time modules changed, missing: {{missing_modules}}"
)
unexpected_modules = [
name for name in expected_not_loaded if name in sys.modules
]
assert not unexpected_modules, (
f"unexpected eager modules loaded: {{unexpected_modules}}"
)
missing_attrs = [
name for name in expected_top_attrs if not hasattr(torch_npu, name)
]
assert not missing_attrs, (
f"torch_npu top-level attrs changed, missing: {{missing_attrs}}"
)
import torch_npu.npu
import torch_npu.npu.aclnn
assert torch_npu.npu is not None
assert torch_npu.npu.aclnn is not None
# _op_plugin_docs is imported for side effect, then removed from top-level.
assert "torch_npu._op_plugin_docs" in sys.modules
assert not hasattr(torch_npu, "_op_plugin_docs")
"""
)
def test_03_public_exports_snapshot(self):
self._run_python(
f"""
import torch
import torch_npu
import torch_npu._C as C
from torch_npu.utils.exposed_api import public_npu_functions
lazy_names = {LAZY_TOP_LEVEL_APIS!r}
for name in lazy_names:
assert name in torch_npu.__all__, f"{{name}} is missing from __all__"
assert name in dir(torch_npu), f"{{name}} is missing from dir(torch_npu)"
assert name not in torch_npu.__dict__, (
f"{{name}} should not be cached before lazy access"
)
value = getattr(torch_npu, name)
assert value is not None
assert name in torch_npu.__dict__, (
f"{{name}} was not cached after lazy access"
)
available_public_ops = []
missing_from_torch_npu = []
missing_from_all = []
missing_torch_alias = []
for name in public_npu_functions:
if not hasattr(torch.ops.npu, name):
continue
available_public_ops.append(name)
if not hasattr(torch_npu, name):
missing_from_torch_npu.append(name)
if name not in torch_npu.__all__:
missing_from_all.append(name)
if not hasattr(torch, name):
missing_torch_alias.append(name)
assert available_public_ops, "no available public torch.ops.npu ops found"
assert not missing_from_torch_npu, (
f"some public ops are missing from torch_npu: "
f"{{missing_from_torch_npu[:20]}}"
)
assert not missing_from_all, (
f"some public ops are missing from torch_npu.__all__: "
f"{{missing_from_all[:20]}}"
)
assert not missing_torch_alias, (
f"some public ops are missing deprecated torch aliases: "
f"{{missing_torch_alias[:20]}}"
)
dtype_names = [
name
for name in dir(C._cd.DType)
if not name.startswith("_") and name not in ["_dir", "name"]
]
missing_dtype = []
mismatch_dtype = []
for name in dtype_names:
if not hasattr(torch_npu, name):
missing_dtype.append(name)
continue
exported = getattr(torch_npu, name)
source = getattr(C._cd.DType, name)
# Pybind objects may not preserve Python identity across getattr calls.
if exported != source and repr(exported) != repr(source):
mismatch_dtype.append((name, repr(exported), repr(source)))
assert not missing_dtype, (
f"some DType symbols are missing from torch_npu: {{missing_dtype}}"
)
assert not mismatch_dtype, (
f"some DType symbols do not match torch_npu._C._cd.DType: "
f"{{mismatch_dtype[:10]}}"
)
"""
)
def test_04_framework_registration_snapshot(self):
self._run_python(
"""
import torch_npu
import torch.distributed as dist
import torch.distributed.rpc as rpc
import torch.distributed.tensor # noqa: F401
from torch._dynamo.device_interface import get_interface_for_device
from torch._dynamo.backends.registry import _BACKENDS
from torch._inductor.codegen.common import device_op_overrides_dict
iface = get_interface_for_device("npu")
assert iface is not None
assert "npu" in _BACKENDS, "npu dynamo backend is not registered"
assert "npugraph_ex" in _BACKENDS, (
"npugraph_ex dynamo backend is not registered"
)
assert "npu" in device_op_overrides_dict
assert device_op_overrides_dict.get("npu") is not None
assert "hccl" in dist.Backend.backend_list
assert "lccl" in dist.Backend.backend_list
names = [
name for name in dir(rpc.BackendType)
if "NPU" in name or "TENSORPIPE" in name
]
assert hasattr(rpc, "BackendType")
assert "NPU_TENSORPIPE" in names
""",
optional=True,
)
def test_05_runtime_lazy_init_semantics(self):
self._run_python(
"""
import torch
import torch_npu
assert torch_npu.npu.is_initialized() is False, (
"import torch_npu unexpectedly triggered NPU lazy init"
)
torch.npu.is_available()
torch.npu.device_count()
assert torch_npu.npu.is_initialized() is False, (
"availability query unexpectedly triggered NPU lazy init"
)
"""
)
self._run_python(
"""
import sys
import torch_npu
assert torch_npu.npu.is_initialized() is False
if torch_npu.npu.device_count() <= 0:
sys.exit(0)
torch_npu.npu.get_device_properties(0)
assert torch_npu.npu.is_initialized() is True, (
"runtime NPU API did not trigger lazy init"
)
"""
)
self._run_python(
"""
import torch_npu
assert torch_npu.npu.is_initialized() is False
torch_npu.npu.init()
assert torch_npu.npu.is_initialized() is True
"""
)
def test_06_component_behavior_snapshot(self):
self._run_python(
f"""
import os
import sys
import torch_npu
import torch_npu._C as C
import torch_npu._afd
import torch_npu.utils as utils
import torch_npu.utils.asd_detector as asd_detector
import torch_npu.utils.patch_getenv as patch_getenv
afd_ops = {AFD_OPS!r}
# patch_getenv behavior.
assert os.getenv is patch_getenv._patched_getenv
assert os.environ.get is patch_getenv._patched_environ_get
# ASD compatibility APIs.
for module_name in [
"torch_npu.utils._asd_detector",
"torch_npu.utils.asd_detector",
]:
assert module_name in sys.modules, (
f"{{module_name}} is not loaded after import torch_npu"
)
for api_name in ["set_asd_loss_scale", "register_asd_hook"]:
assert hasattr(utils, api_name), (
f"torch_npu.utils.{{api_name}} is missing"
)
assert hasattr(asd_detector, api_name), (
f"torch_npu.utils.asd_detector.{{api_name}} is missing"
)
utils_api = getattr(utils, api_name)
detector_api = getattr(asd_detector, api_name)
assert callable(utils_api)
assert callable(detector_api)
assert utils_api is detector_api
# AFD compatibility behavior.
assert hasattr(C, "_afd") is False
assert "torch_npu._afd" in sys.modules
assert "torch_npu._C._afd" not in sys.modules
try:
import torch_npu._C._afd # noqa: F401
raise AssertionError("import torch_npu._C._afd should fail")
except ModuleNotFoundError:
pass
for name in afd_ops:
assert hasattr(torch_npu._afd, name), (
f"torch_npu._afd.{{name}} is missing"
)
"""
)
def test_07_distributed_patch_behavior(self):
self._run_python(
"""
import sys
import torch
import torch_npu
import torch.distributed as dist
import torch.distributed.distributed_c10d as c10d
import torch.distributed.launcher.api as launcher_api
from torch.distributed.fsdp import sharded_grad_scaler
from torch_npu.npu.amp.sharded_grad_scaler import _ShardedGradScaler
assert torch._C._distributed_c10d._verify_params_across_processes is (
torch_npu.distributed._verify_params_across_processes
)
assert torch._C._distributed_c10d.ProcessGroup._get_sequence_number_for_group is (
torch_npu.distributed.distributed_c10d._hccl_get_sequence_number_for_group
)
assert c10d._add_ephemeral_timeout_for_all_pgs is (
torch_npu.distributed.distributed_c10d._hccl_add_ephemeral_timeout_for_all_pgs
)
assert dist.batch_isend_irecv is (
torch_npu.distributed.distributed_c10d._batch_isend_irecv
)
assert c10d.batch_isend_irecv is (
torch_npu.distributed.distributed_c10d._batch_isend_irecv
)
assert dist.gather is torch_npu.distributed.distributed_c10d._gather
assert c10d.gather is torch_npu.distributed.distributed_c10d._gather
assert dist.gather_object is torch_npu.distributed.distributed_c10d._gather_object
assert c10d.gather_object is torch_npu.distributed.distributed_c10d._gather_object
assert dist.is_hccl_available is torch_npu.distributed.is_hccl_available
assert dist.reinit_process_group is torch_npu.distributed.reinit_process_group
assert callable(c10d.rendezvous)
assert callable(launcher_api._get_addr_and_port)
assert sharded_grad_scaler.ShardedGradScaler is _ShardedGradScaler
"""
)
def test_08_top_level_unsupported_dtype_compatibility(self):
self._run_python(
"""
import torch
import torch_npu
# Regression test for external packages such as MindSpeed.
# They access torch_npu.unsupported_dtype directly after importing torch_npu.
expected_unsupported_dtype = [
torch.quint8,
torch.quint4x2,
torch.quint2x4,
torch.qint32,
torch.qint8,
]
unsupported_dtype = torch_npu.unsupported_dtype
assert unsupported_dtype == expected_unsupported_dtype
assert "unsupported_dtype" in dir(torch_npu)
assert "unsupported_dtype" in torch_npu.__dict__
# Simulate MindSpeed-style dtype filtering.
valid_dtype_names = []
for name, attr in torch.__dict__.items():
if isinstance(attr, torch.dtype) and attr not in torch_npu.unsupported_dtype:
valid_dtype_names.append(name)
assert valid_dtype_names, "no valid torch dtype found"
"""
)
def test_09_legacy_submodule_attribute_compatibility(self):
self._run_python(
"""
import sys
import torch_npu
# Old behavior: importing torch_npu also exposed these child modules
# as attributes on their parent packages.
assert "torch_npu.asd.checksum" in sys.modules
assert hasattr(torch_npu, "asd")
assert hasattr(torch_npu.asd, "checksum")
assert torch_npu.asd.checksum is sys.modules["torch_npu.asd.checksum"]
assert hasattr(torch_npu.asd.checksum, "_matmul_checksum")
assert "torch_npu.utils.syncbatchnorm" in sys.modules
assert hasattr(torch_npu.utils, "syncbatchnorm")
assert torch_npu.utils.syncbatchnorm is (
sys.modules["torch_npu.utils.syncbatchnorm"]
)
"""
)
def test_10_legacy_top_level_distributed_api_compatibility(self):
self._run_python(
"""
import torch_npu
from torch.distributed.fsdp import sharded_grad_scaler
from torch_npu._C._distributed_c10d import ParallelStore
from torch_npu.npu.amp.sharded_grad_scaler import _ShardedGradScaler
# Old behavior: these names were visible on torch_npu top-level
# due to module-scope imports in the old monolithic __init__.py.
assert hasattr(torch_npu, "ParallelStore")
assert torch_npu.ParallelStore is ParallelStore
assert "ParallelStore" not in torch_npu.__all__
assert hasattr(torch_npu, "_ShardedGradScaler")
assert torch_npu._ShardedGradScaler is _ShardedGradScaler
assert "_ShardedGradScaler" not in torch_npu.__all__
# The FSDP patch behavior should still be preserved.
assert sharded_grad_scaler.ShardedGradScaler is _ShardedGradScaler
"""
)
def test_11_import_does_not_trigger_device_count(self):
cases = [
"import torch; import torch_npu",
"import torch_npu",
]
for import_code in cases:
self._run_python(
f"""
import os
{import_code}
assert torch_npu.npu.is_initialized() is False, (
"import torch_npu unexpectedly triggered NPU lazy init"
)
# Regression test for import-time low-level device probing.
# If import torch_npu has already called low-level NPU device count,
# changing ASCEND_RT_VISIBLE_DEVICES here will no longer take effect.
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "32"
raw_count = torch_npu._C._npu_getDeviceCount()
assert raw_count == 0, (
f"import torch_npu unexpectedly triggered low-level NPU "
f"device probing, raw_count={{raw_count}}"
)
assert torch_npu.npu.is_initialized() is False
"""
)
if __name__ == "__main__":
run_tests()