import copy
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices
class TestStorage(TestCase):
def test_storage_method(self):
storage_types = [
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.half,
torch.float32,
torch.float64,
torch.bool,
torch.uint8,
torch.bfloat16,
]
for dtype in storage_types:
def _test_cpu(cpu_storage, npu_storage):
npu_res = npu_storage.cpu()
self.assertEqual(cpu_storage, npu_res)
def _test_npu(cpu_storage, npu_storage):
cpu_res = cpu_storage.npu().cpu()
npu_res = npu_storage.cpu()
self.assertEqual(cpu_res, npu_res)
def _test_clone(cpu_storage, npu_storage):
cpu_res = cpu_storage.clone()
npu_res = npu_storage.clone().cpu()
self.assertEqual(cpu_res, npu_res)
def _test_copy_(cpu_storage, npu_storage):
cpu_res = torch.ones([3, 1, 2, 2]).to(dtype).storage()
npu_res = torch.zeros([3, 1, 2, 2]).npu().to(dtype).storage()
cpu_res.copy_(cpu_storage)
npu_res.copy_(npu_storage)
self.assertEqual(cpu_res, npu_res.cpu())
def _test_untyped(cpu_storage, npu_storage):
cpu_res = cpu_storage.untyped()
npu_res = npu_storage.untyped()
if dtype == torch.float64:
self.assertEqual(cpu_storage.float().untyped(), npu_res)
else:
self.assertEqual(cpu_res, npu_res.cpu())
def _test_data_ptr(cpu_storage, npu_storage):
cpu_res = cpu_storage.data_ptr()
npu_res = npu_storage.data_ptr()
self.assertNotEqual(npu_res, None)
self.assertNotEqual(cpu_res, npu_res)
def _test_element_size(cpu_storage, npu_storage):
cpu_res = cpu_storage.element_size()
npu_res = npu_storage.element_size()
if dtype == torch.float64:
self.assertEqual(cpu_res, npu_res * 2)
else:
self.assertEqual(cpu_res, npu_res)
def _test_fill_(cpu_storage, npu_storage):
cpu_storage.fill_(1)
npu_storage.fill_(1)
npu_res = npu_storage.cpu()
self.assertEqual(cpu_storage, npu_res)
def _test_get_device(cpu_storage, npu_storage):
npu_res = npu_storage.get_device()
self.assertEqual(npu_res, 0)
def _test_is_pinned(cpu_storage, npu_storage):
cpu_res = cpu_storage.is_pinned()
npu_res = npu_storage.is_pinned()
self.assertEqual(cpu_res, False)
self.assertEqual(npu_res, False)
cpu_res = cpu_storage.is_pinned("npu")
npu_res = npu_storage.is_pinned("npu")
self.assertEqual(cpu_res, False)
self.assertEqual(npu_res, False)
def _test_pin_memory(cpu_storage, npu_storage):
ori_ptr = cpu_storage.data_ptr()
cpu_pin_storage = cpu_storage.pin_memory("npu")
self.assertEqual(cpu_storage.data_ptr(), ori_ptr)
self.assertNotEqual(cpu_pin_storage.data_ptr(), ori_ptr)
self.assertEqual(cpu_storage.is_pinned("npu"), False)
self.assertEqual(cpu_pin_storage.is_pinned("npu"), True)
def _test_nbytes(cpu_storage, npu_storage):
cpu_res = cpu_storage.nbytes()
npu_res = npu_storage.nbytes()
if dtype == torch.float64:
self.assertEqual(cpu_res, npu_res * 2)
else:
self.assertEqual(cpu_res, npu_res)
def _test_pickle_storage_type(cpu_storage, npu_storage):
cpu_res = cpu_storage.pickle_storage_type()
npu_res = npu_storage.pickle_storage_type()
if dtype == torch.float64:
self.assertEqual(npu_res, "FloatStorage")
else:
self.assertEqual(cpu_res, npu_res)
def _test_size(cpu_storage, npu_storage):
cpu_res = cpu_storage.size()
npu_res = npu_storage.size()
self.assertEqual(cpu_res, npu_res)
def _test_tolist(cpu_storage, npu_storage):
cpu_res = cpu_storage.tolist()
npu_res = npu_storage.tolist()
self.assertEqual(cpu_res, npu_res)
def _test_resize_(cpu_storage, npu_storage):
cpu_ori_ptr = cpu_storage.data_ptr()
npu_ori_ptr = npu_storage.data_ptr()
cpu_storage.resize_(24)
npu_storage.resize_(24)
self.assertEqual(cpu_storage.size(), npu_storage.size())
self.assertEqual(cpu_storage.tolist()[:12], npu_storage.tolist()[:12])
self.assertNotEqual(cpu_storage.data_ptr(), cpu_ori_ptr)
self.assertNotEqual(npu_storage.data_ptr(), npu_ori_ptr)
cpu_ori_ptr = cpu_storage.data_ptr()
npu_ori_ptr = npu_storage.data_ptr()
cpu_storage.resize_(8)
npu_storage.resize_(8)
self.assertEqual(cpu_storage.size(), npu_storage.size())
self.assertEqual(cpu_storage.tolist(), npu_storage.tolist())
self.assertNotEqual(cpu_storage.data_ptr(), cpu_ori_ptr)
self.assertNotEqual(npu_storage.data_ptr(), npu_ori_ptr)
cpu_storage.resize_(0)
npu_storage.resize_(0)
self.assertEqual(cpu_storage.size(), npu_storage.size())
cpu_storage.resize_(16)
npu_storage.resize_(16)
cpu_storage.fill_(4)
npu_storage.fill_(4)
def _test_is_shared(cpu_storage, npu_storage):
cpu_res = cpu_storage.is_shared()
npu_res = npu_storage.is_shared()
self.assertEqual(cpu_res, False)
self.assertEqual(npu_res, False)
def _test_share_memory_(cpu_storage, npu_storage):
npu_ori_ptr = npu_storage.data_ptr()
cpu_res = cpu_storage.share_memory_()
npu_res = npu_storage.share_memory_()
self.assertEqual(npu_storage.data_ptr(), npu_ori_ptr)
self.assertEqual(cpu_res.is_shared(), True)
self.assertEqual(npu_res.is_shared(), False)
def _test_dtype(cpu_storage, npu_storage):
cpu_res = cpu_storage.dtype
npu_res = npu_storage.dtype
if cpu_res == torch.float64:
self.assertEqual(npu_res, torch.float32)
else:
self.assertEqual(npu_res, cpu_res)
def _test_device(cpu_storage, npu_storage):
cpu_res = cpu_storage.device
npu_res = npu_storage.device
self.assertEqual(cpu_res.type, "cpu")
self.assertEqual(npu_res.type, "npu")
def _test_datatype_cast(cpu_storage, npu_storage):
dtypes = [
"bool",
"double",
"float",
"half",
"long",
"int",
"short",
"byte",
"char",
]
for dt in dtypes:
cpu_res = eval("cpu_storage" + "." + dt + "()")
npu_res = eval("npu_storage" + "." + dt + "()")
self.assertEqual(cpu_res.size(), npu_res.size())
self.assertEqual(cpu_res, npu_res.cpu())
self.assertEqual(cpu_res.tolist(), npu_res.cpu().tolist())
@SupportedDevices(['Ascend910B'])
def _test_datatype_cast_complex(cpu_storage, npu_storage):
dtypes = [
"complex_double",
"complex_float",
]
for dt in dtypes:
cpu_res = eval("cpu_storage" + "." + dt + "()")
npu_res = eval("npu_storage" + "." + dt + "()")
self.assertEqual(cpu_res.size(), npu_res.size())
def _test_from_buffer(cpu_storage, npu_storage):
cpu_list = [2, 3, 3, 2, 5]
cpu_buffer = bytearray(cpu_list)
cpu_res = torch.ByteStorage.from_buffer(cpu_buffer)
self.assertEqual(cpu_res.tolist(), cpu_list)
''' test TypedStorage, FloatStorage and so on '''
cpu_tensor = torch.randn([3, 1, 2, 2])
npu_tensor = cpu_tensor.npu()
cpu_storage = cpu_tensor.to(dtype).storage()
npu_storage = npu_tensor.to(dtype).storage()
_test_fill_(cpu_storage, npu_storage)
_test_cpu(cpu_storage, npu_storage)
_test_npu(cpu_storage, npu_storage)
_test_clone(cpu_storage, npu_storage)
_test_copy_(cpu_storage, npu_storage)
_test_untyped(cpu_storage, npu_storage)
_test_data_ptr(cpu_storage, npu_storage)
_test_element_size(cpu_storage, npu_storage)
_test_fill_(cpu_storage, npu_storage)
_test_get_device(cpu_storage, npu_storage)
_test_is_pinned(cpu_storage, npu_storage)
_test_pin_memory(cpu_storage, npu_storage)
_test_nbytes(cpu_storage, npu_storage)
_test_pickle_storage_type(cpu_storage, npu_storage)
_test_size(cpu_storage, npu_storage)
_test_tolist(cpu_storage, npu_storage)
_test_resize_(cpu_storage, npu_storage)
_test_is_shared(cpu_storage, npu_storage)
_test_share_memory_(cpu_storage, npu_storage)
_test_dtype(cpu_storage, npu_storage)
_test_device(cpu_storage, npu_storage)
_test_datatype_cast(cpu_storage, npu_storage)
_test_datatype_cast_complex(cpu_storage, npu_storage)
_test_from_buffer(cpu_storage, npu_storage)
''' test untyped storage only on a certain data type'''
if dtype == torch.int8:
def _test_mps(cpu_storage, npu_storage):
npu_res = npu_storage.mps()
cpu_res = cpu_storage.mps()
self.assertEqual(cpu_res, npu_res)
def _test_new(cpu_storage, npu_storage):
npu_res = npu_storage.new()
cpu_res = cpu_storage.new()
self.assertEqual(cpu_res.size(), npu_res.size())
self.assertEqual(cpu_res.type(), npu_res.type())
self.assertEqual(cpu_res.device.type, cpu_storage.device.type)
self.assertEqual(npu_res.device.type, npu_storage.device.type)
def _test_type(cpu_storage, npu_storage):
npu_res = npu_storage.type()
self.assertEqual(npu_res, "torch.storage.UntypedStorage")
def _test_copy_(cpu_storage, npu_storage):
cpu_res = torch.ones([3, 1, 2, 2]).untyped_storage()
npu_res = torch.zeros([3, 1, 2, 2]).npu().untyped_storage()
cpu_res.copy_(cpu_storage)
npu_res.copy_(npu_storage)
self.assertEqual(cpu_res, npu_res.cpu())
cpu_tensor = torch.randn([3, 1, 2, 2])
npu_tensor = cpu_tensor.npu()
cpu_storage = cpu_tensor.untyped_storage()
npu_storage = npu_tensor.untyped_storage()
''' test typed storage '''
_test_new(cpu_storage, npu_storage)
_test_cpu(cpu_storage, npu_storage)
_test_npu(cpu_storage, npu_storage)
_test_clone(cpu_storage, npu_storage)
_test_copy_(cpu_storage, npu_storage)
_test_untyped(cpu_storage, npu_storage)
_test_data_ptr(cpu_storage, npu_storage)
_test_element_size(cpu_storage, npu_storage)
_test_fill_(cpu_storage, npu_storage)
_test_get_device(cpu_storage, npu_storage)
_test_is_pinned(cpu_storage, npu_storage)
_test_pin_memory(cpu_storage, npu_storage)
_test_nbytes(cpu_storage, npu_storage)
_test_size(cpu_storage, npu_storage)
_test_tolist(cpu_storage, npu_storage)
_test_type(cpu_storage, npu_storage)
_test_resize_(cpu_storage, npu_storage)
_test_is_shared(cpu_storage, npu_storage)
_test_share_memory_(cpu_storage, npu_storage)
_test_device(cpu_storage, npu_storage)
_test_datatype_cast(cpu_storage, npu_storage)
_test_datatype_cast_complex(cpu_storage, npu_storage)
_test_from_buffer(cpu_storage, npu_storage)
with self.assertRaisesRegex(RuntimeError, "Storage device not recognized: mps"):
_test_mps(cpu_storage, npu_storage)
def test_type_conversions(self):
x = torch.randn(5, 5)
supported_dtypes = ["float", "half", "long", "short", "int", "bool", "char", "byte"]
for dtype in supported_dtypes:
self.assertIsInstance(getattr(x.npu(), dtype)(), getattr(torch.npu, dtype.title() + "Tensor"))
self.assertIsInstance(getattr(x.float().cpu(), dtype)(), getattr(torch, dtype.title() + "Tensor"))
y = x.storage()
for dtype in supported_dtypes:
self.assertIsInstance(getattr(y.npu(), dtype)(), getattr(torch.npu, dtype.title() + "Storage"))
self.assertIsInstance(getattr(y.float().cpu(), dtype)(), getattr(torch, dtype.title() + "Storage"))
def test_deepcopy(self):
torch_npu.npu.config.allow_internal_format = True
x = torch.tensor([1])
y = copy.deepcopy(x)
x = torch.tensor([1]).npu()
y = copy.deepcopy(x)
self.assertNotEqual(x.storage().data_ptr(), y.storage().data_ptr())
x = torch.rand(3, 3).npu()
x = torch_npu.npu_format_cast(x, 29)
y = copy.deepcopy(x)
self.assertEqual(torch_npu.get_npu_format(y), 29)
self.assertEqual(x, y)
if __name__ == '__main__':
run_tests()