import os
import sys
import time
import numpy as np
import pytest
import ray
import torch
from omegaconf import OmegaConf
from tensordict import TensorDict
from tensordict.tensorclass import NonTensorData
DEFAULT_FIELDS = [
"tensor_f32",
"tensor_i64",
"tensor_bf16",
"tensor_f16",
"nested_jagged",
"nested_strided",
"list_int",
"list_str",
"list_obj",
"np_array",
"np_bytes_str",
"np_obj",
"special_val",
"non_tensor_stack",
]
BACKEND_CONFIGS = {
"SimpleStorage": {
"controller": {
"polling_mode": True,
},
"backend": {
"storage_backend": "SimpleStorage",
"SimpleStorage": {
"total_storage_size": 200,
"num_data_storage_units": 2,
},
},
},
"MooncakeStore": {
"controller": {
"polling_mode": True,
},
"backend": {
"storage_backend": "MooncakeStore",
"MooncakeStore": {
"global_segment_size": 134217728,
"local_buffer_size": 134217728,
"metadata_server": "localhost:50050",
"master_server_address": "localhost:50051",
"protocol": "tcp",
"device_name": "",
},
},
},
"Yuanrong": {
"controller": {
"polling_mode": True,
},
"backend": {
"storage_backend": "Yuanrong",
"Yuanrong": {
"worker_port": 31501,
"metastore_port": 2379,
},
},
},
}
@pytest.fixture(scope="module")
def ray_cluster():
if not ray.is_initialized():
ray.init(ignore_reinit_error=True)
yield
if ray.is_initialized():
ray.shutdown()
@pytest.fixture(scope="module")
def backend_name():
"""Get the backend name from environment variable.
Environment variables:
TQ_TEST_BACKEND: Backend name (SimpleStorage, MooncakeStore, or Yuanrong)
To run tests for a specific backend:
TQ_TEST_BACKEND=SimpleStorage pytest tests/e2e/test_e2e_lifecycle_consistency.py
TQ_TEST_BACKEND=MooncakeStore pytest tests/e2e/test_e2e_lifecycle_consistency.py
TQ_TEST_BACKEND=Yuanrong pytest tests/e2e/test_e2e_lifecycle_consistency.py
"""
return os.environ.get("TQ_TEST_BACKEND", "SimpleStorage")
@pytest.fixture(scope="module")
def e2e_client(ray_cluster, backend_name):
"""Create a client using transfer_queue.init() for lifecycle testing.
Args:
ray_cluster: Ray cluster fixture
backend_name: Backend name from TQ_TEST_BACKEND env var
"""
import transfer_queue
if backend_name not in BACKEND_CONFIGS:
raise ValueError(f"Unknown backend: {backend_name}. Available: {list(BACKEND_CONFIGS.keys())}")
config = BACKEND_CONFIGS[backend_name]
transfer_queue.init(OmegaConf.create(config))
client = transfer_queue.get_client()
yield client
transfer_queue.close()
def generate_complex_data(indices: list[int]) -> TensorDict:
"""Generate complex TensorDict with all supported field types."""
n = len(indices)
tensor_f32 = torch.stack([torch.arange(i, i + 5, dtype=torch.float32) for i in indices])
tensor_i64 = torch.stack([torch.arange(i, i + 5, dtype=torch.int64) for i in indices])
nested_list = []
for i in indices:
length = 2 + (i % 4)
nested_list.append(torch.arange(i, i + length, dtype=torch.float32))
nested_list[0][0] = float("inf")
if len(nested_list) > 1:
nested_list[1][0] = float("nan")
nested_jagged = torch.nested.as_nested_tensor(nested_list, layout=torch.jagged)
strided_tensors = [torch.full((3, 4), float(i)) for i in indices]
nested_strided = torch.nested.nested_tensor(strided_tensors, layout=torch.strided)
list_int = [i * 10 for i in indices]
list_str = [f"sample_{i}" for i in indices]
np_array = np.array([np.arange(i, i + 3) for i in indices], dtype=np.float64)
np_bytes_str = np.array([f"bs_{i}".encode() for i in indices], dtype="|S10")
np_obj = np.array([f"obj_{i}" for i in indices], dtype=object)
special_val = torch.zeros(n, 3)
special_val[:, 0] = float("inf")
special_val[:, 1] = float("nan")
special_val[:, 2] = torch.tensor(indices, dtype=torch.float32)
non_tensor_data = [{"idx": i, "text": f"non_tensor_{i}"} for i in indices]
non_tensor_stack = NonTensorData(data=non_tensor_data, batch_size=(n,), device=None)
tensor_bf16 = torch.stack([torch.arange(i, i + 5, dtype=torch.bfloat16) for i in indices])
tensor_f16 = torch.stack([torch.arange(i, i + 5, dtype=torch.float16) for i in indices])
list_obj = [{"key": f"value_{i}", "num": i} for i in indices]
field_values = {
"tensor_f32": tensor_f32,
"tensor_i64": tensor_i64,
"tensor_bf16": tensor_bf16,
"tensor_f16": tensor_f16,
"nested_jagged": nested_jagged,
"nested_strided": nested_strided,
"list_int": list_int,
"list_str": list_str,
"list_obj": list_obj,
"np_array": np_array,
"np_bytes_str": np_bytes_str,
"np_obj": np_obj,
"special_val": special_val,
"non_tensor_stack": non_tensor_stack,
}
assert set(field_values.keys()) == set(DEFAULT_FIELDS), (
f"generate_complex_data fields mismatch with DEFAULT_FIELDS: "
f"extra={set(field_values.keys()) - set(DEFAULT_FIELDS)}, "
f"missing={set(DEFAULT_FIELDS) - set(field_values.keys())}"
)
return TensorDict(
{field: field_values[field] for field in DEFAULT_FIELDS},
batch_size=n,
)
def poll_for_meta(client, partition_id, data_fields, batch_size, task_name, mode="fetch", max_retries=10):
"""Poll until metadata is ready or max retries reached."""
for _ in range(max_retries):
meta = client.get_meta(
partition_id=partition_id,
data_fields=data_fields,
batch_size=batch_size,
mode=mode,
task_name=task_name,
)
if meta is not None and meta.size > 0:
return meta
time.sleep(0.3)
return None
def verify_special_values(retrieved: torch.Tensor, expected: torch.Tensor) -> bool:
"""Verify special values (NaN, Inf) are preserved."""
if len(retrieved) != len(expected):
return False
for r, e in zip(retrieved, expected, strict=True):
if not (torch.isinf(r[0]) and r[0] > 0):
return False
if not torch.isnan(r[1]):
return False
if not torch.allclose(r[2], e[2]):
return False
return True
def verify_nested_tensor_equal(retrieved, expected) -> bool:
"""Verify nested tensors element by element, handling NaN/Inf."""
r_list = retrieved.unbind()
e_list = expected.unbind()
if len(r_list) != len(e_list):
return False
for r, e in zip(r_list, e_list, strict=True):
r_nan = torch.isnan(r)
e_nan = torch.isnan(e)
if not torch.equal(r_nan, e_nan):
return False
mask = ~r_nan
if mask.any() and not torch.allclose(r[mask], e[mask]):
return False
return True
def verify_non_tensor_data(retrieved, expected) -> bool:
"""Verify NonTensorData content element by element."""
if hasattr(retrieved, "tolist"):
retrieved = retrieved.tolist()
elif hasattr(retrieved, "data"):
retrieved = retrieved.data
if hasattr(expected, "tolist"):
expected = expected.tolist()
elif hasattr(expected, "data"):
expected = expected.data
if isinstance(retrieved, list) and isinstance(expected, list):
if len(retrieved) != len(expected):
return False
return all(r == e for r, e in zip(retrieved, expected, strict=True))
return retrieved == expected
def verify_list_equal(retrieved, expected) -> bool:
"""Verify list content.
Note: TensorDict may materialize Python lists as Tensors or NonTensorStack during
storage/retrieval, so we normalize both sides to native Python types before comparison.
"""
from tensordict.tensorclass import NonTensorStack
if isinstance(retrieved, NonTensorStack):
retrieved = retrieved.tolist()
elif isinstance(retrieved, torch.Tensor):
if retrieved.is_nested:
retrieved = [t.item() for t in retrieved]
else:
retrieved = retrieved.reshape(-1).tolist()
if isinstance(expected, NonTensorStack):
expected = expected.tolist()
elif isinstance(expected, torch.Tensor):
if expected.is_nested:
expected = [t.item() for t in expected]
else:
expected = expected.tolist()
return retrieved == expected
def _reorder_tensordict(td: TensorDict, order: list[int]) -> TensorDict:
"""Reorder a TensorDict by the given index order.
Handles regular tensors, nested/jagged tensors, NonTensorStack, lists, and other
indexable types.
"""
from tensordict.tensorclass import NonTensorStack
reordered = {}
for key in td.keys():
field = td[key]
if isinstance(field, NonTensorStack):
items = field.tolist()
reordered_items = [items[i] for i in order]
reordered[key] = NonTensorStack(*reordered_items, batch_size=[len(order)])
elif isinstance(field, torch.Tensor) and field.is_nested:
items = list(field)
reordered_items = [items[i] for i in order]
reordered[key] = torch.nested.as_nested_tensor(reordered_items, layout=field.layout)
elif isinstance(field, list):
reordered[key] = [field[i] for i in order]
else:
reordered[key] = field[torch.tensor(order)]
return TensorDict(reordered, batch_size=td.batch_size)
def recover_local_index(global_index_order, new_global_index_order):
value_to_new_index = {}
for idx, val in enumerate(new_global_index_order):
value_to_new_index[val] = idx
local_index_order_to_recover = []
for val in global_index_order:
local_index_order_to_recover.append(value_to_new_index[val])
return local_index_order_to_recover
def test_core_consistency(e2e_client):
"""Put full complex data then get - verify all field types are correctly round-tripped."""
client = e2e_client
partition_id = "test_core_consistency"
batch_size = 20
task_name = "core_consistency_task"
indices = list(range(batch_size))
original_data = generate_complex_data(indices)
fields = DEFAULT_FIELDS
meta = client.put(data=original_data, partition_id=partition_id)
assert meta.size == batch_size, f"Expected batch_size {batch_size}, got {meta.size}"
try:
retrieved_meta = poll_for_meta(client, partition_id, fields, batch_size, task_name, mode="fetch")
assert retrieved_meta is not None and retrieved_meta.size == batch_size, "Failed to retrieve metadata"
retrieved_data = client.get_data(retrieved_meta)
for i in range(batch_size):
assert torch.allclose(retrieved_data["tensor_f32"][i], original_data["tensor_f32"][i]), (
f"tensor_f32 mismatch at index {i}"
)
assert torch.equal(retrieved_data["tensor_i64"][i], original_data["tensor_i64"][i]), (
f"tensor_i64 mismatch at index {i}"
)
assert torch.equal(retrieved_data["tensor_bf16"][i], original_data["tensor_bf16"][i]), (
f"tensor_bf16 mismatch at index {i}"
)
assert torch.equal(retrieved_data["tensor_f16"][i], original_data["tensor_f16"][i]), (
f"tensor_f16 mismatch at index {i}"
)
assert verify_nested_tensor_equal(retrieved_data["nested_jagged"], original_data["nested_jagged"]), (
"Jagged nested tensor mismatch"
)
assert verify_nested_tensor_equal(retrieved_data["nested_strided"], original_data["nested_strided"]), (
"Strided nested tensor mismatch"
)
assert verify_list_equal(retrieved_data["list_int"], original_data["list_int"]), "list_int mismatch"
assert verify_list_equal(retrieved_data["list_str"], original_data["list_str"]), "list_str mismatch"
assert verify_list_equal(retrieved_data["list_obj"], original_data["list_obj"]), "list_obj mismatch"
for i in range(batch_size):
assert np.allclose(retrieved_data["np_array"][i].numpy(), original_data["np_array"][i]), (
f"np_array mismatch at index {i}"
)
retrieved_bs = retrieved_data["np_bytes_str"]
if isinstance(retrieved_bs, torch.Tensor) and retrieved_bs.is_nested:
retrieved_bs = [t.item() for t in retrieved_bs]
elif hasattr(retrieved_bs, "tolist"):
retrieved_bs = retrieved_bs.tolist()
expected_bs = original_data["np_bytes_str"]
if hasattr(expected_bs, "tolist") and not isinstance(expected_bs, np.ndarray):
expected_bs = expected_bs.tolist()
assert list(retrieved_bs) == list(expected_bs), "np_bytes_str mismatch"
retrieved_np_obj = retrieved_data["np_obj"]
if isinstance(retrieved_np_obj, torch.Tensor) and retrieved_np_obj.is_nested:
retrieved_np_obj = [t.item() for t in retrieved_np_obj]
elif hasattr(retrieved_np_obj, "tolist"):
retrieved_np_obj = retrieved_np_obj.tolist()
expected_np_obj = original_data["np_obj"]
if hasattr(expected_np_obj, "tolist") and not isinstance(expected_np_obj, np.ndarray):
expected_np_obj = expected_np_obj.tolist()
assert list(retrieved_np_obj) == list(expected_np_obj), "np_obj mismatch"
assert verify_special_values(retrieved_data["special_val"], original_data["special_val"]), (
"Special values (NaN/Inf) not preserved"
)
assert verify_non_tensor_data(retrieved_data["non_tensor_stack"], original_data["non_tensor_stack"]), (
"NonTensorData content mismatch"
)
finally:
client.clear_partition(partition_id)
def test_cross_shard_complex_update(e2e_client):
"""Cross-shard update: put A + put B, update overlapping region, verify all regions."""
client = e2e_client
partition_id = "test_cross_shard_update"
task_name = "cross_shard_task"
idx_a = list(range(0, 20))
idx_b = list(range(20, 40))
idx_update = list(range(10, 30))
base_fields = DEFAULT_FIELDS
alloc_meta = client.get_meta(
partition_id=partition_id,
data_fields=base_fields,
batch_size=40,
mode="insert",
task_name="allocator",
)
assert len(alloc_meta.global_indexes) == 40, "Failed to allocate 40 samples"
try:
data_a = generate_complex_data(idx_a)
meta_a = alloc_meta.select_samples(list(range(0, 20)))
client.put(data=data_a, metadata=meta_a)
data_b = generate_complex_data(idx_b)
meta_b = alloc_meta.select_samples(list(range(20, 40)))
client.put(data=data_b, metadata=meta_b)
modified_indices = [i + 1000 for i in idx_update]
data_update = generate_complex_data(modified_indices)
new_extra_tensor = torch.stack([torch.ones(3) * i for i in idx_update])
new_extra_non_tensor = NonTensorData(
data=[{"new_field": f"new_{i}"} for i in idx_update],
batch_size=(len(idx_update),),
device=None,
)
data_update["new_extra_tensor"] = new_extra_tensor
data_update["new_extra_non_tensor"] = new_extra_non_tensor
meta_update = alloc_meta.select_samples(list(range(10, 30)))
client.put(data=data_update, metadata=meta_update)
full_meta = poll_for_meta(client, partition_id, base_fields, 40, task_name, mode="force_fetch")
assert full_meta is not None and full_meta.size == 40, "Failed to retrieve full metadata"
full_data = client.get_data(full_meta)
sorted_order = sorted(range(full_meta.size), key=lambda i: full_meta.global_indexes[i])
if sorted_order != list(range(full_meta.size)):
full_data = _reorder_tensordict(full_data, sorted_order)
original_data_0_9 = generate_complex_data(list(range(0, 10)))
for i in range(10):
assert torch.allclose(full_data["tensor_f32"][i], original_data_0_9["tensor_f32"][i]), (
f"Region 0-9 tensor_f32 mismatch at index {i}"
)
updated_expected = generate_complex_data([i + 1000 for i in range(10, 30)])
for i in range(20):
assert torch.allclose(full_data["tensor_f32"][10 + i], updated_expected["tensor_f32"][i]), (
f"Region 10-29 tensor_f32 mismatch at index {10 + i}"
)
original_data_30_39 = generate_complex_data(list(range(30, 40)))
for i in range(10):
assert torch.allclose(full_data["tensor_f32"][30 + i], original_data_30_39["tensor_f32"][i]), (
f"Region 30-39 tensor_f32 mismatch at index {30 + i}"
)
update_gis = set(meta_update.global_indexes)
update_positions_in_full = [
i for i, global_index in enumerate(full_meta.global_indexes) if global_index in update_gis
]
extended_fields = base_fields + ["new_extra_tensor", "new_extra_non_tensor"]
extended_meta = poll_for_meta(
client,
partition_id,
extended_fields,
40,
task_name,
mode="force_fetch",
)
assert extended_meta is not None and extended_meta.size > 0, (
"Failed to fetch extended metadata for update region; poll_for_meta returned no or empty metadata."
)
extended_meta = extended_meta.select_samples(update_positions_in_full).select_fields(extended_fields)
update_region_data = client.get_data(extended_meta)
assert "new_extra_tensor" in update_region_data.keys(), "new_extra_tensor should exist"
assert "new_extra_non_tensor" in update_region_data.keys(), "new_extra_non_tensor should exist"
finally:
client.clear_partition(partition_id)
def test_production_status_lifecycle(e2e_client):
"""Multi-round partial put: verify production & consumption status transitions."""
client = e2e_client
partition_id = "test_production_lifecycle"
batch_size = 10
task_name = "production_lifecycle_task"
set_a_fields = ["tensor_f32", "tensor_i64", "list_int", "list_str"]
set_b_fields = ["nested_jagged", "np_array", "special_val"]
all_fields = set_a_fields + set_b_fields
alloc_meta = client.get_meta(
partition_id=partition_id,
data_fields=all_fields,
batch_size=batch_size,
mode="insert",
task_name="allocator",
)
indices = alloc_meta.global_indexes
assert len(indices) == batch_size, f"Expected {batch_size} indices, got {len(indices)}"
try:
full_data = generate_complex_data(indices)
set_a_data = full_data.select(*set_a_fields)
client.put(data=set_a_data, metadata=alloc_meta)
set_a_ready = client.check_production_status(data_fields=set_a_fields, partition_id=partition_id)
assert set_a_ready, "Set_A fields should be ready after Round 1"
set_b_ready = client.check_production_status(data_fields=set_b_fields, partition_id=partition_id)
assert not set_b_ready, "Set_B fields should NOT be ready after Round 1"
all_ready_before = client.check_production_status(data_fields=all_fields, partition_id=partition_id)
assert not all_ready_before, "All fields should NOT be ready before Round 2"
set_b_data = full_data.select(*set_b_fields)
client.put(data=set_b_data, metadata=alloc_meta)
all_ready_after = client.check_production_status(data_fields=all_fields, partition_id=partition_id)
assert all_ready_after, "All fields should be ready after Round 2"
is_consumed = client.check_consumption_status(task_name=task_name, partition_id=partition_id)
assert not is_consumed, "Data should not be consumed initially"
meta = poll_for_meta(client, partition_id, all_fields, batch_size, task_name, mode="fetch")
assert meta is not None, "Failed to poll metadata"
is_consumed_mid = client.check_consumption_status(task_name=task_name, partition_id=partition_id)
assert is_consumed_mid, "Data should already be consumed after get_meta in fetch mode"
client.get_data(meta)
is_consumed_after = client.check_consumption_status(task_name=task_name, partition_id=partition_id)
assert is_consumed_after, "Data should be consumed after get_meta in fetch mode"
finally:
client.clear_partition(partition_id)
def test_custom_metadata_persistence(e2e_client):
"""Set per-sample custom metadata, retrieve it, and verify persistence."""
client = e2e_client
partition_id = "test_custom_meta"
batch_size = 8
task_name = "custom_meta_task"
fields = DEFAULT_FIELDS
meta = client.put(
data=generate_complex_data(list(range(batch_size))),
partition_id=partition_id,
)
assert meta.size == batch_size, f"Expected batch_size {batch_size}, got {meta.size}"
try:
custom_metadata_list = [
{
"score": float(i) / 10.0,
"label": f"label_{i}",
"tags": [f"tag_{i}_a", f"tag_{i}_b"],
}
for i in range(batch_size)
]
meta.update_custom_meta(custom_metadata_list)
client.set_custom_meta(meta)
retrieved_meta = poll_for_meta(client, partition_id, fields, batch_size, task_name, mode="force_fetch")
assert retrieved_meta is not None, "Failed to retrieve metadata"
retrieved_custom = retrieved_meta.get_all_custom_meta()
assert len(retrieved_custom) == batch_size, (
f"Expected {batch_size} custom_meta entries, got {len(retrieved_custom)}"
)
for i, (actual, expected) in enumerate(zip(retrieved_custom, custom_metadata_list, strict=True)):
assert actual["score"] == expected["score"], f"Score mismatch at sample {i}"
assert actual["label"] == expected["label"], f"Label mismatch at sample {i}"
assert actual["tags"] == expected["tags"], f"Tags mismatch at sample {i}"
finally:
client.clear_partition(partition_id)
def test_reset_consumption(e2e_client):
"""Consume data, reset consumption status, verify re-consumability."""
client = e2e_client
partition_id = "test_reset_consumption"
batch_size = 10
task_name = "reset_test_task"
fields = DEFAULT_FIELDS
data = generate_complex_data(list(range(batch_size)))
client.put(data=data, partition_id=partition_id)
try:
is_consumed_initial = client.check_consumption_status(task_name=task_name, partition_id=partition_id)
assert not is_consumed_initial, "Data should not be consumed initially"
meta = poll_for_meta(client, partition_id, fields, batch_size, task_name, mode="fetch")
assert meta is not None and meta.size == batch_size, "Failed to poll metadata"
is_consumed_mid = client.check_consumption_status(task_name=task_name, partition_id=partition_id)
assert is_consumed_mid, "Data should already be consumed after get_meta in fetch mode"
retrieved_data = client.get_data(meta)
assert retrieved_data.batch_size[0] == batch_size, "Retrieved data batch_size mismatch"
is_consumed_after = client.check_consumption_status(task_name=task_name, partition_id=partition_id)
assert is_consumed_after, "Data should be consumed after get_meta in fetch mode"
success = client.reset_consumption(partition_id=partition_id, task_name=task_name)
assert success, "reset_consumption should return True"
is_consumed_reset = client.check_consumption_status(task_name=task_name, partition_id=partition_id)
assert not is_consumed_reset, "Consumption status should be False after reset"
meta_again = poll_for_meta(client, partition_id, fields, batch_size, task_name, mode="fetch")
assert meta_again is not None and meta_again.size == batch_size, "Should be able to fetch metadata again"
finally:
client.clear_partition(partition_id)
def test_clear_partition(e2e_client):
"""Clear partition: verify data removal and production status reset."""
client = e2e_client
partition_id = "test_clear_partition"
batch_size = 15
task_name = "clear_test_task"
fields = DEFAULT_FIELDS
data = generate_complex_data(list(range(batch_size)))
client.put(data=data, partition_id=partition_id)
try:
is_ready = client.check_production_status(data_fields=fields, partition_id=partition_id)
assert is_ready, "Data should be ready after put"
meta = poll_for_meta(client, partition_id, fields, batch_size, task_name, mode="force_fetch")
assert meta is not None and meta.size == batch_size, "Failed to poll metadata"
partition_list_before = client.get_partition_list()
assert partition_id in partition_list_before, "Partition should exist before clear"
client.clear_partition(partition_id)
partition_list_after = client.get_partition_list()
assert partition_id not in partition_list_after, "Partition should be removed after clear"
is_ready_after_clear = client.check_production_status(data_fields=fields, partition_id=partition_id)
assert not is_ready_after_clear, "Production status should be False after clear"
finally:
try:
client.clear_partition(partition_id)
except Exception:
pass
def test_dynamic_tensor_shape_nested_transition(e2e_client):
"""
Test transition from regular tensor to nested tensor.
First put tensors of identical shape, then put tensors of a different shape.
Verify that the field schema marks is_nested=True, and getting all samples returns a nested tensor.
"""
client = e2e_client
partition_id = "test_nested_transition_partition"
task_name = "test_task"
try:
data1 = TensorDict({"dynamic_feature": torch.ones(2, 4)}, batch_size=2)
meta1_put = client.put(data=data1, partition_id=partition_id)
assert meta1_put.size == 2
meta1 = poll_for_meta(client, partition_id, ["dynamic_feature"], 2, task_name, mode="force_fetch")
assert not meta1.field_schema["dynamic_feature"]["is_nested"]
retrieved_1 = client.get_data(meta1)
assert retrieved_1["dynamic_feature"].is_nested
assert len(retrieved_1["dynamic_feature"]) == 2
assert retrieved_1["dynamic_feature"][0].shape == (4,)
alloc_meta2 = client.get_meta(
partition_id=partition_id,
data_fields=["dynamic_feature"],
batch_size=2,
mode="insert",
task_name="allocator",
)
assert alloc_meta2.size == 2
data2 = TensorDict({"dynamic_feature": torch.ones(2, 6)}, batch_size=2)
client.put(data=data2, metadata=alloc_meta2)
meta2 = poll_for_meta(client, partition_id, ["dynamic_feature"], 2, task_name, mode="force_fetch")
assert meta2.field_schema["dynamic_feature"]["is_nested"] is True
meta_all = poll_for_meta(client, partition_id, ["dynamic_feature"], 4, task_name, mode="force_fetch")
assert meta_all.field_schema["dynamic_feature"]["is_nested"] is True
retrieved_all = client.get_data(meta_all)
assert retrieved_all["dynamic_feature"].is_nested is True
assert len(retrieved_all["dynamic_feature"]) == 4
finally:
client.clear_partition(partition_id)
def test_retrieved_data_writability_and_memory_safety(e2e_client):
"""Verify that all data types retrieved via GET are writable and memory-independent.
This test validates the ZMQ copy=False GET path (Plan 1):
- Tensors (f32, i64, bf16, f16): writable after nested tensor creation
- Nested tensors (jagged, strided): writable after as_nested_tensor
- Numpy arrays (float64, bytes string): writable after .copy() in _pack_field_values
- Modifications to retrieved data do not affect stored data (memory independence)
"""
client = e2e_client
partition_id = "test_writability"
batch_size = 8
task_name = "writability_task"
fields = DEFAULT_FIELDS
indices = list(range(batch_size))
original_data = generate_complex_data(indices)
original_meta = client.put(data=original_data, partition_id=partition_id)
global_index_order = original_meta.global_indexes
try:
meta = poll_for_meta(client, partition_id, fields, batch_size, task_name, mode="force_fetch")
assert meta is not None and meta.size == batch_size
local_index_order = recover_local_index(global_index_order, meta.global_indexes)
meta = meta.select_samples(local_index_order)
retrieved = client.get_data(meta)
retrieved["tensor_f32"][0, 0] = 99999.0
assert retrieved["tensor_f32"][0, 0].item() == 99999.0, "tensor_f32 should be writable"
retrieved["tensor_i64"][0, 0] = 88888
assert retrieved["tensor_i64"][0, 0].item() == 88888, "tensor_i64 should be writable"
retrieved["tensor_bf16"][0, 0] = 77.0
assert retrieved["tensor_bf16"][0, 0].item() == 77.0, "tensor_bf16 should be writable"
retrieved["tensor_f16"][0, 0] = 66.0
assert retrieved["tensor_f16"][0, 0].item() == 66.0, "tensor_f16 should be writable"
jagged_vals = retrieved["nested_jagged"].values()
jagged_vals[0] = 55555.0
assert jagged_vals[0].item() == 55555.0, "nested_jagged should be writable"
strided_subs = list(retrieved["nested_strided"].unbind())
strided_subs[0][0, 0] = 44444.0
assert strided_subs[0][0, 0].item() == 44444.0, "nested_strided should be writable"
retrieved["special_val"][0, 2] = 33333.0
assert retrieved["special_val"][0, 2].item() == 33333.0, "special_val should be writable"
np_arr_retrieved = retrieved["np_array"]
if isinstance(np_arr_retrieved, torch.Tensor):
np_arr_retrieved[0, 0] = 22222.0
assert np_arr_retrieved[0, 0].item() == 22222.0, "np_array (as tensor) should be writable"
meta2 = poll_for_meta(client, partition_id, fields, batch_size, task_name, mode="force_fetch")
assert meta2 is not None and meta2.size == batch_size
local_index_order = recover_local_index(global_index_order, meta2.global_indexes)
meta2 = meta2.select_samples(local_index_order)
retrieved2 = client.get_data(meta2)
for i in range(batch_size):
assert torch.allclose(retrieved2["tensor_f32"][i], original_data["tensor_f32"][i]), (
"Modifying retrieved tensor_f32 should not affect stored data"
)
for i in range(batch_size):
assert torch.equal(retrieved2["tensor_i64"][i], original_data["tensor_i64"][i]), (
"Modifying retrieved tensor_i64 should not affect stored data"
)
for i in range(batch_size):
assert torch.equal(retrieved2["tensor_bf16"][i], original_data["tensor_bf16"][i]), (
"Modifying retrieved tensor_bf16 should not affect stored data"
)
for i in range(batch_size):
assert torch.equal(retrieved2["tensor_f16"][i], original_data["tensor_f16"][i]), (
"Modifying retrieved tensor_f16 should not affect stored data"
)
assert verify_nested_tensor_equal(retrieved2["nested_jagged"], original_data["nested_jagged"]), (
"Modifying retrieved nested_jagged should not affect stored data"
)
finally:
client.clear_partition(partition_id)
if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))