import functools
from torch._inductor.runtime.hints import DeviceProperties
def patch_create_device_properties():
DeviceProperties.create = NPUDeviceProperties.create
class NPUDeviceProperties(DeviceProperties):
@classmethod
@functools.lru_cache(None)
def create(cls, device) -> DeviceProperties:
import torch
from torch._dynamo.device_interface import get_interface_for_device
device_type = device.type
if torch.version.hip and device_type == "cuda":
device_type = "hip"
device_interface = get_interface_for_device(device)
props = device_interface.get_device_properties(device)
try:
multi_processor_count = props.vector_core_num
except AttributeError:
if device_type == "xpu":
multi_processor_count = props.gpu_subslice_count
else:
raise
return DeviceProperties(
type=device_type,
index=device.index,
multi_processor_count=multi_processor_count,
cc=device_interface.get_compute_capability(device),
major=getattr(props, "major", None),
regs_per_multiprocessor=getattr(props, "regs_per_multiprocessor", None),
max_threads_per_multi_processor=getattr(
props, "max_threads_per_multi_processor", None
),
warp_size=getattr(props, "warp_size", 32 if device_type != "cpu" else None),
)