import logging
import pytest
import ray
import torch
from transfer_queue.controller import TransferQueueController
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@pytest.fixture(scope="function")
def ray_setup():
if ray.is_initialized():
ray.shutdown()
ray.init(
ignore_reinit_error=True,
runtime_env={"env_vars": {"RAY_DEBUG": "1", "RAY_DEDUP_LOGS": "0"}},
log_to_driver=True,
)
yield
if ray.is_initialized():
ray.shutdown()
logger.info("Ray has been shut down completely after test")
class TestTransferQueueController:
def test_controller_with_single_partition(self, ray_setup):
gbs = 8
num_n_samples = 4
tq_controller = TransferQueueController.remote()
partition_id = "train_0"
data_fields = ["prompt_ids", "attention_mask"]
metadata = ray.get(
tq_controller.get_metadata.remote(
data_fields=data_fields,
batch_size=gbs * num_n_samples,
partition_id=partition_id,
mode="insert",
)
)
assert metadata.global_indexes == list(range(gbs * num_n_samples))
assert metadata.partition_ids[0] == "train_0"
assert metadata.production_status is not None and all(metadata.production_status == 0)
partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id))
assert partition_index_range == list(range(gbs * num_n_samples))
print("✓ Initial get metadata correct")
field_schema = {
"prompt_ids": {"dtype": "torch.int64", "shape": (32,)},
"attention_mask": {"dtype": "torch.bool", "shape": (32,)},
}
success = ray.get(
tq_controller.update_production_status.remote(
partition_id=partition_id,
global_indexes=metadata.global_indexes,
field_schema=field_schema,
custom_backend_meta=None,
)
)
assert success
partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id))
assert partition.production_status is not None
assert partition.production_status.size(0) == gbs * num_n_samples
global_index, production_status = ray.get(
tq_controller.get_production_status.remote(
partition_id=partition_id,
data_fields=data_fields,
)
)
assert torch.equal(global_index, torch.tensor(range(gbs * num_n_samples), dtype=torch.long))
expected_production_status = torch.ones(gbs * num_n_samples, len(metadata.field_names), dtype=torch.int8)
assert torch.equal(production_status, expected_production_status)
print("✓ Get production status returns correct global_index and production_status")
assert partition.total_fields_num == len(data_fields)
assert partition.allocated_fields_num >= partition.total_fields_num
assert torch.equal(
sum(partition.production_status[:, : len(data_fields)]),
torch.Tensor([gbs * num_n_samples, gbs * num_n_samples]),
)
if partition.allocated_fields_num > len(data_fields):
assert torch.equal(
sum(partition.production_status[:, len(data_fields) :]),
torch.zeros(1 * (partition.allocated_fields_num - len(data_fields))),
)
print(f"✓ Updated production status for partition {partition_id}")
global_index, consumption_status = ray.get(
tq_controller.get_consumption_status.remote(
partition_id=partition_id,
task_name="generate_sequences",
)
)
assert torch.equal(global_index, torch.tensor(range(gbs * num_n_samples), dtype=torch.long))
expected_consumption_status_before = torch.zeros(gbs * num_n_samples, dtype=torch.int8)
assert torch.equal(consumption_status, expected_consumption_status_before)
print("✓ Get consumption status returns correct global_index and status (before consumption)")
gen_meta = ray.get(
tq_controller.get_metadata.remote(
data_fields=["prompt_ids"],
batch_size=gbs * num_n_samples,
partition_id=partition_id,
mode="fetch",
task_name="generate_sequences",
)
)
assert gen_meta.global_indexes == list(range(gbs * num_n_samples))
assert gen_meta.partition_ids[0] == "train_0"
assert gen_meta.field_names == ["prompt_ids"]
partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id))
assert torch.equal(partition.consumption_status["generate_sequences"], torch.ones(gbs * num_n_samples))
print("✓ Get metadata in fetch mode correct")
global_index, consumption_status = ray.get(
tq_controller.get_consumption_status.remote(
partition_id=partition_id,
task_name="generate_sequences",
)
)
assert torch.equal(global_index, torch.tensor(range(gbs * num_n_samples), dtype=torch.long))
expected_consumption_status_after = torch.ones(gbs * num_n_samples, dtype=torch.int8)
assert torch.equal(consumption_status, expected_consumption_status_after)
print("✓ Get consumption status returns correct global_index and status (after consumption)")
clear_meta = ray.get(
tq_controller.get_metadata.remote(
data_fields=gen_meta.field_names,
partition_id=partition_id,
mode="force_fetch",
)
)
assert clear_meta.global_indexes == list(range(gbs * num_n_samples))
assert clear_meta.field_names == gen_meta.field_names
print("✓ Clear metadata correct")
ray.get(tq_controller.clear_partition.remote(partition_id))
partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id))
partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id))
assert partition_index_range == []
assert partition is None
print("✓ Clear partition correct")
def test_controller_reset_consumption(self, ray_setup):
"""Test reset_consumption functionality - allows data to be re-consumed"""
gbs = 4
num_n_samples = 2
partition_id = "test_reset_consumption"
tq_controller = TransferQueueController.remote()
data_fields = ["prompt_ids", "attention_mask"]
metadata = ray.get(
tq_controller.get_metadata.remote(
data_fields=data_fields,
batch_size=gbs * num_n_samples,
partition_id=partition_id,
mode="insert",
)
)
assert metadata.global_indexes == list(range(gbs * num_n_samples))
field_schema = {
"prompt_ids": {"dtype": "torch.int64", "shape": (32,)},
"attention_mask": {"dtype": "torch.bool", "shape": (32,)},
}
success = ray.get(
tq_controller.update_production_status.remote(
partition_id=partition_id,
global_indexes=metadata.global_indexes,
field_schema=field_schema,
)
)
assert success
global_index, consumption_status = ray.get(
tq_controller.get_consumption_status.remote(
partition_id=partition_id,
task_name="generate_sequences",
)
)
expected_consumption_before = torch.zeros(gbs * num_n_samples, dtype=torch.int8)
assert torch.equal(consumption_status, expected_consumption_before)
print("✓ Consumption status before fetch is all zeros")
gen_meta = ray.get(
tq_controller.get_metadata.remote(
data_fields=["prompt_ids"],
batch_size=gbs * num_n_samples,
partition_id=partition_id,
mode="fetch",
task_name="generate_sequences",
)
)
assert gen_meta.global_indexes == list(range(gbs * num_n_samples))
global_index, consumption_status = ray.get(
tq_controller.get_consumption_status.remote(
partition_id=partition_id,
task_name="generate_sequences",
)
)
expected_consumption_after = torch.ones(gbs * num_n_samples, dtype=torch.int8)
assert torch.equal(consumption_status, expected_consumption_after)
print("✓ Consumption status after fetch is all ones")
ray.get(
tq_controller.reset_consumption.remote(
partition_id=partition_id,
task_name="generate_sequences",
)
)
global_index, consumption_status = ray.get(
tq_controller.get_consumption_status.remote(
partition_id=partition_id,
task_name="generate_sequences",
)
)
expected_consumption_reset = torch.zeros(gbs * num_n_samples, dtype=torch.int8)
assert torch.equal(consumption_status, expected_consumption_reset)
print("✓ Consumption status after reset is all zeros")
gen_meta_2 = ray.get(
tq_controller.get_metadata.remote(
data_fields=["prompt_ids"],
batch_size=gbs * num_n_samples,
partition_id=partition_id,
mode="fetch",
task_name="generate_sequences",
)
)
assert gen_meta_2.global_indexes == list(range(gbs * num_n_samples))
gen_meta_3 = ray.get(
tq_controller.get_metadata.remote(
data_fields=["attention_mask"],
batch_size=gbs * num_n_samples,
partition_id=partition_id,
mode="fetch",
task_name="another_task",
)
)
assert gen_meta_3.global_indexes == list(range(gbs * num_n_samples))
_, consumption_status_task1 = ray.get(
tq_controller.get_consumption_status.remote(
partition_id=partition_id,
task_name="generate_sequences",
)
)
_, consumption_status_task2 = ray.get(
tq_controller.get_consumption_status.remote(
partition_id=partition_id,
task_name="another_task",
)
)
assert torch.equal(consumption_status_task1, torch.ones(gbs * num_n_samples, dtype=torch.int8))
assert torch.equal(consumption_status_task2, torch.ones(gbs * num_n_samples, dtype=torch.int8))
print("✓ Both tasks consumed successfully")
ray.get(
tq_controller.reset_consumption.remote(
partition_id=partition_id,
task_name=None,
)
)
_, consumption_status_task1_reset = ray.get(
tq_controller.get_consumption_status.remote(
partition_id=partition_id,
task_name="generate_sequences",
)
)
_, consumption_status_task2_reset = ray.get(
tq_controller.get_consumption_status.remote(
partition_id=partition_id,
task_name="another_task",
)
)
assert torch.equal(consumption_status_task1_reset, torch.zeros(gbs * num_n_samples, dtype=torch.int8))
assert torch.equal(consumption_status_task2_reset, torch.zeros(gbs * num_n_samples, dtype=torch.int8))
print("✓ Reset all tasks successful - both tasks have zero consumption status")
ray.get(tq_controller.clear_partition.remote(partition_id))
print("✓ Reset consumption test completed successfully")
def test_controller_with_multi_partitions(self, ray_setup):
gbs_1 = 8
num_n_samples_1 = 4
partition_id_1 = "train_0"
gbs_2 = 16
num_n_samples_2 = 1
partition_id_2 = "val_0"
gbs_3 = 32
num_n_samples_3 = 2
partition_id_3 = "train_1"
tq_controller = TransferQueueController.remote()
data_fields = ["prompt_ids", "attention_mask"]
metadata = ray.get(
tq_controller.get_metadata.remote(
data_fields=data_fields,
batch_size=gbs_1 * num_n_samples_1,
partition_id=partition_id_1,
mode="insert",
)
)
field_schema = {
"prompt_ids": {"dtype": "torch.int64", "shape": (32,)},
"attention_mask": {"dtype": "torch.bool", "shape": (32,)},
}
success = ray.get(
tq_controller.update_production_status.remote(
partition_id=partition_id_1,
global_indexes=metadata.global_indexes,
field_schema=field_schema,
)
)
assert success
global_index_1, production_status_1 = ray.get(
tq_controller.get_production_status.remote(
partition_id=partition_id_1,
data_fields=data_fields,
)
)
expected_global_index_1 = torch.tensor(range(gbs_1 * num_n_samples_1), dtype=torch.long)
assert torch.equal(global_index_1, expected_global_index_1)
expected_production_status_1 = torch.ones(gbs_1 * num_n_samples_1, len(data_fields), dtype=torch.int8)
assert torch.equal(production_status_1, expected_production_status_1)
print("✓ Get production status for partition_1 returns correct global_index and status")
gen_meta = ray.get(
tq_controller.get_metadata.remote(
data_fields=["prompt_ids"],
batch_size=gbs_1 * num_n_samples_1,
partition_id=partition_id_1,
mode="fetch",
task_name="generate_sequences",
)
)
assert gen_meta
global_index_1_consumed, consumption_status_1 = ray.get(
tq_controller.get_consumption_status.remote(
partition_id=partition_id_1,
task_name="generate_sequences",
)
)
assert torch.equal(global_index_1_consumed, expected_global_index_1)
expected_consumption_status_1 = torch.ones(gbs_1 * num_n_samples_1, dtype=torch.int8)
assert torch.equal(consumption_status_1, expected_consumption_status_1)
print("✓ Get consumption status for partition_1 returns correct global_index and status (after fetch)")
clear_meta = ray.get(
tq_controller.get_metadata.remote(
data_fields=gen_meta.field_names,
partition_id=partition_id_1,
mode="force_fetch",
)
)
assert clear_meta
data_fields = ["prompt_ids", "attention_mask"]
val_metadata = ray.get(
tq_controller.get_metadata.remote(
data_fields=data_fields,
batch_size=gbs_2 * num_n_samples_2,
partition_id=partition_id_2,
mode="insert",
)
)
part1_index_range = gbs_1 * num_n_samples_1
part2_index_range = gbs_2 * num_n_samples_2
assert val_metadata.global_indexes == list(range(part1_index_range, part2_index_range + part1_index_range))
assert val_metadata.partition_ids[0] == "val_0"
assert val_metadata.production_status is not None and all(val_metadata.production_status == 0)
partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id_2))
assert partition_index_range == list(range(part1_index_range, part2_index_range + part1_index_range))
field_schema = {
"prompt_ids": {"dtype": "torch.int64", "shape": (32,)},
"attention_mask": {"dtype": "torch.bool", "shape": (32,)},
}
success = ray.get(
tq_controller.update_production_status.remote(
partition_id=partition_id_2,
global_indexes=val_metadata.global_indexes,
field_schema=field_schema,
)
)
assert success
global_index_2, production_status_2 = ray.get(
tq_controller.get_production_status.remote(
partition_id=partition_id_2,
data_fields=data_fields,
)
)
expected_global_index_2 = torch.tensor(
range(part1_index_range, part2_index_range + part1_index_range), dtype=torch.long
)
assert torch.equal(global_index_2, expected_global_index_2)
expected_production_status_2 = torch.ones(part2_index_range, len(data_fields), dtype=torch.int8)
assert torch.equal(production_status_2, expected_production_status_2)
print("✓ Get production status for partition_2 returns correct global_index and status")
global_index_2_consumed, consumption_status_2 = ray.get(
tq_controller.get_consumption_status.remote(
partition_id=partition_id_2,
task_name="generate_sequences",
)
)
assert torch.equal(global_index_2_consumed, expected_global_index_2)
expected_consumption_status_2 = torch.zeros(part2_index_range, dtype=torch.int8)
assert torch.equal(consumption_status_2, expected_consumption_status_2)
print("✓ Get consumption status for partition_2 returns correct global_index and status (before consumption)")
partition_index_range_1 = ray.get(tq_controller.get_partition_index_range.remote(partition_id_1))
assert partition_index_range_1
ray.get(tq_controller.clear_partition.remote(partition_id_1))
partition_1_after_clear = ray.get(tq_controller.get_partition_snapshot.remote(partition_id_1))
partition_index_range_1_after_clear = ray.get(tq_controller.get_partition_index_range.remote(partition_id_1))
assert not partition_index_range_1_after_clear
assert partition_1_after_clear is None
assert partition_index_range_1_after_clear == []
partition_2 = ray.get(tq_controller.get_partition_snapshot.remote(partition_id_2))
partition_index_range_2 = ray.get(tq_controller.get_partition_index_range.remote(partition_id_2))
assert partition_index_range_2 == [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]
assert torch.all(
partition_2.production_status[list(partition_index_range_2), : len(val_metadata.field_names)] == 1
)
print("✓ Only clear partition 1 correct")
metadata_2 = ray.get(
tq_controller.get_metadata.remote(
data_fields=data_fields,
batch_size=gbs_3 * num_n_samples_3,
partition_id=partition_id_3,
mode="insert",
)
)
assert metadata_2.global_indexes == list(range(32)) + list(range(48, 80))
assert metadata_2.partition_ids[0] == "train_1"
assert metadata_2.production_status is not None and all(metadata_2.production_status == 0)
partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id_3))
assert partition_index_range == list(range(32)) + list(range(48, 80))
print("✓ Correctly assign partition_3")
def test_controller_clear_meta(self, ray_setup):
"""Test clear_meta functionality for individual samples"""
gbs = 4
num_n_samples = 2
partition_id = "test_clear_meta"
tq_controller = TransferQueueController.remote()
data_fields = ["prompt_ids", "attention_mask"]
metadata = ray.get(
tq_controller.get_metadata.remote(
data_fields=data_fields,
batch_size=gbs * num_n_samples,
partition_id=partition_id,
mode="insert",
)
)
assert metadata.global_indexes == list(range(gbs * num_n_samples))
field_schema = {
"prompt_ids": {"dtype": "torch.int64", "shape": (32,)},
"attention_mask": {"dtype": "torch.bool", "shape": (32,)},
}
success = ray.get(
tq_controller.update_production_status.remote(
partition_id=partition_id,
global_indexes=metadata.global_indexes,
field_schema=field_schema,
)
)
assert success
partition_before = ray.get(tq_controller.get_partition_snapshot.remote(partition_id))
assert partition_before is not None
assert len(partition_before.global_indexes) == gbs * num_n_samples
assert set(partition_before.global_indexes) == set(range(gbs * num_n_samples))
global_indexes_to_clear = [0, 1, 2, 3, 6]
partition_ids_to_clear = [partition_id] * len(global_indexes_to_clear)
ray.get(
tq_controller.clear_meta.remote(
global_indexes=global_indexes_to_clear,
partition_ids=partition_ids_to_clear,
)
)
partition_after = ray.get(tq_controller.get_partition_snapshot.remote(partition_id))
assert partition_after is not None
assert set(partition_after.global_indexes) == set([4, 5, 7])
print("✓ Clear meta correct")
class TestTransferQueueControllerCustomMeta:
"""Integration tests for TransferQueueController custom_meta and custom_backend_meta methods.
Note: In this codebase:
- custom_meta: per-sample metadata (simple key-value pairs per sample)
- custom_backend_meta: per-sample per-field metadata (stored via update_production_status)
"""
def test_controller_with_custom_meta(self, ray_setup):
"""Test TransferQueueController with custom_backend_meta and custom_meta functionality"""
batch_size = 3
partition_id = "custom_meta_test"
tq_controller = TransferQueueController.remote()
data_fields = ["prompt_ids", "attention_mask"]
metadata = ray.get(
tq_controller.get_metadata.remote(
data_fields=data_fields,
batch_size=batch_size,
partition_id=partition_id,
mode="insert",
)
)
assert metadata.global_indexes == list(range(batch_size))
custom_backend_meta = {
0: {"prompt_ids": {"token_count": 100}, "attention_mask": {"mask_ratio": 0.1}},
1: {"prompt_ids": {"token_count": 120}, "attention_mask": {"mask_ratio": 0.15}},
2: {"prompt_ids": {"token_count": 90}, "attention_mask": {"mask_ratio": 0.12}},
}
field_schema = {
"prompt_ids": {"dtype": "torch.int64", "shape": (32,)},
"attention_mask": {"dtype": "torch.bool", "shape": (32,)},
}
success = ray.get(
tq_controller.update_production_status.remote(
partition_id=partition_id,
global_indexes=metadata.global_indexes,
field_schema=field_schema,
custom_backend_meta=custom_backend_meta,
)
)
assert success
partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id))
assert partition is not None
result = partition.get_field_custom_backend_meta(list(range(batch_size)), ["prompt_ids", "attention_mask"])
assert len(result) == batch_size
assert result[0]["prompt_ids"]["token_count"] == 100
assert result[2]["attention_mask"]["mask_ratio"] == 0.12
print("✓ Controller set custom_backend_meta via update_production_status correct")
custom_meta = {
partition_id: {
0: {"sample_score": 0.9, "quality": "high"},
1: {"sample_score": 0.8, "quality": "medium"},
}
}
ray.get(tq_controller.set_custom_meta.remote(partition_custom_meta=custom_meta))
partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id))
result = partition.get_custom_meta([0, 1])
assert 0 in result
assert result[0]["sample_score"] == 0.9
assert result[0]["quality"] == "high"
assert 1 in result
assert result[1]["sample_score"] == 0.8
assert 2 not in result
new_partition_id = "custom_meta_test2"
data_fields = ["prompt_ids", "attention_mask"]
new_metadata = ray.get(
tq_controller.get_metadata.remote(
data_fields=data_fields,
batch_size=batch_size,
partition_id=new_partition_id,
mode="insert",
)
)
field_schema = {
"prompt_ids": {"dtype": "torch.int64", "shape": (32,)},
"attention_mask": {"dtype": "torch.bool", "shape": (32,)},
}
success = ray.get(
tq_controller.update_production_status.remote(
partition_id=new_partition_id,
global_indexes=new_metadata.global_indexes,
field_schema=field_schema,
custom_backend_meta=None,
)
)
assert success
new_custom_meta = {
new_partition_id: {
3: {"sample_score": 1, "quality": "high"},
4: {"sample_score": 0, "quality": "low"},
},
partition_id: {
2: {"sample_score": 0.7, "quality": "high"},
0: {"sample_score": 0.001, "quality": "low"},
},
}
ray.get(tq_controller.set_custom_meta.remote(partition_custom_meta=new_custom_meta))
partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id))
result = partition.get_custom_meta([0, 1, 2])
assert 0 in result
assert result[0]["sample_score"] == 0.001
assert result[0]["quality"] == "low"
assert 1 in result
assert result[1]["sample_score"] == 0.8
assert 2 in result
assert result[2]["sample_score"] == 0.7
new_partition = ray.get(tq_controller.get_partition_snapshot.remote(new_partition_id))
result = new_partition.get_custom_meta([3, 4, 5])
assert 3 in result
assert result[3]["sample_score"] == 1
assert result[3]["quality"] == "high"
assert 4 in result
assert result[4]["sample_score"] == 0
assert 5 not in result
ray.get(tq_controller.clear_partition.remote(partition_id))
class TestTransferQueueControllerKvInterface:
"""End-to-end tests for TransferQueueController KV interface functionality.
Tests for kv_retrieve_meta method that supports key-value interface operations
across the controller and partition layers.
"""
def test_controller_kv_retrieve_meta_create_mode(self, ray_setup):
"""Test kv_retrieve_meta with create=True creates new keys in partition."""
tq_controller = TransferQueueController.remote()
partition_id = "kv_test_partition"
keys = ["key_a", "key_b", "key_c"]
metadata = ray.get(tq_controller.kv_retrieve_meta.remote(keys=keys, partition_id=partition_id, create=True))
partitions = ray.get(tq_controller.list_partitions.remote())
assert partition_id in partitions
assert len(metadata.global_indexes) == len(keys)
partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id))
assert "key_a" in partition.keys_mapping
assert "key_b" in partition.keys_mapping
assert "key_c" in partition.keys_mapping
assert metadata.global_indexes[0] == partition.keys_mapping["key_a"]
assert metadata.global_indexes[1] == partition.keys_mapping["key_b"]
assert metadata.global_indexes[2] == partition.keys_mapping["key_c"]
assert partition.revert_keys_mapping[metadata.global_indexes[0]] == "key_a"
assert partition.revert_keys_mapping[metadata.global_indexes[1]] == "key_b"
assert partition.revert_keys_mapping[metadata.global_indexes[2]] == "key_c"
print("✓ kv_retrieve_meta with create=True creates keys correctly")
ray.get(tq_controller.clear_partition.remote(partition_id))
def test_controller_kv_retrieve_meta_existing_keys(self, ray_setup):
"""Test kv_retrieve_meta retrieves existing keys correctly."""
tq_controller = TransferQueueController.remote()
partition_id = "kv_existing_test"
keys = ["existing_key_1", "existing_key_2"]
ray.get(tq_controller.kv_retrieve_meta.remote(keys=keys, partition_id=partition_id, create=True))
retrieved_metadata = ray.get(
tq_controller.kv_retrieve_meta.remote(keys=keys, partition_id=partition_id, create=False)
)
assert len(retrieved_metadata.global_indexes) == len(keys)
print("✓ kv_retrieve_meta retrieves existing keys correctly")
ray.get(tq_controller.clear_partition.remote(partition_id))
def test_controller_kv_retrieve_meta_non_existent_without_create(self, ray_setup):
"""Test kv_retrieve_meta raises error for non-existent keys without create."""
tq_controller = TransferQueueController.remote()
partition_id = "kv_nonexistent_test"
ray.get(tq_controller.kv_retrieve_meta.remote(keys=["initial_key"], partition_id=partition_id, create=True))
batch_meta = ray.get(
tq_controller.kv_retrieve_meta.remote(keys=["nonexistent_key"], partition_id=partition_id, create=False)
)
assert batch_meta.size == 0
print("✓ kv_retrieve_meta return an empty BatchMeta for non-existent keys without create")
ray.get(tq_controller.clear_partition.remote(partition_id))
def test_controller_kv_retrieve_meta_empty_partition_without_create(self, ray_setup):
"""Test kv_retrieve_meta raises error for non-existent partition without create."""
tq_controller = TransferQueueController.remote()
partition_id = "nonexistent_partition"
batch_meta = ray.get(
tq_controller.kv_retrieve_meta.remote(keys=["key_1"], partition_id=partition_id, create=False)
)
assert batch_meta.size == 0
print("✓ kv_retrieve_meta return an empty BatchMeta for non-existent partition_id without create")
def test_controller_kv_retrieve_meta_with_production_status(self, ray_setup):
"""Test kv_retrieve_meta works with production status update."""
tq_controller = TransferQueueController.remote()
partition_id = "kv_production_test"
keys = ["sample_1", "sample_2", "sample_3"]
metadata = ray.get(tq_controller.kv_retrieve_meta.remote(keys=keys, partition_id=partition_id, create=True))
global_indexes = metadata.global_indexes
field_schema = {"data": {"dtype": "torch.float32", "shape": (64,)}}
success = ray.get(
tq_controller.update_production_status.remote(
partition_id=partition_id,
global_indexes=global_indexes,
field_schema=field_schema,
)
)
assert success
retrieved_metadata = ray.get(
tq_controller.kv_retrieve_meta.remote(keys=keys, partition_id=partition_id, create=False)
)
assert len(retrieved_metadata.global_indexes) == len(keys)
assert "data" in retrieved_metadata.field_schema
print("✓ kv_retrieve_meta works with production status")
ray.get(tq_controller.clear_partition.remote(partition_id))
def test_controller_kv_retrieve_meta_with_custom_meta(self, ray_setup):
"""Test kv_retrieve_meta preserves custom_meta through retrieve."""
tq_controller = TransferQueueController.remote()
partition_id = "kv_custom_meta_test"
keys = ["key_1", "key_2"]
metadata = ray.get(tq_controller.kv_retrieve_meta.remote(keys=keys, partition_id=partition_id, create=True))
custom_meta = {
partition_id: {
metadata.global_indexes[0]: {"score": 0.9, "tag": "A"},
metadata.global_indexes[1]: {"score": 0.8, "tag": "B"},
}
}
ray.get(tq_controller.set_custom_meta.remote(partition_custom_meta=custom_meta))
retrieved_metadata = ray.get(
tq_controller.kv_retrieve_meta.remote(keys=keys, partition_id=partition_id, create=False)
)
all_custom_meta = retrieved_metadata.get_all_custom_meta()
assert len(all_custom_meta) == 2
assert all_custom_meta[0]["score"] == 0.9
assert all_custom_meta[1]["tag"] == "B"
print("✓ kv_retrieve_meta preserves custom_meta")
ray.get(tq_controller.clear_partition.remote(partition_id))
def test_controller_kv_interface_multiple_partitions(self, ray_setup):
"""Test KV interface works correctly across multiple partitions."""
tq_controller = TransferQueueController.remote()
partition_1 = "partition_kv_1"
keys_1 = ["p1_key_a", "p1_key_b"]
ray.get(tq_controller.kv_retrieve_meta.remote(keys=keys_1, partition_id=partition_1, create=True))
partition_2 = "partition_kv_2"
keys_2 = ["p2_key_x", "p2_key_y", "p2_key_z"]
ray.get(tq_controller.kv_retrieve_meta.remote(keys=keys_2, partition_id=partition_2, create=True))
partition_1_snapshot = ray.get(tq_controller.get_partition_snapshot.remote(partition_1))
partition_2_snapshot = ray.get(tq_controller.get_partition_snapshot.remote(partition_2))
assert "p1_key_a" in partition_1_snapshot.keys_mapping
assert "p1_key_b" in partition_1_snapshot.keys_mapping
assert "p2_key_x" in partition_2_snapshot.keys_mapping
assert "p2_key_z" in partition_2_snapshot.keys_mapping
assert "p2_key_x" not in partition_1_snapshot.keys_mapping
assert "p1_key_a" not in partition_2_snapshot.keys_mapping
print("✓ KV interface maintains partition isolation")
ray.get(tq_controller.clear_partition.remote(partition_1))
ray.get(tq_controller.clear_partition.remote(partition_2))
def test_controller_kv_retrieve_keys_basic(self, ray_setup):
"""Test kv_retrieve_keys retrieves keys from global_indexes."""
tq_controller = TransferQueueController.remote()
partition_id = "partition_retrieve_idx"
keys = ["test_key_a", "test_key_b", "test_key_c"]
ray.get(tq_controller.kv_retrieve_meta.remote(keys=keys, partition_id=partition_id, create=True))
retrieved_keys = ray.get(
tq_controller.kv_retrieve_keys.remote(global_indexes=[0, 1, 2], partition_id=partition_id)
)
assert retrieved_keys == ["test_key_a", "test_key_b", "test_key_c"]
print("✓ kv_retrieve_keys retrieves keys correctly")
ray.get(tq_controller.clear_partition.remote(partition_id))
def test_controller_kv_retrieve_keys_partial(self, ray_setup):
"""Test kv_retrieve_keys retrieves subset of keys."""
tq_controller = TransferQueueController.remote()
partition_id = "partition_retrieve_partial"
keys = ["key_0", "key_1", "key_2", "key_3", "key_4"]
ray.get(tq_controller.kv_retrieve_meta.remote(keys=keys, partition_id=partition_id, create=True))
retrieved_keys = ray.get(
tq_controller.kv_retrieve_keys.remote(global_indexes=[0, 4], partition_id=partition_id)
)
assert retrieved_keys == ["key_0", "key_4"]
print("✓ kv_retrieve_keys retrieves subset correctly")
ray.get(tq_controller.clear_partition.remote(partition_id))
def test_controller_kv_retrieve_keys_single_int(self, ray_setup):
"""Test kv_retrieve_keys with list containing single element."""
tq_controller = TransferQueueController.remote()
partition_id = "partition_single_int"
ray.get(tq_controller.kv_retrieve_meta.remote(keys=["single_key"], partition_id=partition_id, create=True))
retrieved_keys = ray.get(tq_controller.kv_retrieve_keys.remote(global_indexes=[0], partition_id=partition_id))
assert retrieved_keys == ["single_key"]
print("✓ kv_retrieve_keys works with list containing single element")
ray.get(tq_controller.clear_partition.remote(partition_id))
def test_controller_kv_retrieve_keys_nonexistent(self, ray_setup):
"""Test kv_retrieve_keys handles non-existent global_indexes."""
tq_controller = TransferQueueController.remote()
partition_id = "partition_nonexistent"
ray.get(tq_controller.kv_retrieve_meta.remote(keys=["existing_key"], partition_id=partition_id, create=True))
result = ray.get(tq_controller.kv_retrieve_keys.remote(global_indexes=[99], partition_id=partition_id))
assert result == [None]
print("✓ kv_retrieve_keys handles non-existent indexes")
ray.get(tq_controller.clear_partition.remote(partition_id))
def test_controller_kv_retrieve_keys_multiple_partitions(self, ray_setup):
"""Test kv_retrieve_keys respects partition isolation."""
tq_controller = TransferQueueController.remote()
partition_1 = "partition_idx_1"
partition_2 = "partition_idx_2"
ray.get(tq_controller.kv_retrieve_meta.remote(keys=["p1_key"], partition_id=partition_1, create=True))
ray.get(tq_controller.kv_retrieve_meta.remote(keys=["p2_key"], partition_id=partition_2, create=True))
keys_1 = ray.get(tq_controller.kv_retrieve_keys.remote(global_indexes=[0], partition_id=partition_1))
keys_2 = ray.get(tq_controller.kv_retrieve_keys.remote(global_indexes=[1], partition_id=partition_2))
assert keys_1 == ["p1_key"]
assert keys_2 == ["p2_key"]
print("✓ kv_retrieve_keys maintains partition isolation")
ray.get(tq_controller.clear_partition.remote(partition_1))
ray.get(tq_controller.clear_partition.remote(partition_2))