import time
from threading import Thread
from unittest.mock import patch
import pytest
import torch
import zmq
from tensordict import NonTensorStack, TensorDict
from transfer_queue import TransferQueueClient
from transfer_queue.metadata import BatchMeta
from transfer_queue.utils.enum_utils import Role
from transfer_queue.utils.zmq_utils import (
ZMQMessage,
ZMQRequestType,
ZMQServerInfo,
)
TEST_DATA = TensorDict(
{
"log_probs": [torch.tensor([1.0, 2.0, 3.0]), torch.tensor([4.0, 5.0, 6.0]), torch.tensor([7.0, 8.0, 9.0])],
"variable_length_sequences": torch.nested.as_nested_tensor(
[
torch.tensor([-0.5, -1.2, -0.8]),
torch.tensor([-0.3, -1.5, -2.1, -0.9]),
torch.tensor([-1.1, -0.7]),
],
layout=torch.jagged,
),
"prompt_text": ["Hello world!", "This is a longer sentence for testing", "Test case"],
},
batch_size=[3],
)
class MockController:
def __init__(self, controller_id="controller_0"):
self.controller_id = controller_id
self.context = zmq.Context()
self.request_socket = self.context.socket(zmq.ROUTER)
self.request_port = self._bind_to_random_port(self.request_socket)
self.zmq_server_info = ZMQServerInfo(
role=Role.CONTROLLER,
id=controller_id,
ip="127.0.0.1",
ports={
"request_handle_socket": self.request_port,
},
)
self.running = True
self.request_thread = Thread(target=self._handle_requests, daemon=True)
self.request_thread.start()
def _bind_to_random_port(self, socket):
port = socket.bind_to_random_port("tcp://127.0.0.1")
return port
def _handle_requests(self):
poller = zmq.Poller()
poller.register(self.request_socket, zmq.POLLIN)
while self.running:
try:
socks = dict(poller.poll(100))
if self.request_socket in socks:
messages = self.request_socket.recv_multipart(copy=False)
identity = messages.pop(0)
serialized_msg = messages
request_msg = ZMQMessage.deserialize(serialized_msg)
if request_msg.request_type == ZMQRequestType.GET_META:
response_body = self._mock_batch_meta(request_msg.body)
response_type = ZMQRequestType.GET_META_RESPONSE
elif request_msg.request_type == ZMQRequestType.CLEAR_META:
response_body = {"message": "clear meta ok"}
response_type = ZMQRequestType.CLEAR_META_RESPONSE
elif request_msg.request_type == ZMQRequestType.CLEAR_PARTITION:
response_body = {"message": "clear partition ok"}
response_type = ZMQRequestType.CLEAR_PARTITION_RESPONSE
elif request_msg.request_type == ZMQRequestType.GET_PARTITION_META:
response_body = self._mock_batch_meta(request_msg.body)
response_type = ZMQRequestType.GET_PARTITION_META_RESPONSE
elif request_msg.request_type == ZMQRequestType.GET_CONSUMPTION:
response_body = {
"partition_id": request_msg.body.get("partition_id"),
"global_index": torch.tensor([0, 1, 2]),
"consumption_status": torch.tensor([1, 1, 1]),
}
response_type = ZMQRequestType.CONSUMPTION_RESPONSE
elif request_msg.request_type == ZMQRequestType.GET_PRODUCTION:
response_body = {
"partition_id": request_msg.body.get("partition_id"),
"global_index": torch.tensor([0, 1, 2]),
"production_status": torch.tensor([[1, 1, 1], [1, 1, 1]]),
}
response_type = ZMQRequestType.PRODUCTION_RESPONSE
elif request_msg.request_type == ZMQRequestType.GET_LIST_PARTITIONS:
response_body = {
"partition_ids": ["partition_0", "partition_1", "test_partition"],
}
response_type = ZMQRequestType.LIST_PARTITIONS_RESPONSE
elif request_msg.request_type == ZMQRequestType.SET_CUSTOM_META:
response_body = {"message": "success"}
response_type = ZMQRequestType.SET_CUSTOM_META_RESPONSE
elif request_msg.request_type == ZMQRequestType.RESET_CONSUMPTION:
response_body = {
"success": True,
"message": "Consumption reset successfully",
}
response_type = ZMQRequestType.RESET_CONSUMPTION_RESPONSE
elif request_msg.request_type == ZMQRequestType.KV_RETRIEVE_META:
response_body = self._mock_kv_retrieve_meta(request_msg.body)
response_type = ZMQRequestType.KV_RETRIEVE_META_RESPONSE
elif request_msg.request_type == ZMQRequestType.KV_RETRIEVE_KEYS:
response_body = self._mock_kv_retrieve_keys(request_msg.body)
response_type = ZMQRequestType.KV_RETRIEVE_KEYS_RESPONSE
elif request_msg.request_type == ZMQRequestType.KV_LIST:
response_body = self._mock_kv_list(request_msg.body)
response_type = ZMQRequestType.KV_LIST_RESPONSE
else:
response_body = {"error": f"Unknown request type: {request_msg.request_type}"}
response_type = ZMQRequestType.CLEAR_META_RESPONSE
response_msg = ZMQMessage.create(
request_type=response_type,
sender_id=self.controller_id,
receiver_id=request_msg.sender_id,
body=response_body,
)
self.request_socket.send_multipart([identity, *response_msg.serialize()])
except zmq.Again:
continue
except Exception as e:
print(f"MockController ERROR: {e}")
raise
def _mock_batch_meta(self, request_body):
batch_size = request_body.get("batch_size", 1)
data_fields = request_body.get("data_fields", [])
field_schema = {
field_name: {"dtype": None, "shape": None, "is_nested": False, "is_non_tensor": False}
for field_name in data_fields
}
metadata = BatchMeta(
global_indexes=list(range(batch_size)),
partition_ids=["0"] * batch_size,
field_schema=field_schema,
)
return {"metadata": metadata}
def _mock_kv_retrieve_meta(self, request_body):
"""Mock KV retrieve keys response."""
keys = request_body.get("keys", [])
create = request_body.get("create", False)
partition_id = request_body.get("partition_id", "")
if not hasattr(self, "_kv_partition_keys"):
self._kv_partition_keys = {}
start_index = self._get_next_kv_index(partition_id)
global_indexes = list(range(start_index, start_index + len(keys)))
field_schema = {
"data": {"dtype": "torch.float32", "shape": [1, 10], "is_nested": False, "is_non_tensor": False}
}
import numpy as np
production_status = np.ones(len(global_indexes), dtype=np.int8)
metadata = BatchMeta(
global_indexes=global_indexes,
partition_ids=[partition_id] * len(global_indexes),
field_schema=field_schema,
production_status=production_status,
)
if create:
if partition_id not in self._kv_partition_keys:
self._kv_partition_keys[partition_id] = []
self._kv_partition_keys[partition_id].extend(keys)
if global_indexes:
self._update_kv_index(partition_id, global_indexes[-1] + 1)
return {"metadata": metadata}
def _mock_kv_list(self, request_body):
"""Mock KV list response."""
partition_id = request_body.get("partition_id", None)
if not hasattr(self, "_kv_partition_keys"):
self._kv_partition_keys = {}
keys = self._kv_partition_keys.get(partition_id, [])
return {"partition_info": {partition_id: {k: {} for k in keys}}, "message": "success"}
def _mock_kv_retrieve_keys(self, request_body):
"""Mock KV retrieve indexes response."""
global_indexes = request_body.get("global_indexes", [])
partition_id = request_body.get("partition_id", "")
if not hasattr(self, "_kv_partition_keys"):
self._kv_partition_keys = {}
if not hasattr(self, "_kv_index_to_key"):
self._kv_index_to_key = {}
partition_keys = self._kv_partition_keys.get(partition_id, [])
if not hasattr(self, "_kv_partition_index_map"):
self._kv_partition_index_map = {}
if partition_id not in self._kv_partition_index_map:
start_idx = self._get_next_kv_index(partition_id) - len(partition_keys)
self._kv_partition_index_map[partition_id] = {}
for i, key in enumerate(partition_keys):
self._kv_partition_index_map[partition_id][start_idx + i] = key
index_map = self._kv_partition_index_map.get(partition_id, {})
keys = []
for idx in global_indexes:
keys.append(index_map.get(idx, None))
return {"keys": keys}
def _get_next_kv_index(self, partition_id):
"""Get next available index for KV keys in partition."""
if not hasattr(self, "_kv_index_map"):
self._kv_index_map = {}
if partition_id not in self._kv_index_map:
self._kv_index_map[partition_id] = 0
if not hasattr(self, "_kv_partition_keys"):
self._kv_partition_keys = {}
self._kv_partition_keys[partition_id] = []
return self._kv_index_map[partition_id]
def _update_kv_index(self, partition_id, next_index):
"""Update next available index for KV keys."""
if not hasattr(self, "_kv_index_map"):
self._kv_index_map = {}
self._kv_index_map[partition_id] = next_index
def stop(self):
self.running = False
time.sleep(0.2)
self.request_socket.close()
self.context.term()
class MockStorage:
def __init__(self, storage_id="storage_0"):
self.storage_id = storage_id
self.context = zmq.Context()
self.data_socket = self.context.socket(zmq.ROUTER)
self.data_port = self._bind_to_random_port(self.data_socket)
self.zmq_server_info = ZMQServerInfo(
role=Role.STORAGE,
id=storage_id,
ip="127.0.0.1",
ports={
"put_get_socket": self.data_port,
},
)
self.running = True
self.data_thread = Thread(target=self._handle_data_requests, daemon=True)
self.data_thread.start()
def _bind_to_random_port(self, socket):
port = socket.bind_to_random_port("tcp://127.0.0.1")
return port
def _handle_data_requests(self):
poller = zmq.Poller()
poller.register(self.data_socket, zmq.POLLIN)
while self.running:
try:
socks = dict(poller.poll(100))
if self.data_socket in socks:
messages = self.data_socket.recv_multipart(copy=False)
identity = messages.pop(0)
serialized_msg = messages
msg = ZMQMessage.deserialize(serialized_msg)
if msg.request_type == ZMQRequestType.PUT_DATA:
response_body = {"message": "Data stored successfully"}
response_type = ZMQRequestType.PUT_DATA_RESPONSE
elif msg.request_type == ZMQRequestType.GET_DATA:
response_body = self._handle_get_data(msg.body)
response_type = ZMQRequestType.GET_DATA_RESPONSE
elif msg.request_type == ZMQRequestType.CLEAR_DATA:
response_body = {"message": "Data cleared successfully"}
response_type = ZMQRequestType.CLEAR_DATA_RESPONSE
response_msg = ZMQMessage.create(
request_type=response_type,
sender_id=self.storage_id,
receiver_id=msg.sender_id,
body=response_body,
)
self.data_socket.send_multipart([identity, *response_msg.serialize()])
except zmq.Again:
continue
except Exception as e:
if self.running:
print(f"MockStorage running exception: {e}")
else:
print(f"MockStorage ERROR: {e}")
raise
def _handle_get_data(self, request_body):
"""Handle GET_DATA request by retrieving stored data"""
global_indexes = request_body.get("global_indexes", [])
fields = request_body.get("fields", [])
result: dict[str, list] = {}
for field in fields:
gathered_items = [TEST_DATA[field][i] for i in global_indexes]
if gathered_items:
all_tensors = all(isinstance(x, torch.Tensor) for x in gathered_items)
if all_tensors:
result[field] = torch.nested.as_nested_tensor(gathered_items, layout=torch.jagged)
else:
result[field] = NonTensorStack(*gathered_items)
return {"data": TensorDict(result)}
def stop(self):
self.running = False
time.sleep(0.2)
self.data_socket.close()
self.context.term()
@pytest.fixture
def mock_controller():
controller = MockController()
yield controller
controller.stop()
@pytest.fixture
def mock_storage():
storage = MockStorage()
yield storage
storage.stop()
@pytest.fixture
def client_setup(mock_controller, mock_storage):
client_id = "client_0"
client = TransferQueueClient(
client_id=client_id,
controller_info=mock_controller.zmq_server_info,
)
with patch(
"transfer_queue.storage.managers.simple_storage_manager.AsyncSimpleStorageManager._connect_to_controller"
):
config = {
"controller_info": mock_controller.zmq_server_info,
"zmq_info": {mock_storage.storage_id: mock_storage.zmq_server_info},
}
client.initialize_storage_manager(manager_type="SimpleStorage", config=config)
async def mock_put_data(data, metadata, data_parser=None):
pass
async def mock_get_data(metadata):
return TEST_DATA
async def mock_clear_data(metadata):
pass
client.storage_manager.put_data = mock_put_data
client.storage_manager.get_data = mock_get_data
client.storage_manager.clear_data = mock_clear_data
yield client, mock_controller, mock_storage
def test_client_initialization(client_setup):
"""Test client initialization and connection setup"""
client, mock_controller, mock_storage = client_setup
assert client.client_id is not None
assert client._controller is not None
assert client._controller.id == mock_controller.controller_id
def test_put_and_get_data(client_setup):
"""Test basic put and get operations"""
client, _, _ = client_setup
client.put(data=TEST_DATA, partition_id="0")
metadata = client.get_meta(
data_fields=["log_probs", "variable_length_sequences", "prompt_text"], batch_size=2, partition_id="0"
)
result = client.get_data(metadata)
assert "log_probs" in result
assert "variable_length_sequences" in result
assert "prompt_text" in result
torch.testing.assert_close(result["log_probs"][0], torch.tensor([1.0, 2.0, 3.0]))
torch.testing.assert_close(result["log_probs"][1], torch.tensor([4.0, 5.0, 6.0]))
torch.testing.assert_close(result["variable_length_sequences"][0], torch.tensor([-0.5, -1.2, -0.8]))
torch.testing.assert_close(result["variable_length_sequences"][1], torch.tensor([-0.3, -1.5, -2.1, -0.9]))
assert result["prompt_text"][0] == "Hello world!"
assert result["prompt_text"][1] == "This is a longer sentence for testing"
def test_get_meta(client_setup):
"""Test metadata retrieval"""
client, _, _ = client_setup
metadata = client.get_meta(data_fields=["tokens", "labels"], batch_size=10, partition_id="0")
assert hasattr(metadata, "global_indexes")
assert hasattr(metadata, "field_names")
assert hasattr(metadata, "size")
assert len(metadata.global_indexes) == 10
def test_single_controller_multiple_storages():
"""Test client with single controller and multiple storage units"""
controller = MockController("controller_0")
storages = [MockStorage(f"storage_{i}") for i in range(3)]
try:
client_id = "client_test_single_controller"
client = TransferQueueClient(client_id=client_id, controller_info=controller.zmq_server_info)
with patch(
"transfer_queue.storage.managers.simple_storage_manager.AsyncSimpleStorageManager._connect_to_controller"
):
config = {
"controller_info": controller.zmq_server_info,
"zmq_info": {s.storage_id: s.zmq_server_info for s in storages},
}
client.initialize_storage_manager(manager_type="SimpleStorage", config=config)
async def mock_put_data(data, metadata, data_parser=None):
pass
async def mock_get_data(metadata):
return TensorDict({"tokens": torch.randint(0, 100, (5, 128))}, batch_size=5)
async def mock_clear_data(metadata):
pass
client.storage_manager.put_data = mock_put_data
client.storage_manager.get_data = mock_get_data
client.storage_manager.clear_data = mock_clear_data
assert client._controller is not None
assert client._controller.id == controller.controller_id
test_data = TensorDict({"tokens": torch.randint(0, 100, (5, 128))}, batch_size=5)
client.put(data=test_data, partition_id="0")
finally:
controller.stop()
for s in storages:
s.stop()
def test_put_without_required_params(client_setup):
"""Test put operation without required parameters"""
client, _, _ = client_setup
test_data = TensorDict({"tokens": torch.randint(0, 100, (5, 128))}, batch_size=5)
with pytest.raises(ValueError):
client.put(data=test_data)
def test_check_consumption_status(client_setup):
"""Test consumption status checking"""
client, _, _ = client_setup
is_consumed = client.check_consumption_status(task_name="generate_sequences", partition_id="train_0")
assert is_consumed is True
def test_check_production_status(client_setup):
"""Test production status checking"""
client, _, _ = client_setup
is_produced = client.check_production_status(data_fields=["prompt_ids", "attention_mask"], partition_id="train_0")
assert is_produced is True
def test_get_consumption_status(client_setup):
"""Test get_consumption_status - returns global_index and consumption_status tensors"""
client, _, _ = client_setup
global_index, consumption_status = client.get_consumption_status(
task_name="generate_sequences", partition_id="train_0"
)
assert global_index is not None
assert consumption_status is not None
assert torch.equal(global_index, torch.tensor([0, 1, 2], dtype=torch.long))
expected_status = torch.tensor([1, 1, 1], dtype=torch.int8)
assert torch.equal(consumption_status, expected_status)
print("✓ get_consumption_status returns correct global_index and consumption_status")
def test_get_production_status(client_setup):
"""Test get_production_status - returns global_index and production_status tensors"""
client, _, _ = client_setup
global_index, production_status = client.get_production_status(
data_fields=["prompt_ids", "attention_mask"], partition_id="train_0"
)
assert global_index is not None
assert production_status is not None
assert torch.equal(global_index, torch.tensor([0, 1, 2], dtype=torch.long))
expected_status = torch.tensor([[1, 1, 1], [1, 1, 1]], dtype=torch.int8)
assert torch.equal(production_status, expected_status)
print("✓ get_production_status returns correct global_index and production_status")
def test_get_partition_list(client_setup):
"""Test partition list retrieval"""
client, _, _ = client_setup
partition_list = client.get_partition_list()
assert isinstance(partition_list, list)
assert len(partition_list) > 0
assert "partition_0" in partition_list
assert "partition_1" in partition_list
assert "test_partition" in partition_list
def test_reset_consumption(client_setup):
"""Test synchronous reset_consumption - resets consumption status for a partition"""
client, _, _ = client_setup
success = client.reset_consumption(partition_id="train_0", task_name="generate_sequences")
assert success is True
print("✓ reset_consumption with task_name returns True")
def test_reset_consumption_all_tasks(client_setup):
"""Test synchronous reset_consumption without task_name (resets all tasks)"""
client, _, _ = client_setup
success = client.reset_consumption(partition_id="train_0")
assert success is True
print("✓ reset_consumption without task_name (all tasks) returns True")
@pytest.mark.asyncio
async def test_async_reset_consumption(client_setup):
"""Test async reset_consumption - resets consumption status for a partition"""
client, _, _ = client_setup
success = await client.async_reset_consumption(partition_id="train_0", task_name="generate_sequences")
assert success is True
print("✓ async_reset_consumption with task_name returns True")
@pytest.mark.asyncio
async def test_async_reset_consumption_all_tasks(client_setup):
"""Test async reset_consumption without task_name (resets all tasks)"""
client, _, _ = client_setup
success = await client.async_reset_consumption(partition_id="train_0")
assert success is True
print("✓ async_reset_consumption without task_name (all tasks) returns True")
@pytest.mark.asyncio
async def test_async_check_consumption_status(client_setup):
"""Test async consumption status checking"""
client, _, _ = client_setup
is_consumed = await client.async_check_consumption_status(task_name="generate_sequences", partition_id="train_0")
assert is_consumed is True
@pytest.mark.asyncio
async def test_async_check_production_status(client_setup):
"""Test async production status checking"""
client, _, _ = client_setup
is_produced = await client.async_check_production_status(
data_fields=["prompt_ids", "attention_mask"], partition_id="train_0"
)
assert is_produced is True
@pytest.mark.asyncio
async def test_async_get_consumption_status(client_setup):
"""Test async get_consumption_status - returns global_index and consumption_status tensors"""
client, _, _ = client_setup
global_index, consumption_status = await client.async_get_consumption_status(
task_name="generate_sequences", partition_id="train_0"
)
assert global_index is not None
assert consumption_status is not None
assert torch.equal(global_index, torch.tensor([0, 1, 2], dtype=torch.long))
expected_status = torch.tensor([1, 1, 1], dtype=torch.int8)
assert torch.equal(consumption_status, expected_status)
print("✓ async_get_consumption_status returns correct global_index and consumption_status")
@pytest.mark.asyncio
async def test_async_get_production_status(client_setup):
"""Test async get_production_status - returns global_index and production_status tensors"""
client, _, _ = client_setup
global_index, production_status = await client.async_get_production_status(
data_fields=["prompt_ids", "attention_mask"], partition_id="train_0"
)
assert global_index is not None
assert production_status is not None
assert torch.equal(global_index, torch.tensor([0, 1, 2], dtype=torch.long))
expected_status = torch.tensor([[1, 1, 1], [1, 1, 1]], dtype=torch.int8)
assert torch.equal(production_status, expected_status)
print("✓ async_get_production_status returns correct global_index and production_status")
@pytest.mark.asyncio
async def test_async_get_partition_list(client_setup):
"""Test async partition list retrieval"""
client, _, _ = client_setup
partition_list = await client.async_get_partition_list()
assert isinstance(partition_list, list)
assert len(partition_list) > 0
assert "partition_0" in partition_list
assert "partition_1" in partition_list
assert "test_partition" in partition_list
@pytest.mark.asyncio
async def test_async_clear_partition(client_setup):
"""Test async clear partition operation"""
client, _, _ = client_setup
await client.async_clear_partition(partition_id="test_partition")
assert True
@pytest.mark.asyncio
async def test_async_clear_samples(client_setup):
"""Test async clear samples operation"""
client, _, _ = client_setup
metadata = await client.async_get_meta(data_fields=["tokens", "labels"], batch_size=2, partition_id="0")
await client.async_clear_samples(metadata=metadata)
assert True
def test_clear_partition(client_setup):
"""Test synchronous clear partition operation"""
client, _, _ = client_setup
client.clear_partition(partition_id="test_partition")
assert True
def test_clear_samples(client_setup):
"""Test synchronous clear samples operation"""
client, _, _ = client_setup
metadata = client.get_meta(data_fields=["tokens", "labels"], batch_size=2, partition_id="0")
client.clear_samples(metadata=metadata)
assert True
@pytest.mark.asyncio
async def test_async_clear_samples_with_empty_metadata(client_setup):
"""Test async_clear_samples with empty BatchMeta"""
client, _, _ = client_setup
metadata = BatchMeta(global_indexes=[], partition_ids=[], field_schema={})
await client.async_clear_samples(metadata=metadata)
assert True
@pytest.mark.asyncio
async def test_sync_methods_work_in_async_event_loop(client_setup):
"""Test all synchronous methods can be called from within an asyncio event loop.
This test verifies that the sync methods can be called directly from an async
function without causing "asyncio.run() cannot be called from a running loop" errors.
"""
client, _, _ = client_setup
test_data = TensorDict({"tokens": torch.randint(0, 100, (3, 64))}, batch_size=3)
metadata = client.put(data=test_data, partition_id="0")
assert metadata is not None
metadata = client.get_meta(
data_fields=["log_probs", "variable_length_sequences", "prompt_text"], batch_size=2, partition_id="0"
)
assert metadata is not None
assert len(metadata.global_indexes) == 2
result = client.get_data(metadata)
assert result is not None
assert "log_probs" in result
assert "prompt_text" in result
is_consumed = client.check_consumption_status(task_name="generate_sequences", partition_id="train_0")
assert isinstance(is_consumed, bool)
global_index, consumption_status = client.get_consumption_status(
task_name="generate_sequences", partition_id="train_0"
)
assert global_index is not None
assert consumption_status is not None
is_produced = client.check_production_status(data_fields=["log_probs", "prompt_text"], partition_id="train_0")
assert isinstance(is_produced, bool)
global_index, production_status = client.get_production_status(
data_fields=["log_probs", "prompt_text"], partition_id="train_0"
)
assert global_index is not None
assert production_status is not None
partition_list = client.get_partition_list()
assert isinstance(partition_list, list)
assert len(partition_list) > 0
client.clear_partition(partition_id="test_partition")
metadata = client.get_meta(data_fields=["log_probs", "prompt_text"], batch_size=2, partition_id="0")
client.clear_samples(metadata=metadata)
print("✓ All sync methods work correctly when called from within asyncio event loop")
@pytest.mark.asyncio
async def test_sync_and_async_methods_mixed_usage(client_setup):
"""Test mixing sync and async method calls within the same async context.
This test verifies that async methods and sync methods can be used interchangeably
without conflicts when called from an async function.
"""
client, _, _ = client_setup
test_data = TensorDict({"tokens": torch.randint(0, 100, (2, 32))}, batch_size=2)
sync_put_result = client.put(data=test_data, partition_id="0")
assert sync_put_result is not None
async_metadata = await client.async_get_meta(data_fields=["tokens"], batch_size=2, partition_id="0")
assert async_metadata is not None
sync_get_meta_result = client.get_meta(data_fields=["tokens"], batch_size=2, partition_id="0")
assert sync_get_meta_result is not None
async_data = await client.async_get_data(sync_get_meta_result)
assert async_data is not None
print("✓ Mixed async and sync method calls work correctly")
class TestClientCustomMetaInterface:
"""Tests for client custom_meta interface methods."""
def test_set_custom_meta_sync(self, client_setup):
"""Test synchronous set_custom_meta method."""
client, _, _ = client_setup
metadata = client.get_meta(data_fields=["input_ids"], batch_size=2, partition_id="0")
metadata.update_custom_meta(
[
{"input_ids": {"token_count": 100}},
{"input_ids": {"token_count": 120}},
]
)
client.set_custom_meta(metadata)
print("✓ set_custom_meta sync method works")
@pytest.mark.asyncio
async def test_set_custom_meta_async(self, client_setup):
"""Test asynchronous async_set_custom_meta method."""
client, _, _ = client_setup
metadata = await client.async_get_meta(data_fields=["input_ids"], batch_size=2, partition_id="0")
metadata.update_custom_meta(
[
{"input_ids": {"token_count": 100}},
{"input_ids": {"token_count": 120}},
]
)
await client.async_set_custom_meta(metadata)
print("✓ async_set_custom_meta async method works")
class TestClientKVInterface:
"""Tests for client KV interface methods."""
@pytest.mark.asyncio
async def test_async_kv_retrieve_meta_single(self, client_setup):
"""Test async_kv_retrieve_meta with single key."""
client, _, _ = client_setup
metadata = await client.async_kv_retrieve_meta(
keys="test_key_1",
partition_id="test_partition",
create=True,
)
assert metadata is not None
assert hasattr(metadata, "global_indexes")
assert hasattr(metadata, "size")
assert metadata.size == 1
@pytest.mark.asyncio
async def test_async_kv_retrieve_meta_multiple(self, client_setup):
"""Test async_kv_retrieve_meta with multiple keys."""
client, _, _ = client_setup
keys = ["key_a", "key_b", "key_c"]
metadata = await client.async_kv_retrieve_meta(
keys=keys,
partition_id="test_partition",
create=True,
)
assert metadata is not None
assert hasattr(metadata, "global_indexes")
assert hasattr(metadata, "size")
assert metadata.size == 3
@pytest.mark.asyncio
async def test_async_kv_retrieve_meta_create_false(self, client_setup):
"""Test async_kv_retrieve_meta with create=False (retrieve existing keys)."""
client, _, _ = client_setup
await client.async_kv_retrieve_meta(
keys="existing_key",
partition_id="existing_partition",
create=True,
)
metadata = await client.async_kv_retrieve_meta(
keys="existing_key",
partition_id="existing_partition",
create=False,
)
assert metadata is not None
assert metadata.size == 1
@pytest.mark.asyncio
async def test_async_kv_retrieve_meta_invalid_keys_type(self, client_setup):
"""Test async_kv_retrieve_meta raises error with invalid keys type."""
client, _, _ = client_setup
with pytest.raises(TypeError):
await client.async_kv_retrieve_meta(
keys=123,
partition_id="test_partition",
create=True,
)
@pytest.mark.asyncio
async def test_async_kv_list_with_keys(self, client_setup):
"""Test async_kv_list returns keys after they are registered."""
client, mock_controller, _ = client_setup
await client.async_kv_retrieve_meta(
keys=["key_1", "key_2"],
partition_id="kv_partition",
create=True,
)
partition_info = await client.async_kv_list(partition_id="kv_partition")
assert len(partition_info["kv_partition"]) >= 2
assert "key_1" in partition_info["kv_partition"]
assert "key_2" in partition_info["kv_partition"]
assert list(partition_info["kv_partition"].values()) == [{}, {}]
@pytest.mark.asyncio
async def test_async_kv_list_multiple_partitions(self, client_setup):
"""Test async_kv_list with multiple partitions."""
client, _, _ = client_setup
await client.async_kv_retrieve_meta(
keys="partition_a_key",
partition_id="partition_a",
create=True,
)
await client.async_kv_retrieve_meta(
keys="partition_b_key",
partition_id="partition_b",
create=True,
)
partition_a = await client.async_kv_list(partition_id="partition_a")
partition_b = await client.async_kv_list(partition_id="partition_b")
assert "partition_a" in partition_a
assert "partition_b" in partition_b
assert "partition_a" not in partition_b
assert "partition_b" not in partition_a
assert "partition_a_key" in partition_a["partition_a"]
assert "partition_b_key" not in partition_a["partition_a"]
assert "partition_b_key" in partition_b["partition_b"]
assert "partition_a_key" not in partition_b["partition_b"]
assert list(partition_a["partition_a"].values()) == [{}]
assert list(partition_b["partition_b"].values()) == [{}]
def test_kv_retrieve_meta_type_validation(self, client_setup):
"""Test synchronous kv_retrieve_meta type validation."""
import asyncio
client, _, _ = client_setup
async def test_invalid_list():
with pytest.raises(TypeError):
await client.async_kv_retrieve_meta(
keys=["valid_key", 123],
partition_id="test_partition",
create=True,
)
asyncio.run(test_invalid_list())
@pytest.mark.asyncio
async def test_async_kv_retrieve_keys_single(self, client_setup):
"""Test async_kv_retrieve_keys with single global_index."""
client, _, _ = client_setup
partition_id = "test_partition_idx"
await client.async_kv_retrieve_meta(
keys=["test_key"],
partition_id=partition_id,
create=True,
)
keys = await client.async_kv_retrieve_keys(
global_indexes=[0],
partition_id=partition_id,
)
assert keys == ["test_key"]
@pytest.mark.asyncio
async def test_async_kv_retrieve_keys_multiple(self, client_setup):
"""Test async_kv_retrieve_keys with multiple global_indexes."""
client, _, _ = client_setup
partition_id = "test_partition_idx"
keys_to_create = ["key_a", "key_b", "key_c"]
await client.async_kv_retrieve_meta(
keys=keys_to_create,
partition_id=partition_id,
create=True,
)
keys = await client.async_kv_retrieve_keys(
global_indexes=[0, 1, 2],
partition_id=partition_id,
)
assert keys == ["key_a", "key_b", "key_c"]
@pytest.mark.asyncio
async def test_async_kv_retrieve_keys_partial(self, client_setup):
"""Test async_kv_retrieve_keys with subset of global_indexes."""
client, _, _ = client_setup
partition_id = "test_partition_idx"
await client.async_kv_retrieve_meta(
keys=["first_key", "second_key", "third_key"],
partition_id=partition_id,
create=True,
)
keys = await client.async_kv_retrieve_keys(
global_indexes=[0, 2],
partition_id=partition_id,
)
assert keys == ["first_key", "third_key"]
@pytest.mark.asyncio
async def test_async_kv_retrieve_keys_single_int(self, client_setup):
"""Test async_kv_retrieve_keys accepts a single int."""
client, _, _ = client_setup
partition_id = "test_partition_idx"
await client.async_kv_retrieve_meta(
keys=["single_key"],
partition_id=partition_id,
create=True,
)
keys = await client.async_kv_retrieve_keys(
global_indexes=0,
partition_id=partition_id,
)
assert keys == ["single_key"]
@pytest.mark.asyncio
async def test_async_kv_retrieve_keys_invalid_type(self, client_setup):
"""Test async_kv_retrieve_keys raises error with invalid global_indexes type."""
client, _, _ = client_setup
with pytest.raises(TypeError):
await client.async_kv_retrieve_keys(
global_indexes=["not_an_int"],
partition_id="test_partition",
)
@pytest.mark.asyncio
async def test_async_kv_retrieve_keys_empty_list(self, client_setup):
"""Test async_kv_retrieve_keys raises error with empty list."""
client, _, _ = client_setup
with pytest.raises(ValueError):
await client.async_kv_retrieve_keys(
global_indexes=[],
partition_id="test_partition",
)
@pytest.mark.asyncio
async def test_async_kv_retrieve_keys_non_existent(self, client_setup):
"""Test async_kv_retrieve_keys returns None for non-existent global_indexes."""
client, _, _ = client_setup
partition_id = "test_partition_idx"
await client.async_kv_retrieve_meta(
keys=["existing_key"],
partition_id=partition_id,
create=True,
)
keys = await client.async_kv_retrieve_keys(
global_indexes=[99],
partition_id=partition_id,
)
assert keys == [None]
@pytest.mark.asyncio
async def test_async_kv_retrieve_keys_multiple_partitions(self, client_setup):
"""Test async_kv_retrieve_keys returns keys from the correct partition."""
client, _, _ = client_setup
partition_1 = "partition_1"
partition_2 = "partition_2"
await client.async_kv_retrieve_meta(
keys=["key_1"],
partition_id=partition_1,
create=True,
)
await client.async_kv_retrieve_meta(
keys=["key_2"],
partition_id=partition_2,
create=True,
)
keys_1 = await client.async_kv_retrieve_keys(
global_indexes=[0],
partition_id=partition_1,
)
keys_2 = await client.async_kv_retrieve_keys(
global_indexes=[0],
partition_id=partition_2,
)
assert keys_1 == ["key_1"]
assert keys_2 == ["key_2"]
def test_kv_retrieve_keys_sync(self, client_setup):
"""Test synchronous kv_retrieve_keys."""
client, _, _ = client_setup
partition_id = "test_partition_sync"
client.kv_retrieve_meta(
keys=["sync_key"],
partition_id=partition_id,
create=True,
)
keys = client.kv_retrieve_keys(
global_indexes=[0],
partition_id=partition_id,
)
assert keys == ["sync_key"]
def test_kv_retrieve_keys_type_validation(self, client_setup):
"""Test synchronous kv_retrieve_keys type validation."""
client, _, _ = client_setup
with pytest.raises(TypeError):
client.kv_retrieve_keys(
global_indexes=[0, "invalid"],
partition_id="test_partition",
)