"""
Object cache client python interface Test.
"""
from __future__ import absolute_import
import os
import random
import json
import time
import unittest
from yr.datasystem.hetero_client import (
HeteroClient,
Blob,
DeviceBlobList,
)
try:
import acl
except ImportError:
acl = None
def check_acl():
"""
Features: Return False if import acl failed or build without hetero.
"""
return acl is not None and os.getenv("BUILD_HETERO", "off") == "on"
class TestDeviceOcClientMethods(unittest.TestCase):
"""
Features: Object cache client python interface test.
"""
@classmethod
def setUpClass(cls):
root_dir = os.path.dirname(os.path.abspath(".."))
worker_env_path = os.path.join(root_dir, 'output', 'datasystem', 'service', 'worker_config.json')
with open(worker_env_path, "r") as f:
config = json.load(f)
work_address = config.get("worker_address", {})
TestDeviceOcClientMethods.work_addr = work_address.get("value")
(
TestDeviceOcClientMethods.host,
port,
) = TestDeviceOcClientMethods.work_addr.split(":")
TestDeviceOcClientMethods.port = int(port)
@staticmethod
def random_str(slen=10):
"""
Features: Generate a random string for test.
"""
seed = "1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!@#%^*()_+=-"
sa = []
for _ in range(slen):
sa.append(random.choice(seed))
return "".join(sa)
@staticmethod
def init_test_hetero_client():
"""
Features: Init test client.
"""
client = HeteroClient(
TestDeviceOcClientMethods.host,
TestDeviceOcClientMethods.port,
60000,
"",
"",
"",
"",
"QTWAOYTTINDUT2QVKYUC",
"MFyfvK41ba2giqM7**********KGpownRZlmVmHc",
)
client.init()
return client
def get_all_futures(self, future_lists):
"""Function Waiting for all futures to end."""
timeout = 30000
for future_list in future_lists:
for future in future_list:
future.get(timeout)
def check_ptr_content(self, dev_ptr: int, value: str):
"""
Features: Check the device pointer content
"""
size = len(value)
output_byte = bytes("0", "utf-8").zfill(size)
output_ptr = acl.util.bytes_to_ptr(output_byte)
ret = acl.rt.memcpy(output_ptr, size, dev_ptr, size, 2)
self.assertEqual(ret, 0)
self.assertEqual(output_byte, value.encode())
def batch_data_check(self, out_data_blob_list, test_value):
"""Features: check mget result with mset"""
for batch_info in out_data_blob_list:
for info in batch_info.blob_list:
self.check_ptr_content(info.dev_ptr, test_value)
@unittest.skip
def test_mset_and_mget(self):
"""
Features: Test mset and mget device object.
"""
device_id = 0
acl.init()
acl.rt.set_device(device_id)
client = self.init_test_hetero_client()
obj_count = 128
data_size = 1024 * 1024
test_value = self.random_str(data_size)
object_key_list = [self.random_str(10) for _ in range(obj_count)]
in_data_blob_list = self._create_blob_list(device_id, obj_count, data_size, test_value)
out_data_blob_list = self._create_empty_blob_list(device_id, obj_count, data_size)
client.mset_d2h(object_key_list, in_data_blob_list)
client.mget_h2d(object_key_list, out_data_blob_list, 60000)
client.delete(object_key_list)
self.batch_data_check(out_data_blob_list, test_value)
acl.finalize()
@unittest.skip
def test_async_dev_delete(self):
"""
Features: Test mset and mget device object.
"""
src_device_id, dest_device_id = 6, 5
obj_count = 128
data_size = 1024 * 1024
test_value = self.random_str(data_size)
object_key_list = [self.random_str(10) for _ in range(obj_count)]
child1 = os.fork()
if child1 == 0:
acl.init()
acl.rt.set_device(src_device_id)
in_data_blob_list = self._create_blob_list(src_device_id, obj_count, data_size, test_value)
client = self.init_test_hetero_client()
client.dev_mset(object_key_list, in_data_blob_list)
time.sleep(2)
acl.finalize()
else:
child2 = os.fork()
if child2 == 0:
acl.init()
acl.rt.set_device(dest_device_id)
out_data_blob_list = self._create_empty_blob_list(dest_device_id, obj_count, data_size)
client = self.init_test_hetero_client()
failed_keys = client.dev_mget(object_key_list, out_data_blob_list, 10 * 1000)
self.batch_data_check(out_data_blob_list, test_value)
self.assertEqual(len(failed_keys), 0)
future = client.async_dev_delete(object_key_list)
failed_keys = future.get()
self.assertEqual(len(failed_keys), 0)
acl.finalize()
else:
os.waitpid(child1, 0)
os.waitpid(child2, 0)
@unittest.skip
def test_async_mset_and_async_mget(self):
"""
Features: Test async_mset and async_mget device object.
"""
device_id = 0
acl.init()
acl.rt.set_device(device_id)
client = self.init_test_hetero_client()
obj_count = 128
data_size = 1024 * 1024
timeout = 30 * 1000
test_value = self.random_str(data_size)
object_key_list = [self.random_str(10) for _ in range(obj_count)]
in_data_blob_list = self._create_blob_list(device_id, obj_count, data_size, test_value)
out_data_blob_list = self._create_empty_blob_list(device_id, obj_count, data_size)
mset_future = client.async_mset_d2h(object_key_list, in_data_blob_list)
failed_keys = mset_future.get(timeout)
self.assertEqual(len(failed_keys), 0)
mget_future = client.async_mget_h2d(object_key_list, out_data_blob_list, 60000)
failed_keys = mget_future.get(timeout)
self.assertEqual(len(failed_keys), 0)
client.delete(object_key_list)
self.batch_data_check(out_data_blob_list, test_value)
acl.finalize()
@unittest.skip
def test_async_mget_fail(self):
"""
Features: Test async_mset_d2h and async_mget_h2d device object.
"""
device_id = 0
acl.init()
acl.rt.set_device(device_id)
client = self.init_test_hetero_client()
data_size = 1024 * 1024
obj_count = 128
test_value = self.random_str(data_size)
object_key_list = [self.random_str(10) for _ in range(obj_count)]
in_data_blob_list = self._create_blob_list(device_id, obj_count, data_size, test_value)
out_data_blob_list = self._create_empty_blob_list(device_id, obj_count, data_size)
not_exit_key = self.random_str(10)
mset_future = client.async_mset_d2h(object_key_list, in_data_blob_list)
object_key_list[0] = not_exit_key
mget_future = client.async_mget_h2d(object_key_list, out_data_blob_list, 60000)
failed_keys = mset_future.get()
self.assertEqual(len(failed_keys), 0)
failed_keys = mget_future.get()
self.assertEqual(len(failed_keys), 1)
self.assertEqual(failed_keys, [not_exit_key])
client.delete(object_key_list)
acl.finalize()
def _create_blob_list(self, device_id, obj_count, data_size, test_value):
blob_list = []
for _ in range(obj_count):
tmp_batch_list = []
for _ in range(4):
dev_ptr, _ = acl.rt.malloc(data_size, 0)
acl.rt.memcpy(dev_ptr, data_size,
acl.util.bytes_to_ptr(test_value.encode()),
data_size, 1)
blob = Blob(dev_ptr, data_size)
tmp_batch_list.append(blob)
blob_list.append(DeviceBlobList(device_id, tmp_batch_list))
return blob_list
def _create_empty_blob_list(self, device_id, obj_count, data_size):
blob_list = []
for _ in range(obj_count):
tmp_batch_list = []
for _ in range(4):
dev_ptr, _ = acl.rt.malloc(data_size, 0)
blob = Blob(dev_ptr, data_size)
tmp_batch_list.append(blob)
blob_list.append(DeviceBlobList(device_id, tmp_batch_list))
return blob_list