from enum import Enum, auto
import torch_npu
from .logs import logger
PLATFORM = None
class NPUDevice(Enum):
UNDEFINED = auto()
A2 = auto()
A3 = auto()
A5 = auto()
Duo = auto()
def get_npu_device() -> NPUDevice:
global PLATFORM
if PLATFORM is None:
try:
if torch_npu.npu.device_count() == 0:
PLATFORM = NPUDevice.UNDEFINED
return PLATFORM
soc_version = torch_npu.npu.get_soc_version()
if 200 <= soc_version <= 205:
PLATFORM = NPUDevice.Duo
elif 220 <= soc_version <= 225:
PLATFORM = NPUDevice.A2
elif 250 <= soc_version <= 255:
PLATFORM = NPUDevice.A3
elif soc_version == 260:
PLATFORM = NPUDevice.A5
else:
PLATFORM = NPUDevice.UNDEFINED
except RuntimeError as exc:
logger.warning(
"[MindIE-SD/utils] NPU SoC version query failed. issue=torch_npu failed to return SoC version, "
"actual_error=%s. possible_cause=NPU driver, CANN, or device environment is unavailable. "
"Troubleshooting: check npu-smi info, CANN environment variables, and torch_npu installation.",
exc,
)
PLATFORM = NPUDevice.UNDEFINED
return PLATFORM
def is_a5_device() -> bool:
"""Return True when the current NPU is identified as an A5 generation chip."""
return get_npu_device() == NPUDevice.A5