from functools import wraps
import torch
import torch_npu
from torch_npu.utils._error_code import ErrCode, pta_error
from torch_npu.utils.storage import _reduce_ex
__all__ = []
def _npu(self, *args, **kwargs):
return torch_npu._C.npu(self, *args, **kwargs)
@property
def _is_npu(self):
return torch_npu._C.is_npu(self)
class _NPUTensortypeCache(object):
init = False
tensortype_list = []
tensortype_dict = {}
@classmethod
def tensortype_list_dict_init(cls):
if not cls.init:
cls.tensortype_list += [
torch_npu.npu.BoolTensor,
torch_npu.npu.ByteTensor,
torch_npu.npu.CharTensor,
torch_npu.npu.DoubleTensor,
torch_npu.npu.FloatTensor,
torch_npu.npu.HalfTensor,
torch_npu.npu.IntTensor,
torch_npu.npu.LongTensor,
torch_npu.npu.ShortTensor,
torch_npu.npu.BFloat16Tensor,
]
cls.tensortype_str_list = [
"torch_npu.npu.BoolTensor",
"torch_npu.npu.ByteTensor",
"torch_npu.npu.CharTensor",
"torch_npu.npu.DoubleTensor",
"torch_npu.npu.FloatTensor",
"torch_npu.npu.HalfTensor",
"torch_npu.npu.IntTensor",
"torch_npu.npu.LongTensor",
"torch_npu.npu.ShortTensor",
"torch_npu.npu.BFloat16Tensor",
]
for tensortype, tensortype_str in zip(cls.tensortype_list, cls.tensortype_str_list):
cls.tensortype_dict[tensortype_str] = tensortype
cls.tensortype_dict[tensortype_str.replace('torch_npu.', 'torch.')] = tensortype
cls.init = True
@classmethod
def get_tensortype_list(cls):
return cls.tensortype_list
@classmethod
def get_tensortype_dict(cls):
return cls.tensortype_dict
def _npu_type(self, dtype=None, non_blocking=False, **kwargs):
if dtype is None:
return self.type_raw(dtype, non_blocking, **kwargs)
_NPUTensortypeCache.tensortype_list_dict_init()
if isinstance(dtype, str) and dtype in _NPUTensortypeCache.get_tensortype_dict():
tensortype_class = _NPUTensortypeCache.get_tensortype_dict()[dtype]
return self.to(dtype=tensortype_class.dtype, device='npu', non_blocking=non_blocking)
elif dtype in _NPUTensortypeCache.get_tensortype_list():
return self.to(dtype=dtype.dtype, device='npu', non_blocking=non_blocking)
else:
return self.type_raw(dtype, non_blocking, **kwargs)
def _add_tensor_methods():
torch.Tensor.type_raw = torch.Tensor.type
torch.Tensor.type = _npu_type
torch.Tensor.__reduce_ex__ = _reduce_ex