import torch
from torch._dynamo.device_interface import DeviceInterface, register_interface_for_device, \
caching_worker_current_devices, caching_worker_device_properties
from torch_npu._C import _npu_getCurrentRawStream as get_npu_stream
from ..npu.streams import Event, Stream
from ..npu import current_device, set_device, device_count, stream, current_stream, \
set_stream, synchronize, get_device_capability
from ..npu import get_device_properties as get_device_properties_npu
class NpuInterface(DeviceInterface):
device = torch.device
Event = Event
Stream = Stream
class Worker:
@staticmethod
def set_device(device: int):
caching_worker_current_devices["npu"] = device
@staticmethod
def current_device() -> int:
if "npu" in caching_worker_current_devices:
return caching_worker_current_devices["npu"]
return current_device()
@staticmethod
def get_device_properties(device=None):
if device is not None:
if isinstance(device, str):
device = torch.device(device)
if device.type != "npu":
raise AssertionError('device.type should be equal to npu.')
if isinstance(device, torch.device):
device = device.index
if device is None:
device = NpuInterface.Worker.current_device()
if "npu" not in caching_worker_device_properties:
device_prop = [
get_device_properties_npu(i)
for i in range(device_count())
]
caching_worker_device_properties["npu"] = device_prop
return caching_worker_device_properties["npu"][device]
current_device = staticmethod(current_device)
set_device = staticmethod(set_device)
device_count = staticmethod(device_count)
stream = staticmethod(stream)
current_stream = staticmethod(current_stream)
set_stream = staticmethod(set_stream)
synchronize = staticmethod(synchronize)
get_device_properties = staticmethod(get_device_properties_npu)
_set_stream_by_id = staticmethod(None)
get_raw_stream = staticmethod(get_npu_stream)
@staticmethod
def is_available() -> bool:
return device_count() > 0
@staticmethod
def get_compute_capability(device=None):
r"""Different from cuda, only return the chip model here.
"""
return torch.npu.get_device_name(device)
@staticmethod
def exchange_device(device: int) -> int:
curr_device = current_device()
set_device(device)
return curr_device
@staticmethod
def maybe_exchange_device(device: int) -> int:
return device
@staticmethod
def is_bf16_supported(including_emulation: bool = False):
return True
def _dynamo_register_interface_for_device():
register_interface_for_device("npu", NpuInterface)
for i in range(32):
register_interface_for_device(f"npu:{i}", NpuInterface)