from collections import namedtuple, OrderedDict
from multiprocessing.reduction import ForkingPickler
import torch
import numpy as np
from torch.testing._internal.common_utils import parametrize, instantiate_parametrized_tests
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
class TestNestedTensor(TestCase):
@parametrize("batch_size", [2, 4])
@parametrize("max_seq_len", [3, 5])
@parametrize("vocab_size", [16, 32])
def test_2d_nested_tensor(self, batch_size, max_seq_len, vocab_size):
data = []
nested_tensor_ref_list = []
for _ in range(batch_size):
if max_seq_len == 0:
length = 0
else:
length = np.random.randint(1, max_seq_len)
row = list(np.random.randint(low=0, high=vocab_size, size=(length,)))
data.append(row)
nested_tensor_ref_list.append(torch.tensor(row).npu())
nested_tensor = torch.nested.nested_tensor(data)
nested_tensor_list = nested_tensor.unbind()
for i in range(batch_size):
self.assertEqual(nested_tensor_list[i], nested_tensor_ref_list[i].type(torch.int64))
@parametrize("batch_size", [2, 4])
@parametrize("max_seq_len", [3, 5])
@parametrize("vocab_size", [16, 32])
def test_3d_nested_tensor(self, batch_size, max_seq_len, vocab_size):
data = []
nested_tensor_ref_list = []
for _ in range(batch_size):
if max_seq_len == 0:
length = 0
else:
length = np.random.randint(1, max_seq_len)
row = list(np.random.randint(low=0, high=vocab_size, size=(length,)))
row = [list(item * np.arange(max_seq_len)) for item in row]
data.append(row)
nested_tensor_ref_list.append(torch.tensor(row).npu())
nested_tensor = torch.nested.nested_tensor(data)
nested_tensor_list = nested_tensor.unbind()
for i in range(batch_size):
self.assertEqual(nested_tensor_list[i], nested_tensor_ref_list[i].type(torch.int64))
@parametrize("batch_size", [2, 4])
@parametrize("max_seq_len", [3, 5])
@parametrize("vocab_size", [16, 32])
def test_3d_nested_tensor_float(self, batch_size, max_seq_len, vocab_size):
data = []
nested_tensor_ref_list = []
for _ in range(batch_size):
if max_seq_len == 0:
length = 0
else:
length = np.random.randint(1, max_seq_len)
row = list(np.random.randint(low=0, high=vocab_size, size=(length,)))
row = [list(item * np.arange(max_seq_len)) for item in row]
data.append(row)
nested_tensor_ref_list.append(torch.tensor(row).npu())
nested_tensor = torch.nested.nested_tensor(data)
nested_tensor_list = nested_tensor.unbind()
for i in range(batch_size):
self.assertEqual(nested_tensor_list[i], nested_tensor_ref_list[i].type(torch.float32))
def _test_unbind_case(self, a, b):
nt = torch.nested.nested_tensor([a.npu(), b.npu()], dtype=a.dtype)
nt_list = nt.unbind()
self.assertEqual(len(nt_list), 2)
self.assertEqual(nt_list[0], a)
self.assertEqual(nt_list[1], b)
def _test_asnested_case(self, a, b):
nt = torch.nested.nested_tensor([a.npu(), b.npu()], dtype=a.dtype)
nt_as = torch.nested.as_nested_tensor([a.npu(), b.npu()], dtype=a.dtype)
nt_list = nt.unbind()
nt_as_list = nt_as.unbind()
self.assertEqual(len(nt_as_list), len(nt_list))
self.assertEqual(nt_as_list[0], nt_list[0])
self.assertEqual(nt_as_list[1], nt_list[1])
def test_unbind_and_asnested_int64(self):
a = torch.tensor([[1, 2, 3], [4, 5, 6]])
b = torch.tensor([[7, 8], [10, 11]])
self._test_unbind_case(a, b)
self._test_asnested_case(a, b)
def test_unbind_and_asnested_float32(self):
a = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32)
b = torch.tensor([[7, 8], [10, 11]], dtype=torch.float32)
self._test_unbind_case(a, b)
self._test_asnested_case(a, b)
def test_unbind_and_asnested_empty(self):
a = torch.tensor([[], []])
b = torch.tensor([[], [], []])
self._test_unbind_case(a, b)
self._test_asnested_case(a, b)
def test_default_options_nested_tensor(self):
default_nested_tensor = torch.nested.nested_tensor([], device="npu:0")
default_tensor = torch.tensor([]).npu()
self.assertEqual(default_nested_tensor.dtype, default_tensor.dtype)
self.assertEqual(default_nested_tensor.device, default_tensor.device)
self.assertEqual(default_nested_tensor.layout, default_tensor.layout)
self.assertEqual(default_nested_tensor.dim(), default_tensor.dim())
self.assertEqual(default_nested_tensor.requires_grad, default_tensor.requires_grad)
def test_nested_tensor_errsize(self):
nt = torch.nested.nested_tensor([torch.tensor([[1, 2, 3], [4, 5, 6]]).npu(), torch.tensor([[7, 8], [10, 11], [12, 13]]).npu()])
self.assertEqual(nt.size(0), 2)
self.assertRaisesRegex(RuntimeError,
"Given dimension 1 is irregular and does not have a size",
lambda: nt.size(1),
)
nt = torch.nested.nested_tensor([2])
self.assertEqual(nt.size(0), 1)
if __name__ == '__main__':
instantiate_parametrized_tests(TestNestedTensor)
run_tests()