import random
import unittest
from torch_npu.distributed import ParallelStore
class ParallelStoreTest(unittest.TestCase):
def __init__(self, method_name='runTest'):
super(ParallelStoreTest, self).__init__(method_name)
self._begin_port = random.randint(10000, 15000)
self._port_offset = 0
self._client = None
self._server = None
def setUp(self):
self._port_offset += 1
tcp_port = self._begin_port + self._port_offset
self._server = ParallelStore(port=tcp_port, is_server=True, wait_workers=False, multi_tenant=True)
self._client = ParallelStore(port=tcp_port, is_server=False)
def tearDown(self):
self._client = None
self._server = None
def test_client_set_and_server_get(self):
key = 'key/ParallelStoreTest/client_set_and_server_get'
value = b'value/ParallelStoreTest/client_set_and_server_get'
self._client.set(key, value)
result = self._server.get(key)
self.assertEqual(value, result)
def test_client_server_add(self):
key = 'key/ParallelStoreTest/client_server_add'
expected = 1
value = self._client.add(key, 1)
self.assertEqual(expected, value)
for i in range(0, 100):
expected += 1
value = self._server.add(key, 1)
self.assertEqual(expected, value)
expected += 1
value = self._client.add(key, 1)
self.assertEqual(expected, value)
def test_set_delete_key_and_key_count(self):
key_base = 'key/ParallelStoreTest/set_delete_key_and_key_count'
value_base = 'value/ParallelStoreTest/set_delete_key_and_key_count'
keys = list()
for i in range(0, 100):
key = f'{key_base}/{i}'
value = f'{value_base}/{i}'
keys.append(key)
old_key_count = self._server.num_keys()
self._client.set(key, value)
new_key_count = self._server.num_keys()
self.assertEqual(old_key_count + 1, new_key_count)
for key in keys:
old_key_count = self._server.num_keys()
self._client.delete_key(key)
new_key_count = self._server.num_keys()
self.assertEqual(old_key_count - 1, new_key_count)
def test_multi_server_set_get(self):
key = 'key/ParallelStoreTest/test_multi_server_set_get'
tcp_port = self._begin_port + self._port_offset
server2 = ParallelStore(port=tcp_port, is_server=True, wait_workers=False, multi_tenant=True)
value1 = server2.add(key, 100)
value2 = self._server.add(key, 0)
self.assertEqual(value1, value2)
if __name__ == '__main__':
unittest.main()