import functools
from typing import List, Dict
from typing import Optional
from torch._inductor.remote_cache import JsonDataTy
from torch._inductor.runtime.hints import DeviceProperties
from torch.utils._triton import has_triton, has_triton_package
from .config import num_vector_core
if has_triton_package():
from triton import Config
def _load_cached_autotuning(
best_config: Dict[str, JsonDataTy],
configs_hash: str,
configs: List[Config],
inductor_meta: Dict,
) -> Optional[Config]:
if best_config is None:
return None
if best_config.pop("configs_hash", None) != configs_hash:
return None
best_config.pop("time_taken_ms", None)
num_warps = best_config.pop("num_warps", None)
num_stages = best_config.pop("num_stages", None)
triton_config = Config(best_config, num_warps=num_warps, num_stages=num_stages)
triton_config.found_by_coordesc = True
return triton_config
class NPUDeviceProperties(DeviceProperties):
@classmethod
@functools.lru_cache(None)
def create(cls, device) -> DeviceProperties:
import torch
from torch._dynamo.device_interface import get_interface_for_device
device_type = device.type
if torch.version.hip and device_type == "cuda":
device_type = "hip"
device_interface = get_interface_for_device(device)
props = device_interface.get_device_properties(device)
try:
multi_processor_count = num_vector_core
except AttributeError:
if device_type == "xpu":
multi_processor_count = props.gpu_subslice_count
else:
raise
return cls(
type=device_type,
index=device.index,
multi_processor_count=multi_processor_count,
cc=device_interface.get_compute_capability(device),
major=getattr(props, "major", None),
regs_per_multiprocessor=getattr(props, "regs_per_multiprocessor", None),
max_threads_per_multi_processor=getattr(
props, "max_threads_per_multi_processor", None
),
warp_size=getattr(props, "warp_size", 32 if device_type != "cpu" else None),
)