import torch
from torch_npu.npu import device_count
from torch_npu.utils._dynamo_device import NpuInterface, current_device, set_device
from torch_npu.utils._inductor import NPUDeviceOpOverrides
from torch_npu._inductor.config import config as npu_config
from torch_npu._inductor.npu_device import NewNPUDeviceOpOverrides
from torch_npu.testing.testcase import TestCase, run_tests
class TestNpuDevice(TestCase):
def test_aoti_get_stream(self):
overrides = NewNPUDeviceOpOverrides()
result = overrides.aoti_get_stream()
excepted = "aoti_torch_get_current_cuda_stream"
self.assertEqual(result, excepted)
def test_cpp_stream_type(self):
overrides = NewNPUDeviceOpOverrides()
result = overrides.cpp_stream_type()
excepted = "aclrtStream"
self.assertEqual(result, excepted)
def test_abi_compatible_header(self):
overrides = NewNPUDeviceOpOverrides()
result = overrides.abi_compatible_header()
self.assertIn("#include <fstream>", result)
self.assertIn("#include <vector>", result)
self.assertIn("#include <iostream>", result)
self.assertIn("#include <string>", result)
self.assertIn("#include <tuple>", result)
self.assertIn("#include <unordered_map>", result)
self.assertIn("#include <memory>", result)
self.assertIn("#include <filesystem>", result)
self.assertIn("#include <assert.h>", result)
self.assertIn("#include <stdbool.h>", result)
self.assertIn("#include <sys/syscall.h>", result)
self.assertIn("#include <torch_npu/csrc/framework/OpCommand.h>", result)
self.assertIn("#include <torch_npu/csrc/core/npu/NPUStream.h>", result)
self.assertIn("#include \"experiment/runtime/runtime/rt.h\"", result)
def test_cpp_aoti_stream_guard(self):
overrides = NewNPUDeviceOpOverrides()
result = overrides.cpp_aoti_stream_guard()
excepted = "AOTICudaStreamGuard"
self.assertEqual(result, excepted)
def test_cpp_aoti_device_guard_not_implemented(self):
overrides = NewNPUDeviceOpOverrides()
with self.assertRaises(NotImplementedError):
overrides.cpp_aoti_device_guard()
def test_device_guard(self):
overrides = NewNPUDeviceOpOverrides()
result = overrides.device_guard(0)
excepted = "torch.npu.utils.device(0)"
self.assertEqual(result, excepted)
def test_synchronize(self):
overrides = NewNPUDeviceOpOverrides()
result = overrides.synchronize()
excepted = """
stream = torch.npu.current_stream()
stream.synchronize()
"""
self.assertEqual(result, excepted)
def test_set_device(self):
overrides = NewNPUDeviceOpOverrides()
result = overrides.set_device(0)
excepted = "torch.npu.set_device(0)"
self.assertEqual(result, excepted)
def test_import_get_raw_stream_as(self):
overrides = NewNPUDeviceOpOverrides()
result = overrides.import_get_raw_stream_as("test_name")
excepted = "from torch_npu._inductor import get_current_raw_stream as test_name"
self.assertEqual(result, excepted)
if __name__ == "__main__":
run_tests()