import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import torch
from tensordict import TensorDict
from transfer_queue.metadata import BatchMeta
from transfer_queue.storage.managers.base import KVStorageManager
def get_meta(data, global_indexes=None):
if not global_indexes:
global_indexes = list(range(data.batch_size[0]))
field_schema = {}
for field_name in data.keys():
field_data = data[field_name]
if isinstance(field_data, torch.Tensor) and field_data.is_nested:
per_sample_shapes = [t.shape for t in field_data.unbind()]
field_schema[field_name] = {
"dtype": field_data.dtype,
"shape": per_sample_shapes[0],
"per_sample_shapes": per_sample_shapes,
"is_nested": True,
"is_non_tensor": False,
}
else:
tensor = field_data[0]
field_schema[field_name] = {
"dtype": tensor.dtype if isinstance(tensor, torch.Tensor) else type(tensor),
"shape": tensor.shape if isinstance(tensor, torch.Tensor) else None,
"is_nested": False,
"is_non_tensor": not isinstance(tensor, torch.Tensor),
}
import numpy as np
production_status = np.ones(len(global_indexes), dtype=np.int8)
metadata = BatchMeta(
global_indexes=list(global_indexes),
partition_ids=["0"] * len(global_indexes),
field_schema=field_schema,
production_status=production_status,
)
return metadata
@pytest.fixture
def test_data():
"""Fixture providing test configuration, data, and metadata."""
cfg = {
"controller_info": MagicMock(),
"client_name": "YuanrongStorageClient",
"worker_port": 31501,
"device_id": 0,
}
global_indexes = [8, 9, 10]
data = TensorDict(
{
"input_ids": torch.nested.as_nested_tensor(
[
torch.tensor([1, 2, 3, 4, 5]),
torch.tensor([6, 7, 8, 9]),
torch.tensor([10, 11]),
],
layout=torch.jagged,
),
"prompt_ids": torch.nested.as_nested_tensor(
[
torch.tensor([1, 2]),
torch.tensor([6, 7, 8]),
torch.tensor([10]),
],
layout=torch.jagged,
),
"response_ids": torch.nested.as_nested_tensor(
[
torch.tensor([3, 4, 5]),
torch.tensor([9]),
torch.tensor([11]),
],
layout=torch.jagged,
),
"response_mask": torch.nested.as_nested_tensor(
[
torch.tensor([0, 0, 1, 1, 1]),
torch.tensor([0, 0, 0, 1]),
torch.tensor([0, 1]),
],
layout=torch.jagged,
),
},
batch_size=3,
)
metadata = get_meta(data, global_indexes)
return {
"cfg": cfg,
"field_names": data.keys(),
"global_indexes": global_indexes,
"data": data,
"metadata": metadata,
}
def test_generate_keys(test_data):
"""Test whether _generate_keys can generate the correct key list."""
keys = KVStorageManager._generate_keys(test_data["data"].keys(), test_data["metadata"].global_indexes)
expected = [
"8@input_ids",
"9@input_ids",
"10@input_ids",
"8@prompt_ids",
"9@prompt_ids",
"10@prompt_ids",
"8@response_ids",
"9@response_ids",
"10@response_ids",
"8@response_mask",
"9@response_mask",
"10@response_mask",
]
assert keys == expected
assert len(keys) == 12
def test_generate_values(test_data):
"""
Test whether _generate_values can flatten the TensorDict into an ordered list of tensors,
using field_name as the primary key and global_index as the secondary key.
"""
values = KVStorageManager._generate_values(test_data["data"])
expected_length = len(test_data["field_names"]) * len(test_data["global_indexes"])
expected_values = [
[1, 2, 3, 4, 5],
[6, 7, 8, 9],
[10, 11],
[1, 2],
[6, 7, 8],
[10],
[3, 4, 5],
[9],
[11],
[0, 0, 1, 1, 1],
[0, 0, 0, 1],
[0, 1],
]
expected_values = [torch.tensor(value) for value in expected_values]
assert len(values) == expected_length
for i in range(len(values)):
assert torch.equal(values[i], expected_values[i])
@patch("transfer_queue.storage.managers.base.StorageClientFactory.create")
@patch.object(KVStorageManager, "_connect_to_controller", lambda self: None)
def test_merge_tensors_to_tensordict(mock_create, test_data):
"""Test whether _merge_kv_to_tensordict can correctly reconstruct the TensorDict."""
mock_client = MagicMock()
mock_create.return_value = mock_client
manager = KVStorageManager(controller_info=MagicMock(), config=test_data["cfg"])
assert manager.storage_client is mock_client
assert manager._multi_threads_executor is None
values = manager._generate_values(test_data["data"])
reconstructed = manager._merge_tensors_to_tensordict(test_data["metadata"], values)
assert "input_ids" in reconstructed
assert "prompt_ids" in reconstructed
assert "response_ids" in reconstructed
assert "response_mask" in reconstructed
for key in ["input_ids", "prompt_ids", "response_ids", "response_mask"]:
unbound_a = reconstructed[key].unbind(0)
unbound_b = test_data["data"][key].unbind(0)
assert len(unbound_a) == len(unbound_b), f"Length mismatch for {key}: {len(unbound_a)} vs {len(unbound_b)}"
for t1, t2 in zip(unbound_a, unbound_b, strict=True):
assert torch.equal(t1, t2)
assert reconstructed.batch_size == torch.Size([3])
complex_data = TensorDict(
{
"input_ids": torch.nested.nested_tensor([[1, 2], [3], [4]]),
"prompt": ["5", "6", "7"],
"extra": [torch.Tensor([8]), "9", torch.Tensor([10])],
},
batch_size=[3],
)
complex_meta = get_meta(complex_data)
complex_values = manager._generate_values(complex_data)
complex_tensordict = manager._merge_tensors_to_tensordict(complex_meta, complex_values)
assert "input_ids" in complex_tensordict
assert "prompt" in complex_tensordict
for key in complex_tensordict.keys():
if isinstance(complex_tensordict[key], torch.Tensor):
unbound_a = complex_tensordict[key].unbind(0)
unbound_b = complex_data[key].unbind(0)
assert len(unbound_a) == len(unbound_b), f"Length mismatch for {key}: {len(unbound_a)} vs {len(unbound_b)}"
for t1, t2 in zip(unbound_a, unbound_b, strict=True):
assert torch.equal(t1, t2)
else:
assert complex_tensordict[key] == complex_data[key]
def test_get_shape_type_custom_backend_meta_list_without_custom_backend_meta(test_data):
"""Test _get_shape_type_custom_backend_meta_list returns correct shapes and dtypes without custom_backend_meta."""
shapes, dtypes, custom_backend_meta_list = KVStorageManager._get_shape_type_custom_backend_meta_list(
test_data["metadata"]
)
expected_shapes = [
torch.Size([5]),
torch.Size([4]),
torch.Size([2]),
torch.Size([2]),
torch.Size([3]),
torch.Size([1]),
torch.Size([3]),
torch.Size([1]),
torch.Size([1]),
torch.Size([5]),
torch.Size([4]),
torch.Size([2]),
]
expected_dtypes = [torch.int64] * (len(test_data["field_names"]) * len(test_data["global_indexes"]))
expected_custom_backend_meta = [None] * (len(test_data["field_names"]) * len(test_data["global_indexes"]))
assert shapes == expected_shapes
assert dtypes == expected_dtypes
assert custom_backend_meta_list == expected_custom_backend_meta
def test_get_shape_type_custom_backend_meta_list_with_custom_backend_meta(test_data):
"""Test _get_shape_type_custom_backend_meta_list returns correct custom_backend_meta when provided."""
metadata = test_data["metadata"]
metadata._custom_backend_meta = [
{
"input_ids": {"key1": "value1"},
"prompt_ids": {"key2": "value2"},
"response_ids": {"key3": "value3"},
"response_mask": {"key4": "value4"},
},
{
"input_ids": {"key5": "value5"},
"prompt_ids": {"key6": "value6"},
"response_ids": {"key7": "value7"},
"response_mask": {"key8": "value8"},
},
{
"input_ids": {"key9": "value9"},
"prompt_ids": {"key10": "value10"},
"response_ids": {"key11": "value11"},
"response_mask": {"key12": "value12"},
},
]
shapes, dtypes, custom_backend_meta_list = KVStorageManager._get_shape_type_custom_backend_meta_list(metadata)
expected_custom_backend_meta = [
{"key1": "value1"},
{"key5": "value5"},
{"key9": "value9"},
{"key2": "value2"},
{"key6": "value6"},
{"key10": "value10"},
{"key3": "value3"},
{"key7": "value7"},
{"key11": "value11"},
{"key4": "value4"},
{"key8": "value8"},
{"key12": "value12"},
]
assert custom_backend_meta_list == expected_custom_backend_meta
def test_get_shape_type_custom_backend_meta_list_with_partial_custom_backend_meta(test_data):
"""Test _get_shape_type_custom_backend_meta_list handles partial custom_backend_meta correctly."""
metadata = test_data["metadata"]
metadata._custom_backend_meta = [
{"input_ids": {"key1": "value1"}},
{},
{
"prompt_ids": {"key2": "value2"},
"response_ids": {"key3": "value3"},
},
]
shapes, dtypes, custom_backend_meta_list = KVStorageManager._get_shape_type_custom_backend_meta_list(metadata)
expected_custom_backend_meta = [
{"key1": "value1"},
None,
None,
None,
None,
{"key2": "value2"},
None,
None,
{"key3": "value3"},
None,
None,
None,
]
assert custom_backend_meta_list == expected_custom_backend_meta
@pytest.fixture
def test_data_for_put_data():
"""Provide test fixtures for put_data tests."""
field_names = ["text", "label"]
global_indexes = [0, 1, 2]
data = TensorDict(
{
"text": torch.tensor([[1, 2], [3, 4], [5, 6]]),
"label": torch.tensor([0, 1, 2]),
},
batch_size=3,
)
metadata = get_meta(data, global_indexes)
return {
"field_names": field_names,
"global_indexes": global_indexes,
"data": data,
"metadata": metadata,
}
STORAGE_CLIENT_FACTORY_PATH = "transfer_queue.storage.managers.base.StorageClientFactory"
@patch.object(KVStorageManager, "_connect_to_controller", lambda self: None)
@patch.object(KVStorageManager, "notify_data_update", new_callable=AsyncMock)
def test_put_data_with_custom_backend_meta_from_storage_client(mock_notify, test_data_for_put_data):
"""Test that put_data correctly processes custom_backend_meta returned by storage client."""
mock_storage_client = MagicMock()
mock_custom_backend_meta = [
{"storage_key": "0@label"},
{"storage_key": "1@label"},
{"storage_key": "2@label"},
{"storage_key": "0@text"},
{"storage_key": "1@text"},
{"storage_key": "2@text"},
]
mock_storage_client.put.return_value = mock_custom_backend_meta
config = {"client_name": "MockClient"}
with patch(f"{STORAGE_CLIENT_FACTORY_PATH}.create", return_value=mock_storage_client):
manager = KVStorageManager(controller_info=MagicMock(), config=config)
asyncio.run(manager.put_data(test_data_for_put_data["data"], test_data_for_put_data["metadata"]))
mock_storage_client.put.assert_called_once()
call_args = mock_storage_client.put.call_args
keys = call_args[0][0]
values = call_args[0][1]
expected_keys = ["0@label", "1@label", "2@label", "0@text", "1@text", "2@text"]
assert keys == expected_keys
assert len(values) == 6
mock_notify.assert_called_once()
notify_call_args = mock_notify.call_args
per_field_custom_backend_meta = notify_call_args[0][3]
assert 0 in per_field_custom_backend_meta
assert 1 in per_field_custom_backend_meta
assert 2 in per_field_custom_backend_meta
assert per_field_custom_backend_meta[0]["label"] == {"storage_key": "0@label"}
assert per_field_custom_backend_meta[0]["text"] == {"storage_key": "0@text"}
assert per_field_custom_backend_meta[1]["label"] == {"storage_key": "1@label"}
assert per_field_custom_backend_meta[1]["text"] == {"storage_key": "1@text"}
assert per_field_custom_backend_meta[2]["label"] == {"storage_key": "2@label"}
assert per_field_custom_backend_meta[2]["text"] == {"storage_key": "2@text"}
all_custom_backend_meta = test_data_for_put_data["metadata"]._custom_backend_meta
assert len(all_custom_backend_meta) == 3
assert all_custom_backend_meta[0]["label"] == {"storage_key": "0@label"}
assert all_custom_backend_meta[2]["text"] == {"storage_key": "2@text"}
@patch.object(KVStorageManager, "_connect_to_controller", lambda self: None)
@patch.object(KVStorageManager, "notify_data_update", new_callable=AsyncMock)
def test_put_data_without_custom_backend_meta(mock_notify, test_data_for_put_data):
"""Test that put_data works correctly when storage client returns no custom_backend_meta."""
mock_storage_client = MagicMock()
mock_storage_client.put.return_value = None
config = {"controller_info": MagicMock(), "client_name": "MockClient"}
with patch(f"{STORAGE_CLIENT_FACTORY_PATH}.create", return_value=mock_storage_client):
manager = KVStorageManager(controller_info=MagicMock(), config=config)
asyncio.run(manager.put_data(test_data_for_put_data["data"], test_data_for_put_data["metadata"]))
mock_notify.assert_called_once()
notify_call_args = mock_notify.call_args
per_field_custom_backend_meta = notify_call_args[0][3]
assert per_field_custom_backend_meta == {}
@patch.object(KVStorageManager, "_connect_to_controller", lambda self: None)
def test_put_data_custom_backend_meta_length_mismatch_raises_error(test_data_for_put_data):
"""Test that put_data raises ValueError when custom_backend_meta length doesn't match keys."""
mock_storage_client = MagicMock()
mock_storage_client.put.return_value = [{"key": "1"}, {"key": "2"}, {"key": "3"}]
config = {"controller_info": MagicMock(), "client_name": "MockClient"}
with patch(f"{STORAGE_CLIENT_FACTORY_PATH}.create", return_value=mock_storage_client):
manager = KVStorageManager(controller_info=MagicMock(), config=config)
with pytest.raises(ValueError) as exc_info:
asyncio.run(manager.put_data(test_data_for_put_data["data"], test_data_for_put_data["metadata"]))
assert "does not match" in str(exc_info.value)