import logging
import os
import time
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
TQ_INIT_SAMPLE_NUM = int(os.environ.get("TQ_INIT_SAMPLE_NUM", 1))
TQ_INIT_FIELD_NUM = int(os.environ.get("TQ_INIT_FIELD_NUM", 1))
def test_data_partition_status():
"""Test the DataPartitionStatus class functionality."""
print("Testing DataPartitionStatus...")
from transfer_queue.controller import DataPartitionStatus
partition = DataPartitionStatus(partition_id="test@partition_1")
assert partition.total_samples_num == 0
assert partition.allocated_samples_num == TQ_INIT_SAMPLE_NUM
assert partition.total_fields_num == 0
assert partition.allocated_fields_num == TQ_INIT_FIELD_NUM
assert partition.production_status is not None
print("✓ Initial state correct")
success = partition.update_production_status(
global_indices=[0, 1, 2],
field_names=["input_ids", "attention_mask"],
field_schema={
"input_ids": {"dtype": "torch.int32", "shape": (512,), "is_nested": False, "is_non_tensor": False},
"attention_mask": {"dtype": "torch.bool", "shape": (512,), "is_nested": False, "is_non_tensor": False},
},
custom_backend_meta=None,
)
assert success
assert partition.total_samples_num >= 3
assert partition.total_fields_num == 2
assert partition.production_status is not None
assert partition.production_status.shape[0] >= 3
assert partition.production_status.shape[1] >= 2
print("✓ Dynamic expansion works")
assert "input_ids" in partition.field_metadata
assert partition.field_metadata["input_ids"].dtype == "torch.int32"
assert "attention_mask" in partition.field_metadata
assert partition.field_metadata["attention_mask"].shape == (512,)
print("✓ Field metadata retrieval works")
global_index, consumption_tensor = partition.get_consumption_status("test_task", mask=False)
assert consumption_tensor is not None
assert consumption_tensor.shape[0] == partition.allocated_samples_num
print("✓ Consumption status creation works")
partition.mark_consumed("test_task", [0, 1])
assert consumption_tensor[0] == 1
assert consumption_tensor[1] == 1
assert consumption_tensor[2] == 0
print("✓ Sample consumption marking works")
ready_samples = partition.scan_data_status(field_names=["input_ids", "attention_mask"], task_name="test_task")
assert len(ready_samples) == 1, f"Expected 1 ready sample, got {len(ready_samples)}: {ready_samples}"
assert ready_samples == [2], f"Expected [2], got {ready_samples}"
print("✓ Ready sample scanning works")
stats = partition.get_statistics()
assert stats["partition_id"] == "test@partition_1"
assert stats["total_samples_num"] == partition.total_samples_num
assert stats["total_fields_num"] == 2
assert "consumption_statistics" in stats
print("✓ Statistics generation works")
print("DataPartitionStatus tests passed!\n")
def test_partition_interface():
"""Test the partition interface design."""
print("Testing partition interface design...")
from transfer_queue.controller import TransferQueueController
assert hasattr(TransferQueueController, "create_partition")
assert hasattr(TransferQueueController, "get_partition_snapshot")
assert hasattr(TransferQueueController, "update_production_status")
assert hasattr(TransferQueueController, "scan_data_status")
assert hasattr(TransferQueueController, "generate_batch_meta")
print("✓ Controller has all expected methods")
import inspect
sig = inspect.signature(TransferQueueController.create_partition)
params = list(sig.parameters.keys())
assert "partition_id" in params
assert "num_samples" not in params
print("✓ Method signatures are correct")
print("Partition interface tests passed!\n")
def test_dynamic_expansion_scenarios():
"""Test various dynamic expansion scenarios."""
print("Testing dynamic expansion scenarios...")
from transfer_queue.controller import DataPartitionStatus
partition = DataPartitionStatus(partition_id="expansion_test")
partition.update_production_status(
global_indices=[0, 5, 10],
field_names=["field_1"],
field_schema={
"field_1": {"dtype": "torch.bool", "shape": (32,)},
},
custom_backend_meta=None,
)
assert partition.total_samples_num == 3
assert partition.allocated_samples_num >= 11
print("✓ Large index gaps handled correctly")
for i in range(15):
partition.update_production_status(
[0],
[f"field_{i}"],
field_schema={f"field_{i}": {"dtype": "torch.bool", "shape": (32,)}},
)
assert partition.total_fields_num == 15
assert partition.allocated_fields_num >= 15
print("✓ Dynamic field expansion works")
tasks = ["task1", "task2", "task3"]
for task in tasks:
partition.get_consumption_status(task)
partition.mark_consumed(task, [0, 1])
assert len(partition.consumption_status) == 3
for task in tasks:
assert partition.consumption_status[task][0] == 1
assert partition.consumption_status[task][1] == 1
print("✓ Multiple task consumption works")
print("Dynamic expansion tests passed!\n")
def test_data_partition_status_advanced():
"""Advanced tests for DataPartitionStatus refactoring features."""
print("Testing advanced DataPartitionStatus features...")
from transfer_queue.controller import DataPartitionStatus
partition = DataPartitionStatus(partition_id="advanced_test")
assert partition.total_samples_num == 0
assert partition.allocated_samples_num == TQ_INIT_SAMPLE_NUM
assert partition.total_fields_num == 0
assert partition.allocated_fields_num == TQ_INIT_FIELD_NUM
field_schema = {f"dynamic_field_{s}": {"dtype": "torch.bool", "shape": (32,)} for s in ["a", "b", "c"]}
partition.update_production_status(
[0, 1, 2, 3, 4],
["dynamic_field_a", "dynamic_field_b", "dynamic_field_c"],
field_schema=field_schema,
)
assert partition.total_samples_num >= 5
assert partition.total_fields_num == 3
assert partition.allocated_fields_num >= 3
print("✓ Property-based capacity tracking works")
task_name = "multi_expansion_task"
partition.mark_consumed(task_name, [0, 1])
global_index, initial_consumption = partition.get_consumption_status(task_name)
assert initial_consumption[0] == 1
assert initial_consumption[1] == 1
partition.update_production_status(
[10, 11, 12],
["field_d"],
field_schema={"field_d": {"dtype": "torch.bool", "shape": (32,)}},
)
global_index, expanded_consumption = partition.get_consumption_status(task_name)
assert expanded_consumption[0] == 1
assert expanded_consumption[1] == 1
assert expanded_consumption.shape[0] >= 13
print("✓ Consumption data preserved across expansions")
partition.update_production_status(
[0],
["initial_field"],
field_schema={"initial_field": {"dtype": "torch.bool", "shape": (32,)}},
)
new_fields = [f"dynamic_field_{i}" for i in range(20)]
field_schema = {f"dynamic_field_{i}": {"dtype": "torch.bool", "shape": (32,)} for i in range(20)}
partition.update_production_status([1], new_fields, field_schema=field_schema)
assert "initial_field" in partition.field_name_mapping
for field in new_fields:
assert field in partition.field_name_mapping
expected_fields = 1 + len(new_fields)
assert partition.total_fields_num >= expected_fields
assert partition.allocated_fields_num >= partition.total_fields_num
print("✓ Complex field addition scenarios work")
stats = partition.get_statistics()
required_keys = [
"partition_id",
"created_at",
"total_samples_num",
"total_fields_num",
"allocated_samples_num",
"allocated_fields_num",
"registered_tasks",
"produced_samples",
"production_progress",
"field_statistics",
"consumption_statistics",
]
for key in required_keys:
assert key in stats, f"Missing key in statistics: {key}"
assert stats["partition_id"] == "advanced_test"
assert stats["total_fields_num"] > 0
assert isinstance(stats["field_statistics"], dict)
assert isinstance(stats["consumption_statistics"], dict)
print("✓ Statistics generation comprehensive")
initial_consumption_sum = sum(t.sum().item() for t in partition.consumption_status.values())
partition.clear_data(list(range(4)), clear_consumption=False)
assert partition.production_status[:4, :].sum().item() == 0
remaining_consumption_sum = sum(t.sum().item() for t in partition.consumption_status.values())
assert remaining_consumption_sum == initial_consumption_sum
print("✓ Selective data clearing works")
print("Advanced DataPartitionStatus tests passed!\n")
def test_edge_cases_and_error_handling():
"""Test edge cases and error handling in DataPartitionStatus."""
print("Testing edge cases and error handling...")
from transfer_queue.controller import DataPartitionStatus
partition = DataPartitionStatus(partition_id="edge_test")
ready_samples = partition.scan_data_status(["nonexistent_field"], "task")
assert ready_samples == []
print("✓ Empty partition operations handled gracefully")
assert "nonexistent_field" not in partition.field_metadata
print("✓ Metadata retrieval for non-existent data handled correctly")
task_name = "early_task"
_, consumption_tensor = partition.get_consumption_status(task_name)
assert consumption_tensor is not None
assert consumption_tensor.shape[0] == partition.allocated_samples_num
success = partition.update_production_status([], [], {}, {})
assert success
field_schema = {"new_field": {"dtype": "torch.int64", "shape": (32,)}}
success = partition.update_production_status([0], ["new_field"], field_schema=field_schema)
assert success
print("✓ Production status update edge cases handled correctly")
print("Edge cases and error handling tests passed!\n")
def test_performance_characteristics():
"""Test performance characteristics of the refactored implementation."""
print("Testing performance characteristics...")
from transfer_queue.controller import DataPartitionStatus
partition = DataPartitionStatus(partition_id="perf_test")
start_time = time.time()
field_count = 100
many_fields = [f"perf_field_{i}" for i in range(field_count)]
field_schema = {f"perf_field_{i}": {"dtype": "torch.bool", "shape": (32,)} for i in range(field_count)}
partition.update_production_status([0], many_fields, field_schema)
field_creation_time = time.time() - start_time
assert partition.total_fields_num == field_count
assert field_creation_time < 5.0
print(f"✓ Large field creation: {field_creation_time:.3f}s for {field_count} fields")
start_time = time.time()
many_samples = list(range(5000))
field_schema = {"test_field": {"dtype": "torch.int64", "shape": (32,)}}
partition.update_production_status(many_samples, ["test_field"], field_schema=field_schema)
sample_creation_time = time.time() - start_time
assert partition.total_samples_num >= 5000
assert sample_creation_time < 5.0
print(f"✓ Large sample creation: {sample_creation_time:.3f}s for 5000 samples")
task_name = "perf_task"
partition.mark_consumed(task_name, many_samples[::2])
start_time = time.time()
ready_samples = partition.scan_data_status(["test_field"], task_name)
scanning_time = time.time() - start_time
assert len(ready_samples) == 2500
assert scanning_time < 1.0
print(f"✓ Efficient scanning: {scanning_time:.3f}s for 5000 samples")
initial_allocated = partition.allocated_fields_num
initial_samples = partition.total_samples_num
field_schema = {"new_field": {"dtype": "torch.int64", "shape": (32,)}}
partition.update_production_status([100], ["new_field"], field_schema=field_schema)
final_allocated = partition.allocated_fields_num
final_samples = partition.total_samples_num
if final_samples == initial_samples:
assert final_allocated < initial_allocated * 2
print("✓ Memory usage patterns reasonable")
print("Performance characteristics tests passed!\n")
def test_custom_meta_in_data_partition_status():
"""Simple tests for custom_meta and custom_backend_meta functionality in DataPartitionStatus."""
print("Testing custom_meta and custom_backend_meta in DataPartitionStatus...")
from transfer_queue.controller import DataPartitionStatus
partition = DataPartitionStatus(partition_id="custom_meta_test")
global_indices = [0, 1, 2]
field_names = ["input_ids", "attention_mask"]
field_schema = {
"input_ids": {"dtype": "torch.int32", "shape": (512,)},
"attention_mask": {"dtype": "torch.bool", "shape": (512,)},
}
custom_backend_meta = {
0: {"input_ids": {"token_count": 100}},
1: {"attention_mask": {"mask_ratio": 0.2}},
2: {"input_ids": {"token_count": 300}},
}
success = partition.update_production_status(
global_indices=global_indices,
field_names=field_names,
field_schema=field_schema,
custom_backend_meta=custom_backend_meta,
)
assert success
retrieved_backend = partition.get_field_custom_backend_meta([0, 1, 2], ["input_ids", "attention_mask"])
assert 0 in retrieved_backend
assert retrieved_backend[0]["input_ids"]["token_count"] == 100
assert 1 in retrieved_backend
assert retrieved_backend[1]["attention_mask"]["mask_ratio"] == 0.2
partition.set_custom_meta({0: {"sample_score": 0.9}, 1: {"sample_score": 0.8}})
retrieved_custom = partition.get_custom_meta([0, 1])
assert 0 in retrieved_custom
assert retrieved_custom[0]["sample_score"] == 0.9
assert 1 in retrieved_custom
assert retrieved_custom[1]["sample_score"] == 0.8
partition.clear_data([0], clear_consumption=True)
result_custom = partition.get_custom_meta([0, 1])
assert 0 not in result_custom
assert 1 in result_custom
result_backend = partition.get_field_custom_backend_meta([0, 1, 2], ["input_ids", "attention_mask"])
assert 0 not in result_backend
assert 2 in result_backend
print("✓ Custom_meta and custom_backend_meta tests passed")
class TestUpdateFieldMetadata:
"""Unit tests for _update_field_metadata with columnar field_schema."""
def _make_partition(self):
from transfer_queue.controller import DataPartitionStatus
return DataPartitionStatus(partition_id="update_meta_test")
def test_basic_write_and_incremental_add(self):
partition = self._make_partition()
partition._update_field_metadata([0, 1], {"f1": {"dtype": "torch.int32", "shape": (16,)}})
assert partition.field_metadata["f1"].dtype == "torch.int32"
assert partition.field_metadata["f1"].shape == (16,)
partition._update_field_metadata([2], {"f2": {"dtype": "torch.float32", "shape": (256,)}})
assert partition.field_metadata["f2"].dtype == "torch.float32"
def test_dtype_conflict_raises_error(self):
partition = self._make_partition()
partition._update_field_metadata([0], {"f1": {"dtype": "torch.int32", "shape": (16,)}})
import pytest
with pytest.raises(ValueError, match="dtype mismatch"):
partition._update_field_metadata([1], {"f1": {"dtype": "torch.float64", "shape": (16,)}})
def test_shape_conflict_promotes_to_nested(self):
partition = self._make_partition()
partition._update_field_metadata([0], {"f2": {"dtype": "torch.float32", "shape": (256,)}})
partition._update_field_metadata([1], {"f2": {"dtype": "torch.float32", "shape": (128,)}})
assert partition.field_metadata["f2"].is_nested is True
assert partition.field_metadata["f2"].shape is None
def test_nested_per_sample_shapes(self):
partition = self._make_partition()
schema = {
"f3": {
"dtype": "torch.float32",
"shape": None,
"is_nested": True,
"per_sample_shapes": {10: (3,), 11: (5,)},
}
}
partition._update_field_metadata([10, 11], schema)
assert partition.field_metadata["f3"].is_nested is True
assert partition.field_metadata["f3"].per_sample_shapes == {10: (3,), 11: (5,)}
def test_custom_backend_meta(self):
partition = self._make_partition()
partition._update_field_metadata(
[2], {"f1": {"dtype": "torch.int32"}}, custom_backend_meta={2: {"f1": {"k": 1}}}
)
assert partition.field_custom_backend_meta[2]["f1"]["k"] == 1
def test_empty_global_indexes_is_noop(self):
partition = self._make_partition()
partition._update_field_metadata([], {}, custom_backend_meta=None)
assert partition.field_metadata == {}
def test_get_production_status_for_fields():
"""Test get_production_status_for_fields method with mask parameter."""
print("Testing get_production_status_for_fields...")
import torch
from transfer_queue.controller import DataPartitionStatus
partition = DataPartitionStatus(partition_id="production_status_test")
partition.update_production_status(
global_indices=[0, 1, 2, 3, 9],
field_names=["field_a", "field_b"],
field_schema={
"field_a": {"dtype": "torch.int64", "shape": (32,)},
"field_b": {"dtype": "torch.bool", "shape": (32,)},
},
)
global_index, production_status = partition.get_production_status_for_fields(
field_names=["field_a", "field_b"], mask=False
)
assert torch.equal(global_index, torch.tensor([0, 1, 2, 3, 9], dtype=torch.long))
assert production_status.shape[0] == partition.allocated_samples_num
assert torch.all(production_status[0] == 1), "Sample 0 should be produced"
assert torch.all(production_status[1] == 1), "Sample 1 should be produced"
assert torch.all(production_status[2] == 1), "Sample 2 should be produced"
assert torch.all(production_status[3] == 1), "Sample 3 should be produced"
assert torch.all(production_status[9] == 1), "Sample 9 should be produced"
assert production_status.shape[1] == 2
print("✓ get_production_status_for_fields without mask works")
global_index_masked, production_status_masked = partition.get_production_status_for_fields(
field_names=["field_a", "field_b"], mask=True
)
assert torch.equal(global_index_masked, torch.tensor([0, 1, 2, 3, 9], dtype=torch.long))
assert production_status_masked.shape == (len([0, 1, 2, 3, 9]), 2)
assert torch.all(production_status_masked == 1)
print("✓ get_production_status_for_fields with mask works")
global_index_subset, production_status_subset = partition.get_production_status_for_fields(
field_names=["field_a"], mask=True
)
assert global_index_subset.shape[0] == len([0, 1, 2, 3, 9])
assert production_status_subset.shape == (len([0, 1, 2, 3, 9]), 1)
print("✓ get_production_status_for_fields with subset fields works")
print("get_production_status_for_fields tests passed!\n")
def test_get_consumption_status_parameter():
"""Test get_consumption_status method with mask parameter."""
print("Testing consumption status mask parameter...")
import torch
from transfer_queue.controller import DataPartitionStatus
partition = DataPartitionStatus(partition_id="consumption_mask_test")
partition_another = DataPartitionStatus(partition_id="other_partition")
partition.update_production_status(
global_indices=[0, 1, 2, 3, 9],
field_names=["field_a"],
field_schema={"field_a": {"dtype": "torch.int64", "shape": (32,)}},
)
partition_another.update_production_status(
global_indices=[5, 6, 7],
field_names=["field_a"],
field_schema={"field_a": {"dtype": "torch.int64", "shape": (32,)}},
)
partition.mark_consumed("test_task", [0, 2])
global_index, consumption_status = partition.get_consumption_status("test_task", mask=False)
assert global_index.shape[0] == partition.total_samples_num
assert torch.equal(global_index, torch.tensor([0, 1, 2, 3, 9], dtype=torch.long))
assert consumption_status.shape[0] == 10
assert consumption_status[0].item() == 1
assert consumption_status[1].item() == 0
assert consumption_status[2].item() == 1
assert consumption_status[3].item() == 0
assert consumption_status[4].item() == 0
assert consumption_status[5].item() == 0
assert consumption_status[6].item() == 0
assert consumption_status[7].item() == 0
assert consumption_status[8].item() == 0
assert consumption_status[9].item() == 0
print("✓ get_consumption_status without mask works")
global_index_masked, consumption_status_masked = partition.get_consumption_status("test_task", mask=True)
assert global_index_masked.shape[0] == partition.total_samples_num
assert torch.equal(global_index_masked, torch.tensor([0, 1, 2, 3, 9], dtype=torch.long))
assert consumption_status_masked.shape[0] == partition.total_samples_num
assert consumption_status_masked[0].item() == 1
assert consumption_status_masked[1].item() == 0
assert consumption_status_masked[2].item() == 1
assert consumption_status_masked[3].item() == 0
assert consumption_status_masked[4].item() == 0
print("✓ get_consumption_status with mask works")
print("Consumption status mask parameter tests passed!\n")
def test_pre_allocated_indexes_basic():
"""Test basic pre-allocated indexes functionality in DataPartitionStatus."""
from transfer_queue.controller import DataPartitionStatus
print("Testing pre-allocated indexes basic functionality...")
partition = DataPartitionStatus(partition_id="prealloc_test")
assert len(partition.pre_allocated_global_indexes) == 0
assert partition.total_samples_num == 0
print("✓ Initial state correct")
pre_allocated = [0, 1, 2, 3, 4]
partition.register_pre_allocated_indexes(pre_allocated)
assert partition.pre_allocated_global_indexes == set(pre_allocated)
assert partition.total_samples_num == 0
print("✓ Pre-allocated indexes registered")
retrieved = partition.activate_pre_allocated_indexes(3)
assert len(retrieved) == 3
assert set(retrieved) == {0, 1, 2}
assert partition.global_indexes == {0, 1, 2}
assert partition.pre_allocated_global_indexes == {3, 4}
assert partition.total_samples_num == 3
print("✓ Pre-allocated indexes activate & retrieved correctly")
retrieved = partition.activate_pre_allocated_indexes(5)
assert len(retrieved) == 2
assert set(retrieved) == {3, 4}
assert partition.global_indexes == {0, 1, 2, 3, 4}
assert partition.pre_allocated_global_indexes == set()
assert partition.total_samples_num == 5
print("✓ All pre-allocated indexes retrieved")
print("Pre-allocated indexes basic tests passed!\n")
def test_pre_allocated_indexes_consumption_status():
"""Test that pre-allocated indexes are included in consumption status."""
import torch
from transfer_queue.controller import DataPartitionStatus
print("Testing pre-allocated indexes in consumption status...")
partition = DataPartitionStatus(partition_id="consumption_test")
partition.register_pre_allocated_indexes([0, 1, 2, 3, 4])
global_index, consumption_status = partition.get_consumption_status("test_task", mask=True)
assert torch.equal(global_index, torch.tensor([0, 1, 2, 3, 4], dtype=torch.long))
assert torch.all(consumption_status == 0)
print("✓ Consumption status includes pre-allocated indexes")
partition.mark_consumed("test_task", [0, 2, 4])
global_index, consumption_status = partition.get_consumption_status("test_task", mask=True)
assert consumption_status[0].item() == 1
assert consumption_status[1].item() == 0
assert consumption_status[2].item() == 1
assert consumption_status[3].item() == 0
assert consumption_status[4].item() == 1
print("✓ Marked consumed works with pre-allocated indexes")
print("Pre-allocated indexes consumption status tests passed!\n")
def test_pre_allocated_indexes_in_scan_data_status():
"""Test that pre-allocated indexes affect scan_data_status behavior."""
from transfer_queue.controller import DataPartitionStatus
print("Testing pre-allocated indexes in scan_data_status...")
partition = DataPartitionStatus(partition_id="scan_test")
partition.register_pre_allocated_indexes([0, 1, 2, 3, 4])
ready = partition.scan_data_status(field_names=["input_ids"], task_name="test_task")
assert ready == []
print("✓ Scan returns empty before production")
partition.update_production_status(
global_indices=[0, 2, 4],
field_names=["input_ids"],
field_schema={"input_ids": {"dtype": "torch.int32", "shape": (32,)}},
)
ready = partition.scan_data_status(field_names=["input_ids"], task_name="test_task")
assert set(ready) == {0, 2, 4}
print("✓ Scan returns produced samples correctly")
partition.mark_consumed("test_task", [2])
ready = partition.scan_data_status(field_names=["input_ids"], task_name="test_task")
assert set(ready) == {0, 4}
print("✓ Scan respects consumption status")
print("Pre-allocated indexes scan_data_status tests passed!\n")
def test_pre_allocated_indexes_mixed_with_dynamic():
"""Test mixing pre-allocated indexes with dynamically allocated ones."""
from transfer_queue.controller import DataPartitionStatus
print("Testing mixed pre-allocated and dynamic indexes...")
partition = DataPartitionStatus(partition_id="mixed_test")
partition.register_pre_allocated_indexes([0, 1, 2])
partition.update_production_status(
global_indices=[5, 6, 7],
field_names=["input_ids"],
field_schema={"input_ids": {"dtype": "torch.int32", "shape": (32,)}},
)
assert partition.global_indexes == {5, 6, 7}
assert partition.total_samples_num == 3
retrieved = partition.activate_pre_allocated_indexes(3)
assert set(retrieved) == {0, 1, 2}
assert partition.global_indexes == {0, 1, 2, 5, 6, 7}
assert partition.total_samples_num == 6
print("✓ Mixed pre-allocated and dynamic indexes work correctly")
print("Mixed indexes tests passed!\n")
class TestDataPartitionStatusCustomMeta:
"""Unit tests for DataPartitionStatus custom_meta methods."""
def test_set_custom_meta_single_partition(self):
"""Test set_custom_meta sets custom metadata for samples in a partition."""
from transfer_queue.controller import DataPartitionStatus
partition = DataPartitionStatus(partition_id="train_0")
custom_meta = {
0: {"score": 0.9, "label": "positive"},
1: {"score": 0.8, "label": "negative"},
}
partition.set_custom_meta(custom_meta)
result = partition.get_custom_meta([0, 1, 2])
assert 0 in result
assert result[0]["score"] == 0.9
assert 1 in result
assert result[1]["label"] == "negative"
def test_set_custom_meta_updates_existing(self):
"""Test set_custom_meta updates existing custom metadata."""
from transfer_queue.controller import DataPartitionStatus
partition = DataPartitionStatus(partition_id="train_0")
partition.set_custom_meta({0: {"score": 0.5}})
partition.set_custom_meta({0: {"score": 0.9, "label": "updated"}})
result = partition.get_custom_meta([0])
assert result[0]["score"] == 0.9
assert result[0]["label"] == "updated"
def test_get_custom_meta_returns_only_requested(self):
"""Test get_custom_meta only returns metadata for requested indices."""
from transfer_queue.controller import DataPartitionStatus
partition = DataPartitionStatus(partition_id="train_0")
partition.set_custom_meta(
{
0: {"data": "sample_0"},
1: {"data": "sample_1"},
2: {"data": "sample_2"},
}
)
result = partition.get_custom_meta([0, 2])
assert 0 in result
assert 2 in result
assert 1 not in result
assert result[0]["data"] == "sample_0"
assert result[2]["data"] == "sample_2"
def test_get_custom_meta_empty_for_missing(self):
"""Test get_custom_meta returns empty dict for indices without metadata."""
from transfer_queue.controller import DataPartitionStatus
partition = DataPartitionStatus(partition_id="train_0")
partition.set_custom_meta({0: {"score": 0.9}})
result = partition.get_custom_meta([1, 2])
assert 0 not in result
assert 1 not in result
assert 2 not in result
def test_custom_meta_cleared_with_data(self):
"""Test custom_meta is cleared when clearing sample data."""
from transfer_queue.controller import DataPartitionStatus
partition = DataPartitionStatus(partition_id="train_0")
partition.update_production_status(
global_indices=[0, 1],
field_names=["input_ids"],
field_schema={"input_ids": {"dtype": "torch.int32", "shape": (512,)}},
)
partition.set_custom_meta({0: {"score": 0.9}, 1: {"score": 0.8}})
partition.clear_data([0], clear_consumption=True)
result = partition.get_custom_meta([0, 1])
assert 0 not in result
assert 1 in result
class TestDataPartitionStatusKvInterface:
"""Unit tests for DataPartitionStatus KV interface functionality.
Tests for the keys_mapping and kv_retrieve_meta methods that support
key-value interface operations within a partition.
"""
def test_kv_retrieve_meta_with_existing_keys(self):
"""Test kv_retrieve_meta returns correct global_indexes for existing keys."""
from transfer_queue.controller import DataPartitionStatus
partition = DataPartitionStatus(partition_id="kv_test_partition")
partition.keys_mapping = {"key_a": 0, "key_b": 1, "key_c": 2}
global_indexes = partition.kv_retrieve_indexes(["key_a", "key_b", "key_c"])
assert global_indexes == [0, 1, 2]
def test_kv_retrieve_meta_with_nonexistent_keys(self):
"""Test kv_retrieve_meta returns None for keys that don't exist."""
from transfer_queue.controller import DataPartitionStatus
partition = DataPartitionStatus(partition_id="kv_test_partition")
partition.keys_mapping = {"existing_key": 5}
global_indexes = partition.kv_retrieve_indexes(["existing_key", "nonexistent_key"])
assert global_indexes == [5, None]
def test_kv_retrieve_meta_empty_list(self):
"""Test kv_retrieve_meta handles empty key list."""
from transfer_queue.controller import DataPartitionStatus
partition = DataPartitionStatus(partition_id="kv_test_partition")
global_indexes = partition.kv_retrieve_indexes([])
assert global_indexes == []
def test_kv_retrieve_meta_partial_match(self):
"""Test kv_retrieve_meta with partial key matches."""
from transfer_queue.controller import DataPartitionStatus
partition = DataPartitionStatus(partition_id="kv_test_partition")
partition.keys_mapping = {"key_1": 10, "key_2": 20, "key_3": 30}
global_indexes = partition.kv_retrieve_indexes(["key_1", "key_3"])
assert global_indexes == [10, 30]
def test_kv_retrieve_keys_with_existing_indexes(self):
"""Test kv_retrieve_keys returns correct keys for existing global_indexes."""
from transfer_queue.controller import DataPartitionStatus
partition = DataPartitionStatus(partition_id="kv_test_partition")
partition.keys_mapping = {"key_a": 0, "key_b": 1, "key_c": 2}
partition.revert_keys_mapping = {0: "key_a", 1: "key_b", 2: "key_c"}
keys = partition.kv_retrieve_keys([0, 1, 2])
assert keys == ["key_a", "key_b", "key_c"]
def test_kv_retrieve_keys_with_nonexistent_indexes(self):
"""Test kv_retrieve_keys returns None for global_indexes that don't exist."""
from transfer_queue.controller import DataPartitionStatus
partition = DataPartitionStatus(partition_id="kv_test_partition")
partition.keys_mapping = {"existing_key": 5}
partition.revert_keys_mapping = {5: "existing_key"}
keys = partition.kv_retrieve_keys([5, 99])
assert keys == ["existing_key", None]
def test_kv_retrieve_keys_empty_list(self):
"""Test kv_retrieve_keys handles empty global_index list."""
from transfer_queue.controller import DataPartitionStatus
partition = DataPartitionStatus(partition_id="kv_test_partition")
keys = partition.kv_retrieve_keys([])
assert keys == []
def test_kv_retrieve_keys_partial_match(self):
"""Test kv_retrieve_keys with partial global_index matches."""
from transfer_queue.controller import DataPartitionStatus
partition = DataPartitionStatus(partition_id="kv_test_partition")
partition.keys_mapping = {"key_1": 10, "key_2": 20, "key_3": 30}
partition.revert_keys_mapping = {10: "key_1", 20: "key_2", 30: "key_3"}
keys = partition.kv_retrieve_keys([10, 30])
assert keys == ["key_1", "key_3"]
class TestFieldMetaIntegration:
"""Unit tests for DataPartitionStatus integration with FieldMeta.
Tests that _update_field_metadata correctly updates underlying FieldMeta state,
and clear_data properly handles FieldMeta when partition becomes empty or partially empty.
"""
def _make_partition(self):
from transfer_queue.controller import DataPartitionStatus
return DataPartitionStatus(partition_id="fieldmeta_integration_test")
def test_update_field_metadata_creates_fieldmeta(self):
"""Test that _update_field_metadata creates FieldMeta for new fields."""
partition = self._make_partition()
partition._update_field_metadata(
global_indexes=[0, 1, 2],
field_schema={
"input_ids": {"dtype": "torch.int32", "shape": (512,), "is_nested": False, "is_non_tensor": False},
"attention_mask": {"dtype": "torch.bool", "shape": (512,), "is_nested": False, "is_non_tensor": False},
},
)
assert "input_ids" in partition.field_metadata
assert "attention_mask" in partition.field_metadata
input_ids_meta = partition.field_metadata["input_ids"]
assert input_ids_meta.dtype == "torch.int32"
assert input_ids_meta.shape == (512,)
assert input_ids_meta.is_nested is False
assert input_ids_meta.is_non_tensor is False
assert input_ids_meta.global_indexes == {0, 1, 2}
attention_mask_meta = partition.field_metadata["attention_mask"]
assert attention_mask_meta.dtype == "torch.bool"
assert attention_mask_meta.shape == (512,)
def test_update_field_metadata_incremental_add(self):
"""Test that _update_field_metadata correctly handles incremental field additions."""
partition = self._make_partition()
partition._update_field_metadata(
global_indexes=[0, 1],
field_schema={"field_a": {"dtype": "torch.int32", "shape": (16,)}},
)
field_meta = partition.field_metadata["field_a"]
assert field_meta.dtype == "torch.int32"
assert field_meta.shape == (16,)
assert field_meta.global_indexes == {0, 1}
partition._update_field_metadata(
global_indexes=[2, 3],
field_schema={"field_a": {"dtype": "torch.int32", "shape": (16,)}},
)
assert field_meta.global_indexes == {0, 1, 2, 3}
assert field_meta.dtype == "torch.int32"
assert field_meta.shape == (16,)
def test_update_field_metadata_dtype_conflict_raises(self):
"""Test that _update_field_metadata raises error on dtype conflict."""
import pytest
partition = self._make_partition()
partition._update_field_metadata(
global_indexes=[0],
field_schema={"field_x": {"dtype": "torch.int32", "shape": (16,)}},
)
with pytest.raises(ValueError, match="dtype mismatch"):
partition._update_field_metadata(
global_indexes=[1],
field_schema={"field_x": {"dtype": "torch.float64", "shape": (16,)}},
)
def test_update_field_metadata_shape_conflict_promotes_nested(self):
"""Test that shape conflict promotes field to nested."""
partition = self._make_partition()
partition._update_field_metadata(
global_indexes=[0],
field_schema={"field_nested": {"dtype": "torch.float32", "shape": (256,)}},
)
partition._update_field_metadata(
global_indexes=[1],
field_schema={"field_nested": {"dtype": "torch.float32", "shape": (128,)}},
)
field_meta = partition.field_metadata["field_nested"]
assert field_meta.is_nested is True
assert field_meta.shape is None
assert 0 in field_meta.per_sample_shapes
assert 1 in field_meta.per_sample_shapes
assert field_meta.per_sample_shapes[0] == (256,)
assert field_meta.per_sample_shapes[1] == (128,)
def test_update_field_metadata_with_custom_backend_meta(self):
"""Test that _update_field_metadata correctly stores custom_backend_meta."""
partition = self._make_partition()
partition._update_field_metadata(
global_indexes=[0, 1, 2],
field_schema={"field_a": {"dtype": "torch.int32"}},
custom_backend_meta={
0: {"field_a": {"token_count": 100}},
1: {"field_a": {"token_count": 200}},
2: {"field_a": {"token_count": 300}},
},
)
assert 0 in partition.field_custom_backend_meta
assert partition.field_custom_backend_meta[0]["field_a"]["token_count"] == 100
assert partition.field_custom_backend_meta[1]["field_a"]["token_count"] == 200
assert partition.field_custom_backend_meta[2]["field_a"]["token_count"] == 300
def test_update_field_metadata_empty_indexes_is_noop(self):
"""Test that _update_field_metadata with empty indexes does nothing."""
partition = self._make_partition()
partition._update_field_metadata(
global_indexes=[],
field_schema={},
)
assert partition.field_metadata == {}
def test_clear_data_removes_samples_from_fieldmeta(self):
"""Test that clear_data correctly removes samples from FieldMeta."""
partition = self._make_partition()
partition.update_production_status(
global_indices=[0, 1, 2, 3, 4],
field_names=["field_a", "field_b"],
field_schema={
"field_a": {"dtype": "torch.int32", "shape": (16,)},
"field_b": {"dtype": "torch.float32", "shape": (32,)},
},
)
assert partition.field_metadata["field_a"].global_indexes == {0, 1, 2, 3, 4}
assert partition.field_metadata["field_b"].global_indexes == {0, 1, 2, 3, 4}
assert partition.global_indexes == {0, 1, 2, 3, 4}
partition.clear_data([0, 2, 4], clear_consumption=False)
assert partition.field_metadata["field_a"].global_indexes == {1, 3}
assert partition.field_metadata["field_b"].global_indexes == {1, 3}
def test_clear_data_all_samples_clears_fieldmeta_when_empty_partition(self):
"""Test that clear_data clears all FieldMeta when partition becomes empty."""
partition = self._make_partition()
partition.update_production_status(
global_indices=[0, 1, 2],
field_names=["field_a", "field_b"],
field_schema={
"field_a": {"dtype": "torch.int32", "shape": (16,)},
"field_b": {"dtype": "torch.float32", "shape": (32,)},
},
)
assert len(partition.field_metadata) == 2
assert partition.global_indexes == {0, 1, 2}
partition.clear_data([0, 1, 2], clear_consumption=False)
assert partition.field_metadata == {}
def test_clear_data_nested_field_becomes_regular(self):
"""Test that nested FieldMeta becomes regular when remaining samples have same shape."""
partition = self._make_partition()
partition.update_production_status(
global_indices=[0, 1],
field_names=["nested_field"],
field_schema={"nested_field": {"dtype": "torch.float32", "shape": (256,)}},
)
partition.update_production_status(
global_indices=[2],
field_names=["nested_field"],
field_schema={"nested_field": {"dtype": "torch.float32", "shape": (128,)}},
)
assert partition.field_metadata["nested_field"].is_nested is True
assert partition.field_metadata["nested_field"].global_indexes == {0, 1, 2}
assert partition.field_metadata["nested_field"].per_sample_shapes == {0: (256,), 1: (256,), 2: (128,)}
partition.clear_data([2], clear_consumption=False)
field_meta = partition.field_metadata["nested_field"]
assert field_meta.is_nested is False
assert field_meta.shape == (256,)
assert field_meta.global_indexes == {0, 1}
assert field_meta.per_sample_shapes == {}
def test_update_production_status_updates_field_metadata(self):
"""Test that update_production_status correctly updates field_metadata via _update_field_metadata."""
partition = self._make_partition()
partition.update_production_status(
global_indices=[0, 1, 2],
field_names=["input_ids", "attention_mask"],
field_schema={
"input_ids": {"dtype": "torch.int32", "shape": (512,), "is_nested": False, "is_non_tensor": False},
"attention_mask": {"dtype": "torch.bool", "shape": (512,), "is_nested": False, "is_non_tensor": False},
},
)
assert "input_ids" in partition.field_metadata
assert "attention_mask" in partition.field_metadata
assert partition.field_metadata["input_ids"].dtype == "torch.int32"
assert partition.field_metadata["attention_mask"].dtype == "torch.bool"
def test_fieldmeta_global_indexes_in_sync_with_partition(self):
"""Test that FieldMeta global_indexes stays in sync with partition's global_indexes."""
partition = self._make_partition()
partition.update_production_status(
global_indices=[0, 1, 2, 3, 4],
field_names=["field_a"],
field_schema={"field_a": {"dtype": "torch.int32", "shape": (16,)}},
)
assert partition.global_indexes == {0, 1, 2, 3, 4}
assert partition.field_metadata["field_a"].global_indexes == {0, 1, 2, 3, 4}
def test_fieldmeta_to_batch_schema_regular(self):
"""Test that FieldMeta.to_batch_schema works correctly for regular tensors."""
partition = self._make_partition()
partition.update_production_status(
global_indices=[0, 1, 2],
field_names=["regular_field"],
field_schema={"regular_field": {"dtype": "torch.float32", "shape": (512,), "is_nested": False}},
)
field_meta = partition.field_metadata["regular_field"]
schema = field_meta.to_batch_schema([0, 1, 2])
assert schema == {
"dtype": "torch.float32",
"shape": (512,),
"is_nested": False,
"is_non_tensor": False,
}
assert "per_sample_shapes" not in schema
def test_fieldmeta_to_batch_schema_nested(self):
"""Test that FieldMeta.to_batch_schema works correctly for nested tensors."""
partition = self._make_partition()
partition.update_production_status(
global_indices=[0, 1],
field_names=["nested_field"],
field_schema={
"nested_field": {"dtype": "torch.float32", "is_nested": True, "per_sample_shapes": {0: (3,), 1: (5,)}}
},
)
field_meta = partition.field_metadata["nested_field"]
schema = field_meta.to_batch_schema([0, 1])
assert schema["is_nested"] is True
assert schema["per_sample_shapes"] == [(3,), (5,)]
def test_fieldmeta_to_batch_schema_nested_different_order(self):
"""Test that FieldMeta.to_batch_schema returns shapes in requested order."""
partition = self._make_partition()
partition.update_production_status(
global_indices=[0, 1, 2],
field_names=["nested_field"],
field_schema={
"nested_field": {
"dtype": "torch.float32",
"is_nested": True,
"per_sample_shapes": {0: (3,), 1: (5,), 2: (7,)},
}
},
)
field_meta = partition.field_metadata["nested_field"]
schema = field_meta.to_batch_schema([2, 0, 1])
assert schema["per_sample_shapes"] == [(7,), (3,), (5,)]
def test_fieldmeta_to_batch_schema_nested_missing_sample(self):
"""Test that FieldMeta.to_batch_schema returns None for missing samples."""
partition = self._make_partition()
partition.update_production_status(
global_indices=[0],
field_names=["nested_field"],
field_schema={
"nested_field": {"dtype": "torch.float32", "is_nested": True, "per_sample_shapes": {0: (3,)}}
},
)
field_meta = partition.field_metadata["nested_field"]
schema = field_meta.to_batch_schema([0, 1])
assert schema["per_sample_shapes"] == [(3,), None]