import time
import pytest
import ray
import tensordict
import torch
import zmq
from transfer_queue.storage.simple_storage import SimpleStorageUnit
from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType
class MockStorageClient:
"""Mock client for testing storage unit operations."""
def __init__(self, storage_put_get_address):
self.context = zmq.Context()
self.socket = self.context.socket(zmq.DEALER)
self.socket.setsockopt(zmq.RCVTIMEO, 5000)
self.socket.connect(storage_put_get_address)
def send_put(self, client_id, global_indexes, field_data, data_parser=None):
body = {"global_indexes": global_indexes, "data": field_data}
if data_parser is not None:
body["data_parser"] = data_parser
msg = ZMQMessage.create(
request_type=ZMQRequestType.PUT_DATA,
sender_id=f"mock_client_{client_id}",
body=body,
)
self.socket.send_multipart(msg.serialize())
return ZMQMessage.deserialize(self.socket.recv_multipart(copy=False))
def send_get(self, client_id, global_indexes, fields):
msg = ZMQMessage.create(
request_type=ZMQRequestType.GET_DATA,
sender_id=f"mock_client_{client_id}",
body={"global_indexes": global_indexes, "fields": fields},
)
self.socket.send_multipart(msg.serialize())
return ZMQMessage.deserialize(self.socket.recv_multipart(copy=False))
def send_clear(self, client_id, global_indexes):
msg = ZMQMessage.create(
request_type=ZMQRequestType.CLEAR_DATA,
sender_id=f"mock_client_{client_id}",
body={"global_indexes": global_indexes},
)
self.socket.send_multipart(msg.serialize())
return ZMQMessage.deserialize(self.socket.recv_multipart(copy=False))
def close(self):
self.socket.close()
self.context.term()
@pytest.fixture(scope="session")
def ray_setup():
"""Initialize Ray for testing."""
ray.init(ignore_reinit_error=True)
yield
ray.shutdown()
@pytest.fixture
def storage_setup(ray_setup):
"""Set up storage unit for testing."""
storage_size = 10000
tensordict.set_list_to_stack(True).set()
storage_actor = SimpleStorageUnit.options(max_concurrency=50, num_cpus=1).remote(storage_unit_size=storage_size)
zmq_info = ray.get(storage_actor.get_zmq_server_info.remote())
put_get_address = zmq_info.to_addr("put_get_socket")
time.sleep(1)
yield storage_actor, put_get_address
ray.kill(storage_actor)
def test_put_get_single_client(storage_setup):
"""Test basic put and get operations with a single client."""
_, put_get_address = storage_setup
client = MockStorageClient(put_get_address)
global_indexes = [0, 1, 2]
field_data = {
"log_probs": [torch.tensor([1.0, 2.0, 3.0]), torch.tensor([4.0, 5.0, 6.0]), torch.tensor([7.0, 8.0, 9.0])],
"rewards": [torch.tensor([10.0]), torch.tensor([20.0]), torch.tensor([30.0])],
}
response = client.send_put(0, global_indexes, field_data)
assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
response = client.send_get(0, [0, 1], ["log_probs", "rewards"])
assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
retrieved_data = response.body["data"]
assert "log_probs" in retrieved_data
assert "rewards" in retrieved_data
assert len(retrieved_data["log_probs"]) == 2
assert len(retrieved_data["rewards"]) == 2
torch.testing.assert_close(retrieved_data["log_probs"][0], torch.tensor([1.0, 2.0, 3.0]))
torch.testing.assert_close(retrieved_data["log_probs"][1], torch.tensor([4.0, 5.0, 6.0]))
torch.testing.assert_close(retrieved_data["rewards"][0], torch.tensor([10.0]))
torch.testing.assert_close(retrieved_data["rewards"][1], torch.tensor([20.0]))
client.close()
def test_put_get_multiple_clients(storage_setup):
"""Test put and get operations with multiple clients."""
_, put_get_address = storage_setup
num_clients = 3
clients = [MockStorageClient(put_get_address) for _ in range(num_clients)]
for i, client in enumerate(clients):
global_indexes = [i * 10 + 0, i * 10 + 1, i * 10 + 2]
field_data = {
"log_probs": [
torch.tensor([i, i + 1, i + 2]),
torch.tensor([i + 3, i + 4, i + 5]),
torch.tensor([i + 6, i + 7, i + 8]),
],
"rewards": [torch.tensor([i * 10]), torch.tensor([i * 10 + 10]), torch.tensor([i * 10 + 20])],
}
response = client.send_put(i, global_indexes, field_data)
assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
overlapping_client = MockStorageClient(put_get_address)
overlap_global_indexes = [0]
overlap_field_data = {"log_probs": [torch.tensor([999, 999, 999])], "rewards": [torch.tensor([999])]}
response = overlapping_client.send_put(99, overlap_global_indexes, overlap_field_data)
assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
for i, client in enumerate(clients):
response = client.send_get(i, [i * 10 + 0, i * 10 + 1], ["log_probs", "rewards"])
assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
retrieved_data = response.body["data"]
assert len(retrieved_data["log_probs"]) == 2
assert len(retrieved_data["rewards"]) == 2
if i == 0:
torch.testing.assert_close(retrieved_data["log_probs"][0], torch.tensor([999, 999, 999]))
torch.testing.assert_close(retrieved_data["rewards"][0], torch.tensor([999]))
torch.testing.assert_close(retrieved_data["log_probs"][1], torch.tensor([3, 4, 5]))
torch.testing.assert_close(retrieved_data["rewards"][1], torch.tensor([10]))
else:
torch.testing.assert_close(retrieved_data["log_probs"][0], torch.tensor([i, i + 1, i + 2]))
torch.testing.assert_close(retrieved_data["log_probs"][1], torch.tensor([i + 3, i + 4, i + 5]))
torch.testing.assert_close(retrieved_data["rewards"][0], torch.tensor([i * 10]))
torch.testing.assert_close(retrieved_data["rewards"][1], torch.tensor([i * 10 + 10]))
for client in clients:
client.close()
overlapping_client.close()
def test_performance_basic(storage_setup):
"""Basic performance test with larger data volume."""
_, put_get_address = storage_setup
client = MockStorageClient(put_get_address)
put_latencies = []
num_puts = 10
batch_size = 16
for i in range(num_puts):
start = time.time()
global_indexes = list(range(i * batch_size, (i + 1) * batch_size))
log_probs_data = []
rewards_data = []
for _ in range(batch_size):
log_probs_tensor = torch.randn(100)
rewards_tensor = torch.randn(100)
log_probs_data.append(log_probs_tensor)
rewards_data.append(rewards_tensor)
field_data = {"log_probs": log_probs_data, "rewards": rewards_data}
response = client.send_put(0, global_indexes, field_data)
latency = time.time() - start
put_latencies.append(latency)
assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
get_latencies = []
num_gets = 10
for i in range(num_gets):
start = time.time()
global_indexes = list(range(i * batch_size, (i + 1) * batch_size))
response = client.send_get(0, global_indexes, ["log_probs", "rewards"])
latency = time.time() - start
get_latencies.append(latency)
assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
avg_put_latency = sum(put_latencies) / len(put_latencies) * 1000
avg_get_latency = sum(get_latencies) / len(get_latencies) * 1000
assert avg_put_latency < 1500, f"Avg PUT latency {avg_put_latency}ms exceeds threshold"
assert avg_get_latency < 1500, f"Avg GET latency {avg_get_latency}ms exceeds threshold"
client.close()
def test_put_get_nested_tensor(storage_setup):
"""Test put and get operations with nested tensors."""
_, put_get_address = storage_setup
client = MockStorageClient(put_get_address)
global_indexes = [0, 1, 2]
field_data = {
"variable_length_sequences": [
torch.tensor([-0.5, -1.2, -0.8]),
torch.tensor([-0.3, -1.5, -2.1, -0.9]),
torch.tensor([-1.1, -0.7]),
],
"attention_mask": [torch.tensor([1, 1, 1]), torch.tensor([1, 1, 1, 1]), torch.tensor([1, 1])],
}
response = client.send_put(0, global_indexes, field_data)
assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
response = client.send_get(0, [0, 2], ["variable_length_sequences", "attention_mask"])
assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
retrieved_data = response.body["data"]
assert "variable_length_sequences" in retrieved_data
assert "attention_mask" in retrieved_data
assert len(retrieved_data["variable_length_sequences"]) == 2
assert len(retrieved_data["attention_mask"]) == 2
torch.testing.assert_close(retrieved_data["variable_length_sequences"][0], torch.tensor([-0.5, -1.2, -0.8]))
torch.testing.assert_close(retrieved_data["variable_length_sequences"][1], torch.tensor([-1.1, -0.7]))
torch.testing.assert_close(retrieved_data["attention_mask"][0], torch.tensor([1, 1, 1]))
torch.testing.assert_close(retrieved_data["attention_mask"][1], torch.tensor([1, 1]))
client.close()
def test_put_get_non_tensor_data(storage_setup):
"""Test put and get operations with non-tensor data (strings)."""
_, put_get_address = storage_setup
client = MockStorageClient(put_get_address)
global_indexes = [0, 1, 2]
field_data = {
"prompt_text": ["Hello world!", "This is a longer sentence for testing", "Test case"],
"response_text": ["Hi there!", "This is the response to the longer sentence", "Test response"],
}
response = client.send_put(0, global_indexes, field_data)
assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
response = client.send_get(0, [0, 1, 2], ["prompt_text", "response_text"])
assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
retrieved_data = response.body["data"]
assert "prompt_text" in retrieved_data
assert "response_text" in retrieved_data
assert isinstance(retrieved_data["prompt_text"][0], str)
assert isinstance(retrieved_data["response_text"][0], str)
assert retrieved_data["prompt_text"][0] == "Hello world!"
assert retrieved_data["prompt_text"][1] == "This is a longer sentence for testing"
assert retrieved_data["prompt_text"][2] == "Test case"
assert retrieved_data["response_text"][0] == "Hi there!"
assert retrieved_data["response_text"][1] == "This is the response to the longer sentence"
assert retrieved_data["response_text"][2] == "Test response"
client.close()
def test_put_get_single_item(storage_setup):
"""Test put and get operations for a single item."""
_, put_get_address = storage_setup
client = MockStorageClient(put_get_address)
field_data = {
"prompt_text": ["Hello world!"],
"attention_mask": [torch.tensor([1, 1, 1])],
}
response = client.send_put(0, [0], field_data)
assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
response = client.send_get(0, [0], ["prompt_text", "attention_mask"])
assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
retrieved_data = response.body["data"]
assert "prompt_text" in retrieved_data
assert "attention_mask" in retrieved_data
assert retrieved_data["prompt_text"][0] == "Hello world!"
assert len(retrieved_data["attention_mask"]) == 1
torch.testing.assert_close(retrieved_data["attention_mask"][0], torch.tensor([1, 1, 1]))
client.close()
def test_clear_data(storage_setup):
"""Test clear operations."""
_, put_get_address = storage_setup
client = MockStorageClient(put_get_address)
global_indexes = [0, 1, 2]
field_data = {
"log_probs": [torch.tensor([1.0]), torch.tensor([2.0]), torch.tensor([3.0])],
"rewards": [torch.tensor([10.0]), torch.tensor([20.0]), torch.tensor([30.0])],
}
response = client.send_put(0, global_indexes, field_data)
assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
response = client.send_get(0, [0, 1, 2], ["log_probs"])
assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
assert len(response.body["data"]["log_probs"]) == 3
response = client.send_clear(0, [0, 2])
assert response.request_type == ZMQRequestType.CLEAR_DATA_RESPONSE
response = client.send_get(0, [1], ["log_probs"])
assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
assert len(response.body["data"]["log_probs"]) == 1
torch.testing.assert_close(response.body["data"]["log_probs"][0], torch.tensor([2.0]))
client.close()
def test_storage_unit_data_direct():
"""Test StorageUnitData class directly without ZMQ."""
from transfer_queue.storage import StorageUnitData
storage_data = StorageUnitData(storage_size=10)
field_data = {
"log_probs": [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])],
"rewards": [torch.tensor([10.0]), torch.tensor([20.0])],
}
storage_data.put_data(field_data, [0, 1])
result = storage_data.get_data(["log_probs", "rewards"], [0, 1])
assert "log_probs" in result
assert "rewards" in result
assert len(result["log_probs"]) == 2
assert len(result["rewards"]) == 2
result_single = storage_data.get_data(["log_probs"], [0])
torch.testing.assert_close(result_single["log_probs"][0], torch.tensor([1.0, 2.0]))
storage_data.clear([0])
assert 0 not in storage_data.field_data["log_probs"]
assert 1 in storage_data.field_data["log_probs"]
def test_storage_unit_data_capacity_uses_active_keys():
"""Capacity check must use _active_keys, not scan field_data."""
from transfer_queue.storage.simple_storage import StorageUnitData
storage = StorageUnitData(storage_size=3)
storage.put_data({"f": [1, 2, 3]}, global_indexes=[0, 1, 2])
assert len(storage._active_keys) == 3
with pytest.raises(ValueError, match="Storage capacity exceeded"):
storage.put_data({"f": [4]}, global_indexes=[3])
storage.clear(keys=[2])
assert len(storage._active_keys) == 2
storage.put_data({"f": [4]}, global_indexes=[3])
assert storage._active_keys == {0, 1, 3}
def test_storage_unit_data_parser(storage_setup):
"""Test data_parser functionality in SimpleStorageUnit.
Writes two columns:
- normal_data: regular tensors, should remain unchanged
- data_to_be_parsed: list of shape descriptors (list of ints)
data_parser converts shape descriptors into random tensors of those shapes.
"""
_, put_get_address = storage_setup
client = MockStorageClient(put_get_address)
def create_data_by_shape_parser(field_data):
if "data_to_be_parsed" in field_data:
shapes = field_data["data_to_be_parsed"]
field_data["data_to_be_parsed"] = [torch.randn(shape) for shape in shapes]
return field_data
field_data = {
"normal_data": torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]),
"data_to_be_parsed": [[2, 3], [1, 4], [3, 2]],
}
global_indexes = [0, 1, 2]
response = client.send_put(0, global_indexes, field_data, data_parser=create_data_by_shape_parser)
assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE, f"Put failed: {response.body}"
response = client.send_get(0, global_indexes, ["normal_data", "data_to_be_parsed"])
assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
result = response.body["data"]
torch.testing.assert_close(result["normal_data"][0], torch.tensor([1.0, 2.0]))
torch.testing.assert_close(result["normal_data"][1], torch.tensor([3.0, 4.0]))
torch.testing.assert_close(result["normal_data"][2], torch.tensor([5.0, 6.0]))
expected_shapes = [(2, 3), (1, 4), (3, 2)]
for i, expected_shape in enumerate(expected_shapes):
actual_shape = tuple(result["data_to_be_parsed"][i].shape)
assert actual_shape == expected_shape, (
f"Shape mismatch at index {i}: expected {expected_shape}, got {actual_shape}"
)
client.close()
def test_storage_unit_data_parser_callable_types(storage_setup):
"""Test that various callable types (partial, callable class) work as data_parser."""
_, put_get_address = storage_setup
client = MockStorageClient(put_get_address)
from functools import partial
def _partial_parser(field_data, prefix):
if "text" in field_data:
field_data["text"] = [f"{prefix}{t}" for t in field_data["text"]]
return field_data
partial_parser = partial(_partial_parser, prefix="parsed_")
response = client.send_put(
0,
[0, 1],
{"text": ["a", "b"]},
data_parser=partial_parser,
)
assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE, f"partial parser failed: {response.body}"
response = client.send_get(0, [0, 1], ["text"])
assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
assert response.body["data"]["text"] == ["parsed_a", "parsed_b"]
class CallableParser:
def __call__(self, field_data):
if "value" in field_data:
field_data["value"] = [v * 2 for v in field_data["value"]]
return field_data
callable_parser = CallableParser()
response = client.send_put(
0,
[2, 3],
{"value": [1, 2]},
data_parser=callable_parser,
)
assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE, f"callable class parser failed: {response.body}"
response = client.send_get(0, [2, 3], ["value"])
assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
assert response.body["data"]["value"] == [2, 4]
client.close()
def test_storage_unit_data_parser_validation(storage_setup):
"""Test that invalid data_parser inputs produce clear error messages."""
_, put_get_address = storage_setup
client = MockStorageClient(put_get_address)
response = client.send_put(
0,
[0],
{"data": [1]},
data_parser="not_callable",
)
assert response.request_type == ZMQRequestType.PUT_ERROR
assert "data_parser must be callable" in response.body["message"]
def bad_parser(field_data):
return "not_a_dict"
response = client.send_put(
0,
[1],
{"data": [1]},
data_parser=bad_parser,
)
assert response.request_type == ZMQRequestType.PUT_ERROR
assert "data_parser must return a dict" in response.body["message"]
def delete_key_parser(field_data):
del field_data["data"]
return field_data
response = client.send_put(
0,
[2],
{"data": [1], "extra": [2]},
data_parser=delete_key_parser,
)
assert response.request_type == ZMQRequestType.PUT_ERROR
assert "data_parser must not change dict keys" in response.body["message"]
def add_key_parser(field_data):
field_data["new_key"] = [999]
return field_data
response = client.send_put(
0,
[3],
{"data": [1]},
data_parser=add_key_parser,
)
assert response.request_type == ZMQRequestType.PUT_ERROR
assert "data_parser must not change dict keys" in response.body["message"]
def wrong_len_parser(field_data):
field_data["data"] = field_data["data"][:-1]
return field_data
response = client.send_put(
0,
[4, 5],
{"data": [1, 2]},
data_parser=wrong_len_parser,
)
assert response.request_type == ZMQRequestType.PUT_ERROR
assert "data_parser changed the number of elements" in response.body["message"]
client.close()