import contextlib
import torch


def has_cuda():
    return torch.cuda.is_available()


def has_npu():
    return hasattr(torch, "npu") and torch.npu.is_available()


def has_accelerator():
    return has_cuda() or has_npu()


def get_default_device():
    if has_cuda():
        return torch.device("cuda")
    if has_npu():
        return torch.device("npu")
    return torch.device("cpu")


def get_stream_context(device):
    if device.type == "cuda":
        return torch.cuda.stream(torch.cuda.Stream(device))
    if device.type == "npu" and hasattr(torch, "npu") and hasattr(torch.npu, "Stream") and hasattr(torch.npu, "stream"):
        return torch.npu.stream(torch.npu.Stream(device))
    return contextlib.nullcontext()


def get_autocast_context(device, enabled):
    if not enabled:
        return contextlib.nullcontext()
    if device.type == "cuda":
        return torch.cuda.amp.autocast(True)
    if device.type == "npu":
        if hasattr(torch, "npu") and hasattr(torch.npu, "amp") and hasattr(torch.npu.amp, "autocast"):
            return torch.npu.amp.autocast(True)
        if hasattr(torch, "amp") and hasattr(torch.amp, "autocast"):
            try:
                return torch.amp.autocast(device_type="npu", enabled=True)
            except Exception:
                return contextlib.nullcontext()
    return contextlib.nullcontext()


def empty_cache_and_sync(device):
    if device.type == "cuda":
        torch.cuda.empty_cache()
        torch.cuda.current_stream().synchronize()
        return
    if device.type == "npu" and hasattr(torch, "npu"):
        torch.npu.empty_cache()
        if hasattr(torch.npu, "current_stream"):
            torch.npu.current_stream().synchronize()