import torch
import torch_npu
__all__ = ["get_amp_supported_dtype", "is_autocast_enabled", "set_autocast_enabled", "get_autocast_dtype",
"set_autocast_dtype"]
def get_amp_supported_dtype():
if torch.npu.is_bf16_supported():
return [torch.float16, torch.bfloat16, torch.float32]
return [torch.float16, torch.float32]
def is_autocast_enabled():
return torch_npu._C.is_autocast_enabled()
def set_autocast_enabled(enable):
torch_npu._C.set_autocast_enabled(enable)
def get_autocast_dtype():
return torch_npu._C.get_autocast_dtype()
def set_autocast_dtype(dtype):
return torch_npu._C.set_autocast_dtype(dtype)