import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu._C import _weak_ref_tensor
class TestNPUFormat(TestCase):
def test_enum_values(self):
"""test the enumeration value"""
self.assertEqual(torch_npu.Format.NCHW.value, 0)
self.assertEqual(torch_npu.Format.NHWC.value, 1)
def test_npu_format_cast(self):
"""test npu_format_cast"""
tensor = torch.ones(2, 2).npu()
out1 = torch_npu.npu_format_cast(tensor, 0)
fmt1 = torch_npu.get_npu_format(out1)
self.assertEqual(fmt1, torch_npu.Format.NCHW)
out2 = torch_npu.npu_format_cast(tensor, torch_npu.Format.NHWC)
fmt2 = torch_npu.get_npu_format(out2)
self.assertEqual(fmt2, torch_npu.Format.NHWC)
torch_npu.npu.config.allow_internal_format = True
out3 = torch_npu.npu_format_cast(tensor, torch_npu.Format.FRACTAL_NZ)
fmt3 = torch_npu.get_npu_format(out3)
self.assertEqual(fmt3, torch_npu.Format.FRACTAL_NZ)
out4 = torch_npu.npu_format_cast(out3, torch_npu.Format.ND)
fmt4 = torch_npu.get_npu_format(out4)
self.assertEqual(fmt4, torch_npu.Format.ND)
def test_npu_format_cast_(self):
"""test npu_format_cast_"""
x1 = torch.ones(2, 2).npu()
x2 = torch.ones(2, 2).npu()
torch_npu.npu_format_cast_(x1, 0)
fmt1 = torch_npu.get_npu_format(x1)
self.assertEqual(fmt1, torch_npu.Format.NCHW)
torch_npu.npu_format_cast_(x2, torch_npu.Format.NHWC)
fmt2 = torch_npu.get_npu_format(x2)
self.assertEqual(fmt2, torch_npu.Format.NHWC)
def test_get_npu_format(self):
"""test get_npu_format"""
x1 = torch.ones(2, 2).npu()
torch_npu.npu_format_cast_(x1, 0)
fmt1 = torch_npu.get_npu_format(x1)
self.assertEqual(fmt1, torch_npu.Format.NCHW)
self.assertEqual(fmt1, 0)
def test_get_npu_format_weak_ref(self):
"""test get_npu_format"""
torch_npu.npu.config.allow_internal_format = True
x1 = torch.ones(2, 2).npu()
torch_npu.npu_format_cast_(x1, torch_npu.Format.FRACTAL_NZ)
weak_x1 = _weak_ref_tensor(x1)
fmt1 = torch_npu.get_npu_format(weak_x1)
self.assertEqual(fmt1, torch_npu.Format.FRACTAL_NZ)
self.assertEqual(x1.data_ptr(), weak_x1.data_ptr())
def test_weak_ref_tensor_with_storage_offset(self):
"""test _weak_ref_tensor preserves shape, strides, offset and data"""
view_shape = [2, 1, 8, 64]
view_strides = [1536, 0, 192, 1]
view_offset = 128
max_offset = view_offset
for i in range(len(view_shape)):
max_offset += (view_shape[i] - 1) * view_strides[i]
storage_size = max_offset + 1
base = torch.arange(storage_size, dtype=torch.float32).npu()
view = torch.as_strided(base, size=view_shape, stride=view_strides,
storage_offset=view_offset)
weak = _weak_ref_tensor(view)
self.assertEqual(weak.size(), view.size())
self.assertEqual(weak.stride(), view.stride())
self.assertEqual(weak.storage_offset(), view.storage_offset())
self.assertEqual(weak.storage().nbytes(), view.storage().nbytes())
self.assertTrue(torch.equal(weak, view))
if __name__ == "__main__":
run_tests()