import contextlib
import torch
def has_cuda():
return torch.cuda.is_available()
def has_npu():
return hasattr(torch, "npu") and torch.npu.is_available()
def has_accelerator():
return has_cuda() or has_npu()
def get_default_device():
if has_cuda():
return torch.device("cuda")
if has_npu():
return torch.device("npu")
return torch.device("cpu")
def get_stream_context(device):
if device.type == "cuda":
return torch.cuda.stream(torch.cuda.Stream(device))
if device.type == "npu" and hasattr(torch, "npu") and hasattr(torch.npu, "Stream") and hasattr(torch.npu, "stream"):
return torch.npu.stream(torch.npu.Stream(device))
return contextlib.nullcontext()
def get_autocast_context(device, enabled):
if not enabled:
return contextlib.nullcontext()
if device.type == "cuda":
return torch.cuda.amp.autocast(True)
if device.type == "npu":
if hasattr(torch, "npu") and hasattr(torch.npu, "amp") and hasattr(torch.npu.amp, "autocast"):
return torch.npu.amp.autocast(True)
if hasattr(torch, "amp") and hasattr(torch.amp, "autocast"):
try:
return torch.amp.autocast(device_type="npu", enabled=True)
except Exception:
return contextlib.nullcontext()
return contextlib.nullcontext()
def empty_cache_and_sync(device):
if device.type == "cuda":
torch.cuda.empty_cache()
torch.cuda.current_stream().synchronize()
return
if device.type == "npu" and hasattr(torch, "npu"):
torch.npu.empty_cache()
if hasattr(torch.npu, "current_stream"):
torch.npu.current_stream().synchronize()