import sys
import types
import logging
from typing import Any, Optional
import torch
logger = logging.getLogger(__name__)
IS_CUDA_AVAILABLE = torch.cuda.is_available()
IS_NPU_AVAILABLE = False
try:
import torch_npu
IS_NPU_AVAILABLE = True
except Exception as e:
IS_NPU_AVAILABLE = False
if IS_NPU_AVAILABLE:
torch.npu.config.allow_internal_format = False
def accelerator_getattr(module, fallback_module):
def __getattr__(name):
if hasattr(fallback_module, name):
attr = getattr(fallback_module, name)
setattr(module, name, attr)
return attr
else:
raise AttributeError(f'module {module} and {fallback_module} has no attribute {name}.')
return __getattr__
def set_accelerator_compatible(fallback_module=None):
accelerator_module = types.ModuleType('torch.accelerator')
accelerator_module.__doc__ = f'Fallback accelerator module that delegates to {get_device_type()}'
for attr in dir(torch.accelerator):
if attr.startswith('__'):
continue
setattr(accelerator_module, attr, getattr(torch.accelerator, attr))
accelerator_module.__getattr__ = accelerator_getattr(accelerator_module, fallback_module)
torch.accelerator = accelerator_module
sys.modules['torch.accelerator'] = accelerator_module
def get_dist_comm_backend(cpu: bool = False) -> str:
"""Return distributed communication backend type based on device type."""
if cpu:
if IS_CUDA_AVAILABLE:
return "cpu:gloo,cuda:nccl"
elif IS_NPU_AVAILABLE:
return "cpu:gloo,npu:hccl"
if IS_CUDA_AVAILABLE:
return "nccl"
elif IS_NPU_AVAILABLE:
return "hccl"
else:
raise RuntimeError(f"No available distributed communication backend found on device type {get_device_type()}.")
def get_device_type() -> str:
"""Get device type based on current machine, currently only support CPU, CUDA, NPU."""
if IS_CUDA_AVAILABLE:
device = "cuda"
elif IS_NPU_AVAILABLE:
device = "npu"
else:
device = "cpu"
return device
def get_torch_device() -> Any:
"""Get torch attribute based on device type, e.g. torch.cuda or torch.npu"""
device_name = get_device_type()
try:
return getattr(torch, device_name)
except AttributeError:
logger.warning(f"Device namespace '{device_name}' not found in torch, try to load 'torch.cuda'.")
return torch.cuda
def get_device_name() -> str:
"""Get real device name, e.g. A100, H100"""
return get_torch_device().get_device_name()
def synchronize() -> None:
"""Execute torch synchronize operation."""
get_torch_device().synchronize()
def empty_cache() -> None:
"""Execute torch empty cache operation."""
get_torch_device().empty_cache()
def create_stream(device: Optional[torch.device] = None, priority: int = 0) -> Any:
"Create custom stream."
return get_torch_device().Stream(device=device, priority=priority)
def create_event(enable_timing: bool = False, blocking: bool = False) -> Any:
"Create empty event."
return get_torch_device().Event(enable_timing=enable_timing, blocking=blocking)
def get_current_stream() -> Any:
return get_torch_device().current_stream()
def switch_to_specified_stream(stream) -> Any:
return get_torch_device().stream(stream)
def get_memory_reserved():
if IS_NPU_AVAILABLE:
return torch.npu.memory_reserved()
else:
return torch.cuda.memory.memory_reserved()
def get_max_memory_reserved():
if IS_NPU_AVAILABLE:
return torch.npu.max_memory_reserved()
else:
return torch.cuda.memory.max_memory_reserved()
def get_memory_allocated():
if IS_NPU_AVAILABLE:
return torch.npu.memory_allocated()
else:
return torch.cuda.memory.memory_allocated()
def get_max_memory_allocated():
if IS_NPU_AVAILABLE:
return torch.npu.max_memory_allocated()
else:
return torch.cuda.memory.max_memory_allocated()
def reset_peak_memory_stats():
if IS_NPU_AVAILABLE:
return torch.npu.reset_peak_memory_stats()
else:
return torch.cuda.memory.reset_peak_memory_stats()
def set_allow_hf32(allow_hf32=None) -> None:
"""Set allow_hf32/allow_tf32 attribute based on device type."""
if allow_hf32 is None:
return
if IS_NPU_AVAILABLE:
torch.npu.aclnn.allow_hf32 = allow_hf32
else:
torch.backends.cudnn.allow_tf32 = allow_hf32