"""
Object cache client python interface Test.
"""
from __future__ import absolute_import
import math
import os
import random
import json
import threading
import time
import unittest
from yr.datasystem.object_client import Buffer, ConsistencyType, ObjectClient
class TestOcClientMethods(unittest.TestCase):
"""
Features: Object cache client python interface test.
"""
@classmethod
def setUpClass(cls):
root_dir = os.path.dirname(os.path.abspath('..'))
worker_env_path = os.path.join(root_dir, "output", "datasystem", "service", "worker_config.json")
with open(worker_env_path, "r") as f:
config = json.load(f)
work_address = config.get("worker_address", {})
TestOcClientMethods.work_addr = work_address.get("value")
TestOcClientMethods.host, port = TestOcClientMethods.work_addr.split(":")
TestOcClientMethods.port = int(port)
TestOcClientMethods.BUFFER = Buffer()
@staticmethod
def random_str(slen=10):
"""
Features: Generate a random string for test.
"""
seed = "1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!@#%^*()_+=-"
sa = []
for _ in range(slen):
sa.append(random.choice(seed))
return ''.join(sa)
@staticmethod
def write_buffer(value):
"""
Features: Write data to the buffer, for concurrency test only.
"""
TestOcClientMethods.BUFFER.wlatch()
TestOcClientMethods.BUFFER.memory_copy(value)
TestOcClientMethods.BUFFER.unwlatch()
@staticmethod
def read_buffer():
"""
Features: Read data from the buffer, for concurrency test only.
"""
TestOcClientMethods.BUFFER.rlatch()
value = TestOcClientMethods.BUFFER.immutable_data()
TestOcClientMethods.BUFFER.unrlatch()
return value
@staticmethod
def init_test_client():
"""
Features: Init test client.
"""
client = ObjectClient(TestOcClientMethods.host, TestOcClientMethods.port)
client.init()
return client
def test_client_put_get_succeed(self):
"""
Features: Put an object successfully.
"""
client = self.init_test_client()
object_key = self.random_str(10)
value = bytes(self.random_str(50), encoding='utf8')
param = {"consistency_type": ConsistencyType.PRAM}
client.put(object_key, value, param)
buffer_list = client.get([object_key], 5)
self.assertEqual(value, buffer_list[0].immutable_data())
def test_client_get_immutable_and_release_buffer(self):
"""
Features: Put an object successfully.
"""
client = self.init_test_client()
object_key = self.random_str(10)
value = bytes(self.random_str(50), encoding='utf8')
param = {"consistency_type": ConsistencyType.PRAM}
client.put(object_key, value, param)
buffer_list = client.get([object_key], 5)
read_data = buffer_list[0].immutable_data()
buffer_list = None
self.assertEqual(value, read_data)
def test_client_create_publish_succeed(self):
"""
Features: Put an object successfully.
"""
client = self.init_test_client()
object_key = self.random_str(10)
value = bytes(self.random_str(50), encoding='utf8')
size = len(value)
param = {"consistency_type": ConsistencyType.PRAM}
buffer = client.create(object_key, size, param)
self.assertEqual(buffer.get_size(), size)
buffer.wlatch()
buffer.memory_copy(value)
buffer.publish()
buffer.unwlatch()
def test_create_publish_get_bytearray_succeed(self):
"""
Features: Put and get a bytearray type data successfully.
"""
client = self.init_test_client()
object_key = self.random_str(10)
value = bytearray(self.random_str(50), encoding='utf8')
size = len(value)
buffer = client.create(object_key, size)
buffer.wlatch()
buffer.memory_copy(value)
buffer.seal()
buffer.unwlatch()
buffer_list = client.get([object_key], 5)
self.assertEqual(value, buffer_list[0].immutable_data())
def test_client_put_get_memoryview_succeed(self):
"""
Features: Put and get a memoryview type data successfully.
"""
client = self.init_test_client()
object_key = self.random_str(10)
byte_value = bytes(self.random_str(50), encoding='utf8')
value = memoryview(byte_value)
size = len(value)
buffer = client.create(object_key, size)
buffer.wlatch()
buffer.memory_copy(value)
buffer.seal()
buffer.unwlatch()
buffer_list = client.get([object_key], 5)
self.assertEqual(value, buffer_list[0].immutable_data())
def test_client_get_succeed(self):
"""
Features: Get some objects successfully.
"""
client = self.init_test_client()
object_key = self.random_str(10)
value = bytes(self.random_str(50), encoding='utf8')
buffer1 = client.create(object_key, len(value))
buffer1.wlatch()
buffer1.memory_copy(value)
buffer1.seal()
buffer1.unwlatch()
object_key2 = self.random_str(10)
value2 = bytes(self.random_str(100), encoding='utf8')
buffer2 = client.create(object_key2, len(value2))
buffer2.wlatch()
buffer2.memory_copy(value2)
buffer2.seal()
buffer2.unwlatch()
buffer_list = client.get([object_key, object_key2], 5)
self.assertEqual(value, buffer_list[0].immutable_data().tobytes())
self.assertEqual(value2, buffer_list[1].immutable_data().tobytes())
client2 = self.init_test_client()
buffer_list2 = client2.get([object_key, object_key2], 5)
self.assertEqual(value, buffer_list2[0].immutable_data().tobytes())
self.assertEqual(value2, buffer_list2[1].immutable_data().tobytes())
def test_client_get_failed(self):
"""
Features: Get an object that does not exist
"""
client = self.init_test_client()
object_key = self.random_str(10)
value = bytes(self.random_str(50), encoding='utf8')
buffer = client.create(object_key, len(value))
buffer.wlatch()
buffer.memory_copy(value)
buffer.seal()
buffer.unwlatch()
buffer_list = client.get([object_key, self.random_str(10)], 5)
self.assertEqual(value, buffer_list[0].immutable_data().tobytes())
self.assertIsNone(buffer_list[1])
@staticmethod
def init_client_timeout_with_interval(client) -> int:
"""
Features: Configure the client timeout interval, run and return the duration
"""
timeout_start = time.time()
try:
client.init()
except RuntimeError:
timeout_end = time.time()
else:
timeout_end = time.time()
return math.floor(timeout_end - timeout_start)
def test_client_init_timeout(self):
"""
Features: Test timeout case of client init
"""
host, port = "127.0.0.1", 18888
client = ObjectClient(host, port, 5 * 1000, "")
duration = self.init_client_timeout_with_interval(client)
self.assertLessEqual(duration, 6, "duration is {time} sec, want {interval}".format(time=duration, interval=5))
def test_client_increase_decrease_global_ref(self):
"""
Features: Test increase and decrease ref count
"""
client = self.init_test_client()
object_key = self.random_str(10)
client.g_increase_ref([object_key])
self.assertEqual(client.g_decrease_ref([object_key, "not_exist_key"]), [])
def test_client_put_get_big_bytes_succeed(self):
"""
Features: Put and get a big bytes object successfully.
"""
client = self.init_test_client()
value_size = 5 * 1024 * 1024
object_key1 = self.random_str(10)
value1 = bytes("x" * value_size, encoding='utf8')
buffer1 = client.create(object_key1, len(value1))
buffer1.wlatch()
buffer1.memory_copy(value1)
buffer1.seal()
buffer1.unwlatch()
buffer_list1 = client.get([object_key1], 5)
self.assertEqual(value1, buffer_list1[0].immutable_data())
def test_client_put_get_big_bytearray_succeed(self):
"""
Features: Put and get a big bytearray object successfully.
"""
client = self.init_test_client()
value_size = 5 * 1024 * 1024
object_key2 = self.random_str(10)
value2 = bytearray("x" * value_size, encoding='utf8')
buffer2 = client.create(object_key2, len(value2))
buffer2.wlatch()
buffer2.memory_copy(value2)
buffer2.seal()
buffer2.unwlatch()
buffer_list2 = client.get([object_key2], 5)
self.assertEqual(value2, buffer_list2[0].immutable_data())
def test_client_put_get_big_memoryview_succeed(self):
"""
Features: Put and get a big memoryview object successfully.
"""
client = self.init_test_client()
value_size = 5 * 1024 * 1024
object_key3 = self.random_str(10)
byte_value = bytes("x" * value_size, encoding='utf8')
value3 = memoryview(byte_value)
buffer3 = client.create(object_key3, len(value3))
buffer3.wlatch()
buffer3.memory_copy(value3)
buffer3.seal()
buffer3.unwlatch()
buffer_list3 = client.get([object_key3], 5)
self.assertEqual(value3, buffer_list3[0].immutable_data())
def test_concurrent_write_read_buffer(self):
"""
Features: Write data concurrently to the buffer.
"""
client = self.init_test_client()
object_key = self.random_str(10)
value = bytes(self.random_str(50), encoding='utf8')
TestOcClientMethods.BUFFER = client.create(object_key, len(value))
thread_num = 5
threads = []
i = 0
while i in range(thread_num):
t = threading.Thread(target=self.write_buffer, args=(value,))
threads.append(t)
i += 1
for t in threads:
t.start()
for t in threads:
t.join()
buffer_val = self.read_buffer()
self.assertEqual(value, buffer_val)
def test_invalidate_buffer(self):
"""
Features: Invalidate the buffer.
"""
client = self.init_test_client()
object_key = self.random_str(10)
value = bytes(self.random_str(50), encoding='utf8')
buffer = client.create(object_key, len(value))
buffer.wlatch()
buffer.memory_copy(value)
buffer.publish()
buffer.unwlatch()
buffer.invalidate_buffer()
self.assertRaises(RuntimeError, client.get, [object_key], 5)
def test_buffer_immutable_data(self):
"""
Features: Get immutable memoryview from the buffer.
"""
client = self.init_test_client()
object_key = self.random_str(10)
value = bytearray(self.random_str(50), encoding='utf8')
buffer = client.create(object_key, len(value))
immutable_value = buffer.immutable_data()
self.assertTrue(immutable_value.readonly)
def test_buffer_mutable_data(self):
"""
Features: Get mutable memoryview from the buffer.
"""
client = self.init_test_client()
object_key = self.random_str(10)
value = bytearray(self.random_str(50), encoding='utf8')
buffer = client.create(object_key, len(value))
mutable_value = buffer.mutable_data()
self.assertFalse(mutable_value.readonly)
def test_nest_publish(self):
"""
Features: Test nested publish.
"""
client = self.init_test_client()
object1 = "key_1"
object2 = "key_2"
object3 = "key_3"
value = bytes(self.random_str(50), encoding='utf8')
client.g_increase_ref([object1, object2, object3])
self.assertEqual(client.query_global_ref_num(object1), 1)
self.assertEqual(client.query_global_ref_num(object2), 1)
self.assertEqual(client.query_global_ref_num(object3), 1)
param = {"consistency_type": ConsistencyType.PRAM}
client.put(object2, value, param)
client.put(object3, value, param)
buffer = client.create(object1, len(value))
buffer.wlatch()
buffer.memory_copy(value)
buffer.publish([object2, object3])
buffer.unwlatch()
buffer_list = client.get([object1, object2, object3], 0)
self.assertEqual(len(buffer_list), 3)
self.assertEqual(value, buffer_list[0].immutable_data())
self.assertEqual(value, buffer_list[1].immutable_data())
self.assertEqual(value, buffer_list[2].immutable_data())
client.g_decrease_ref([object2, object3])
self.assertEqual(client.query_global_ref_num(object1), 1)
self.assertEqual(client.query_global_ref_num(object2), 0)
self.assertEqual(client.query_global_ref_num(object3), 0)
self.assertEqual(len(buffer_list), 3)
buffer_list = client.get([object1, object2, object3], 0)
self.assertEqual(value, buffer_list[0].immutable_data())
self.assertEqual(value, buffer_list[1].immutable_data())
self.assertEqual(value, buffer_list[2].immutable_data())
client.g_decrease_ref([object1])
self.assertEqual(client.query_global_ref_num(object1), 0)
self.assertEqual(client.query_global_ref_num(object2), 0)
self.assertEqual(client.query_global_ref_num(object3), 0)
self.assertRaises(RuntimeError, client.get, [object1, object2, object3], 0)
def test_generate_object_id(self):
"""
Features: Test generate object key.
"""
client = self.init_test_client()
object_key = client.generate_object_id("abc")
self.assertEqual(client.g_increase_ref([object_key]), [])
value = bytes(self.random_str(50), encoding='utf8')
buffer1 = client.create(object_key, len(value))
buffer1.wlatch()
buffer1.memory_copy(value)
buffer1.seal()
buffer1.unwlatch()
self.assertEqual(client.g_decrease_ref([object_key]), [])