import copy
import os
import time
from collections import defaultdict
from dataclasses import dataclass, field
from itertools import groupby
from operator import itemgetter
from threading import Thread
from typing import TYPE_CHECKING, Any
from uuid import uuid4
import numpy as np
import ray
import torch
import zmq
from omegaconf import DictConfig
from torch import Tensor
from transfer_queue.metadata import (
BatchMeta,
)
from transfer_queue.sampler import BaseSampler, SequentialSampler
from transfer_queue.utils.enum_utils import Role
from transfer_queue.utils.logging_utils import get_logger
from transfer_queue.utils.perf_utils import IntervalPerfMonitor
from transfer_queue.utils.zmq_utils import (
ZMQMessage,
ZMQRequestType,
ZMQServerInfo,
create_zmq_socket,
format_zmq_address,
get_free_port,
get_node_ip_address,
)
if TYPE_CHECKING:
from transfer_queue.metrics import TQMetricsExporter
logger = get_logger(__name__)
TQ_CONTROLLER_GET_METADATA_TIMEOUT = int(os.environ.get("TQ_CONTROLLER_GET_METADATA_TIMEOUT", 1))
TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL = int(os.environ.get("TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL", 5))
class PartitionIndexManager:
"""
Manages the mapping relationship between partitions and global indexes,
responsible for index allocation and reuse.
"""
def __init__(self):
self.partition_to_indexes = defaultdict(set)
self.reusable_indexes = []
self.global_index_counter = 0
self.allocated_indexes = set()
def allocate_indexes(self, partition_id, count=1) -> list[int]:
"""
Allocate global_indexes for the specified partition.
Prioritizes obtaining from reusable pool, allocates new indexes when insufficient.
Args:
partition_id: Partition ID
count: Number of indexes needed
Returns:
list: List of allocated global_indexes
"""
if count <= 0:
raise ValueError(f"Number of indexes needed must be larger than 0, but got {count}")
indexes = []
if self.reusable_indexes:
num_reuse = min(count, len(self.reusable_indexes))
indexes.extend(self.reusable_indexes[:num_reuse])
del self.reusable_indexes[:num_reuse]
if len(indexes) < count:
needed = count - len(indexes)
start_index = self.global_index_counter
end_index = start_index + needed
new_indexes = list(range(start_index, end_index))
self.allocated_indexes.update(new_indexes)
self.global_index_counter = end_index
indexes.extend(new_indexes)
self.partition_to_indexes[partition_id].update(indexes)
return indexes
def release_partition(self, partition_id) -> list[int]:
"""
Release all global_indexes of the specified partition, adding them to reusable pool.
Args:
partition_id: Partition ID
Returns:
list: List of released global_indexes
"""
if partition_id in self.partition_to_indexes:
indexes = self.partition_to_indexes.pop(partition_id)
self.reusable_indexes.extend(indexes)
for idx in indexes:
self.allocated_indexes.discard(idx)
return list(indexes)
return []
def release_indexes(self, partition_id: str, indexes_to_release: list[int]):
"""
Release specific global_indexes for a partition, adding them to reusable pool.
Args:
partition_id: Partition ID
indexes_to_release: List of specific indexes to release
"""
if partition_id not in self.partition_to_indexes:
return []
partition_indexes = self.partition_to_indexes[partition_id]
if not set(indexes_to_release).issubset(partition_indexes):
raise ValueError("Some indexes to release do not belong to the specified partition.")
partition_indexes.difference_update(indexes_to_release)
self.reusable_indexes.extend(indexes_to_release)
self.allocated_indexes.difference_update(indexes_to_release)
if not partition_indexes:
self.partition_to_indexes.pop(partition_id, None)
def get_indexes_for_partition(self, partition_id) -> list[int]:
"""
Get all global_indexes for the specified partition.
Args:
partition_id: Partition ID
Returns:
list: List of global_indexes for this partition
"""
return list(self.partition_to_indexes.get(partition_id, set()).copy())
@dataclass
class FieldMeta:
"""
Single source of truth for one field's metadata in a partition.
Field-level attributes (dtype/shape/is_nested/is_non_tensor) are shared across all samples, O(1) storage.
Sample-level attributes (per_sample_shapes) are only needed for nested tensors,
indexed by global_idx, O(B_nested) storage.
"""
global_indexes: set[int] = field(default_factory=set)
dtype: Any | None = None
shape: tuple | None = None
is_nested: bool | None = None
is_non_tensor: bool | None = None
per_sample_shapes: dict[int, tuple] = field(default_factory=dict)
def update(self, incoming: dict[str, Any], incoming_global_indexes: list[int]) -> None:
"""Update this field's metadata from an incoming schema dict.
Encapsulates dtype consistency check, shape conflict detection,
and automatic is_nested inference.
Args:
incoming: Schema dict with optional keys:
global_indexes, dtype, shape, is_nested, is_non_tensor, per_sample_shapes
incoming_global_indexes: global indexes of the input meta
Raises:
ValueError: If incoming dtype conflicts with existing dtype.
"""
new_dtype = incoming.get("dtype")
if new_dtype is not None:
if self.dtype is None:
self.dtype = new_dtype
elif self.dtype != new_dtype:
raise ValueError(
f"dtype mismatch: existing={self.dtype}, incoming={new_dtype}. "
f"All batches for the same field must have the same dtype."
)
new_is_nested = incoming.get("is_nested")
new_is_non_tensor = incoming.get("is_non_tensor")
if new_is_nested:
new_per_sample_shapes = incoming.get("per_sample_shapes", None)
if new_per_sample_shapes is None:
raise ValueError("Receiving a nested field without 'per_sample_shapes'!")
if self.is_nested is not None and not self.is_nested:
assert self.shape is not None
for gi in self.global_indexes:
self.per_sample_shapes[gi] = self.shape
self.is_nested = True
self.shape = None
self.per_sample_shapes.update(new_per_sample_shapes)
else:
if not new_is_non_tensor:
new_shape = incoming.get("shape", None)
if new_shape is None:
raise ValueError("Receiving a regular tensor without 'shape'!")
if self.is_nested:
for gi in incoming_global_indexes:
self.per_sample_shapes[gi] = new_shape
else:
if self.is_non_tensor is not None and not self.is_non_tensor:
assert self.shape is not None
if self.shape != new_shape:
for gi in self.global_indexes:
self.per_sample_shapes[gi] = self.shape
for gi in incoming_global_indexes:
self.per_sample_shapes[gi] = new_shape
self.shape = None
self.is_nested = True
self.global_indexes.update(incoming_global_indexes)
def remove_samples(self, indexes: list[int]):
"""Remove sample-level data for the given indexes."""
for idx in indexes:
self.per_sample_shapes.pop(idx, None)
self.global_indexes.discard(idx)
if len(self.global_indexes) == 0:
self.is_nested = None
self.is_non_tensor = None
self.shape = None
self.dtype = None
self.per_sample_shapes.clear()
else:
if self.is_nested:
remaining_shapes = set(
tuple(shape) if isinstance(shape, list) else shape for shape in self.per_sample_shapes.values()
)
if len(remaining_shapes) == 1:
self.is_nested = False
self.shape = next(iter(remaining_shapes))
self.per_sample_shapes.clear()
def to_batch_schema(self, batch_global_indexes: list[int]) -> dict[str, Any]:
"""Export as a BatchMeta.field_schema-compatible dict for generate_batch_meta."""
schema = {
"dtype": self.dtype,
"shape": self.shape,
"is_nested": self.is_nested,
"is_non_tensor": self.is_non_tensor,
}
if self.is_nested and self.per_sample_shapes:
schema["per_sample_shapes"] = [self.per_sample_shapes.get(gi) for gi in batch_global_indexes]
return schema
@dataclass
class DataPartitionStatus:
"""
Robust status information for a data partition with dynamic expansion support.
This class tracks the production and consumption status of data within a specific
partition (e.g., "train@global_batch_0", "inference@kv_cache_1") with full support
for dynamic row and column expansion.
"""
partition_id: str
created_at: float = field(default_factory=time.time)
TQ_PRE_ALLOC_SAMPLE_NUM = int(os.environ.get("TQ_PRE_ALLOC_SAMPLE_NUM", 1))
production_status: Tensor = torch.zeros(TQ_PRE_ALLOC_SAMPLE_NUM, 1, dtype=torch.int8)
consumption_status: dict[str, Tensor] = field(default_factory=dict)
global_indexes: set[int] = field(
default_factory=set
)
pre_allocated_global_indexes: set[int] = field(
default_factory=set
)
field_name_mapping: dict[str, int] = field(default_factory=dict)
field_metadata: dict[str, FieldMeta] = field(default_factory=dict)
field_custom_backend_meta: dict[int, dict[str, Any]] = field(
default_factory=dict
)
custom_meta: dict[int, dict[str, Any]] = field(default_factory=dict)
keys_mapping: dict[str, int] = field(default_factory=dict)
revert_keys_mapping: dict[int, str] = field(default_factory=dict)
@property
def total_samples_num(self) -> int:
"""Current number of samples in the partition."""
return len(self.global_indexes)
@property
def total_fields_num(self) -> int:
"""Current number of fields (columns) in the partition."""
return len(self.field_name_mapping)
@property
def allocated_fields_num(self) -> int:
"""Current number of allocated columns in the tensor."""
return self.production_status.shape[1]
@property
def allocated_samples_num(self) -> int:
"""Current number of allocated rows in the tensor."""
return self.production_status.shape[0]
def register_pre_allocated_indexes(self, allocated_indexes: list[int]):
"""
Register pre-allocated sample indexes to this partition.
These indexes are reserved before actual data production, allowing consumers
to see the expected total sample count via get_consumption_status even when
producers haven't generated all samples yet.
Args:
allocated_indexes: List of global indexes to pre-allocate
"""
if len(allocated_indexes) < 1:
logger.info("Trying to pre-allocate global_indexes with empty list!")
return
self.pre_allocated_global_indexes.update(allocated_indexes)
max_sample_idx = max(allocated_indexes)
required_samples = max_sample_idx + 1
self.ensure_samples_capacity(required_samples)
logger.debug(f"Pre-allocated indexes in {self.partition_id}: {allocated_indexes}")
def activate_pre_allocated_indexes(self, sample_num: int) -> list[int]:
"""
Activate and retrieve pre-allocated indexes for use in data insertion.
This method consumes pre-allocated indexes and marks them as active in global_indexes.
If pre-allocated indexes are insufficient, returns all available ones.
Args:
sample_num: Number of indexes needed
Returns:
List of retrieved global indexes
"""
available_indexes = len(self.pre_allocated_global_indexes)
if available_indexes < sample_num:
global_index_to_allocate = list(self.pre_allocated_global_indexes)
logger.debug(
f"Not enough pre-allocated indexes in partition {self.partition_id}. "
f"Returning {available_indexes} of {sample_num} requested."
)
else:
global_index_to_allocate = list(sorted(self.pre_allocated_global_indexes))[:sample_num]
self.global_indexes.update(global_index_to_allocate)
self.pre_allocated_global_indexes.difference_update(set(global_index_to_allocate))
return global_index_to_allocate
def ensure_samples_capacity(self, required_samples: int) -> None:
"""
Ensure the production status tensor has enough rows for the required samples.
Args:
required_samples: Minimum number of samples needed
"""
current_sample_space = self.allocated_samples_num
if required_samples > current_sample_space:
expansion_needed = required_samples - current_sample_space
new_samples = current_sample_space + expansion_needed
new_fields = self.production_status.shape[1]
expanded_tensor = torch.zeros(new_samples, new_fields, dtype=torch.int8)
expanded_tensor[:current_sample_space, :] = self.production_status
self.production_status = expanded_tensor
for task_name, consumption_tensor in self.consumption_status.items():
expanded_consumption = torch.zeros(new_samples, dtype=torch.int8)
expanded_consumption[:current_sample_space] = consumption_tensor
self.consumption_status[task_name] = expanded_consumption
logger.debug(f"Expanded partition {self.partition_id} from {current_sample_space} to {new_samples} samples")
def ensure_fields_capacity(self, required_fields: int) -> None:
"""
Ensure the production status tensor has enough columns for the required fields.
Args:
required_fields: Minimum number of fields needed
"""
current_fields = self.production_status.shape[1]
if required_fields > current_fields:
expansion_needed = required_fields - current_fields
new_fields = current_fields + expansion_needed
new_samples = self.production_status.shape[0]
expanded_tensor = torch.zeros(new_samples, new_fields, dtype=torch.int8)
expanded_tensor[:, :current_fields] = self.production_status
self.production_status = expanded_tensor
logger.debug(f"Expanded partition {self.partition_id} from {current_fields} to {new_fields} fields")
def update_production_status(
self,
global_indices: list[int],
field_names: list[str],
field_schema: dict[str, dict[str, Any]],
custom_backend_meta: dict[int, dict[str, Any]] | None = None,
) -> bool:
"""
Update production status for specific samples and fields.
Handles dynamic expansion of both samples and fields.
Note: field_names is derived from field_schema.keys() internally.
The parameter is kept for backward compatibility but ignored;
callers should ensure field_schema contains all intended fields.
Args:
global_indices: List of sample indices to update
field_names: List of field names (ignored; derived from field_schema.keys())
field_schema: Columnar field schema {field_name: {dtype, shape, is_nested, ...}}
custom_backend_meta: Optional per-sample per-field
custom metadata provided by storage backend
Returns:
True if update was successful, False on error
"""
try:
field_names = list(field_schema.keys())
max_sample_idx = max(global_indices) if global_indices else -1
required_samples = max_sample_idx + 1
self.ensure_samples_capacity(required_samples)
new_fields = [f for f in field_names if f not in self.field_name_mapping]
if new_fields:
for f in new_fields:
self.field_name_mapping[f] = len(self.field_name_mapping)
required_fields = len(self.field_name_mapping)
self.ensure_fields_capacity(required_fields)
if self.production_status is not None and global_indices and field_names:
field_indices = [self.field_name_mapping.get(f) for f in field_names]
self.production_status[torch.tensor(global_indices)[:, None], torch.tensor(field_indices)] = 1
self._update_field_metadata(global_indices, field_schema, custom_backend_meta)
self.global_indexes.update(global_indices)
return True
except Exception as e:
logger.error(f"Error updating production status for partition {self.partition_id}: {e}")
return False
def _update_field_metadata(
self,
global_indexes: list[int],
field_schema: dict[str, dict[str, Any]],
custom_backend_meta: dict[int, dict[str, Any]] | None = None,
):
"""Update field metadata from columnar field_schema."""
if not global_indexes:
return
for field_name, meta in field_schema.items():
if field_name not in self.field_metadata:
self.field_metadata[field_name] = FieldMeta(
global_indexes=set(global_indexes),
dtype=meta.get("dtype"),
shape=meta.get("shape"),
is_nested=meta.get("is_nested", False),
is_non_tensor=meta.get("is_non_tensor", False),
per_sample_shapes=meta.get("per_sample_shapes", {}),
)
else:
self.field_metadata[field_name].update(meta, global_indexes)
if custom_backend_meta:
for global_idx, per_field_meta in custom_backend_meta.items():
if global_idx not in self.field_custom_backend_meta:
self.field_custom_backend_meta[global_idx] = {}
self.field_custom_backend_meta[global_idx].update(per_field_meta)
def mark_consumed(self, task_name: str, global_indices: list[int]):
"""
Mark specific samples as consumed by a task.
Args:
task_name: Name of the consumer task
global_indices: List of sample indices to mark as consumed
"""
try:
_, consumption_status = self.get_consumption_status(task_name, mask=False)
if consumption_status.numel() > 0 and global_indices:
consumption_status[global_indices] = 1
except Exception as e:
logger.error(
f"Error marking samples consumed for partition {self.partition_id}, task {task_name}: {e}. "
f"Target global_indices {global_indices}, but current consumption_status has "
f"shape {consumption_status.shape}"
)
def get_consumption_status(self, task_name: str, mask: bool = False) -> tuple[Tensor, Tensor]:
"""
Get or create consumption status for a specific task.
Handles dynamic expansion when new samples are added.
Args:
task_name: Name of the consumer task
mask: Whether to return only the status for current partition samples
Returns:
Tuple of:
- Partition global index tensor
- Consumption status tensor for the specified task. 1 for consumed, 0 for not consumed.
"""
if task_name not in self.consumption_status:
if self.production_status is not None:
self.consumption_status[task_name] = torch.zeros(self.allocated_samples_num, dtype=torch.int8)
else:
self.consumption_status[task_name] = torch.zeros(0, dtype=torch.int8)
partition_global_index = torch.tensor(
sorted(self.global_indexes | self.pre_allocated_global_indexes), dtype=torch.long
)
if mask:
if partition_global_index.numel() == 0:
empty_status = self.consumption_status[task_name].new_zeros(0)
return partition_global_index, empty_status
self.ensure_samples_capacity(max(partition_global_index) + 1)
consumption_status = self.consumption_status[task_name][partition_global_index]
else:
consumption_status = self.consumption_status[task_name]
return partition_global_index, consumption_status
def reset_consumption(self, task_name: str | None = None):
"""
Reset consumption status for a specific task or all tasks.
This allows the same data to be re-consumed without clearing the actual data.
Useful for debugging scenarios where the same rollout data needs to be
trained multiple times.
Args:
task_name: Name of the task to reset consumption for.
If None, resets consumption status for all tasks.
"""
if task_name is not None:
if task_name in self.consumption_status:
self.consumption_status[task_name].zero_()
logger.debug(f"Reset consumption status for task '{task_name}' in partition {self.partition_id}")
else:
for name, status_tensor in self.consumption_status.items():
status_tensor.zero_()
logger.debug(f"Reset consumption status for all tasks in partition {self.partition_id}")
def get_production_status_for_fields(
self, field_names: list[str], mask: bool = False
) -> tuple[Tensor | None, Tensor | None]:
"""
Check if all samples for specified fields are fully produced and ready.
Args:
field_names: List of field names to check production status for
mask: Whether to return only the status for current partition samples
Returns:
Tuple of:
- Partition global index tensor
- Production status tensor for the specified task. 1 for ready, 0 for not ready.
"""
if field_names is None or len(field_names) == 0:
return None, None
for field_name in field_names:
if field_name not in self.field_name_mapping:
return None, None
col_mask = torch.zeros(self.allocated_fields_num, dtype=torch.bool)
field_indices = [self.field_name_mapping[field] for field in field_names]
if field_indices:
col_mask[field_indices] = True
production_status = self.production_status[:, col_mask]
partition_global_index = torch.tensor(
sorted(self.global_indexes | self.pre_allocated_global_indexes), dtype=torch.long
)
if mask:
production_status = production_status[partition_global_index]
return partition_global_index, production_status
def scan_data_status(self, field_names: list[str], task_name: str) -> list[int]:
"""
Scan data status to find samples ready for consumption.
This replaces the original _scan_data_status functionality.
Args:
field_names: List of required field names
task_name: Name of the consumer task
Returns:
List of sample indices that are ready for consumption
"""
for field_name in field_names:
if field_name not in self.field_name_mapping:
return []
row_mask = torch.ones(self.allocated_samples_num, dtype=torch.bool)
_, consumption_status = self.get_consumption_status(task_name, mask=False)
if consumption_status is not None:
unconsumed_mask = consumption_status == 0
row_mask &= unconsumed_mask
col_mask = torch.zeros(self.allocated_fields_num, dtype=torch.bool)
field_indices = [self.field_name_mapping[field] for field in field_names]
if field_indices:
col_mask[field_indices] = True
relevant_status = self.production_status[row_mask][:, col_mask]
all_fields_ready = torch.all(relevant_status, dim=1)
ready_indices_in_filtered = torch.nonzero(all_fields_ready, as_tuple=False).flatten()
all_indices = torch.where(row_mask)[0]
ready_sample_indices = all_indices[ready_indices_in_filtered].tolist()
return ready_sample_indices
def get_field_schema(
self, field_names: list[str], batch_global_indexes: list[int] | None = None
) -> dict[str, dict[str, Any]]:
"""Return field_schema from the FieldMeta store."""
gi = batch_global_indexes or []
return {f: self.field_metadata[f].to_batch_schema(gi) for f in field_names if f in self.field_metadata}
def get_field_custom_backend_meta(
self, global_indices: list[int], field_names: list[str]
) -> dict[int, dict[str, Any]]:
"""
Get custom_backend_meta for multiple samples and fields.
This method retrieves backend-specific metadata stored at per-sample per-field level.
The returned dictionary maps global_index to a dictionary of field_name to metadata.
Args:
global_indices: List of global sample indices to retrieve metadata for
field_names: List of field names to filter by. Only metadata for these
fields will be included in the result.
Returns:
Dictionary mapping global_index to field-name-to-metadata mapping.
Only includes indices that have custom_backend_meta set.
Example:
>>> partition.get_field_custom_backend_meta([0, 1], ["field_a", "field_b"])
{0: {'field_a': {'meta1': 'xxx'}, 'field_b': {'meta1': 'xxx'}}, 1: {...}}
"""
return {
idx: {f: v for f, v in self.field_custom_backend_meta[idx].items() if f in field_names}
for idx in global_indices
if idx in self.field_custom_backend_meta
}
def get_custom_meta(self, global_indices: list[int]) -> dict[int, dict]:
"""
Get custom_meta for multiple samples.
This method retrieves user-defined per-sample metadata.
Args:
global_indices: List of global sample indices to retrieve metadata for
Returns:
Dictionary mapping global_index to custom metadata dict.
Only includes indices that have custom_meta set.
Example:
>>> partition.get_custom_meta([0, 2])
{0: {'score': 0.9}, 2: {'label': 'positive'}}
"""
return {idx: self.custom_meta[idx] for idx in global_indices if idx in self.custom_meta}
def set_custom_meta(self, custom_meta: dict[int, dict]) -> None:
"""
Set custom_meta for multiple samples.
This method sets or updates user-defined per-sample metadata.
Args:
custom_meta: Dictionary mapping global_index to custom metadata dict.
Existing entries will be overwritten.
"""
self.custom_meta.update(custom_meta)
def get_statistics(self) -> dict[str, Any]:
"""Get detailed statistics for this partition."""
stats = {
"partition_id": self.partition_id,
"created_at": self.created_at,
"total_samples_num": self.total_samples_num,
"total_fields_num": self.total_fields_num,
"allocated_samples_num": self.allocated_samples_num,
"allocated_fields_num": self.allocated_fields_num,
"registered_tasks": list(self.consumption_status.keys()),
}
if self.production_status is not None:
produced_samples = torch.any(self.production_status == 1, dim=1).sum().item()
stats["produced_samples"] = produced_samples
stats["production_progress"] = (
produced_samples / self.total_samples_num if self.total_samples_num > 0 else 0
)
field_stats = {}
for field_name, field_idx in self.field_name_mapping.items():
field_produced = (self.production_status[:, field_idx] == 1).sum().item()
field_stats[field_name] = {
"produced_samples": field_produced,
"production_progress": (
field_produced / self.total_samples_num if self.total_samples_num > 0 else 0
),
}
stats["field_statistics"] = field_stats
consumption_stats = {}
for task_name, consumption_tensor in self.consumption_status.items():
consumed_samples = (consumption_tensor == 1).sum().item()
consumption_stats[task_name] = {
"consumed_samples": consumed_samples,
"consumption_progress": (
consumed_samples / self.total_samples_num if self.total_samples_num > 0 else 0
),
}
stats["consumption_statistics"] = consumption_stats
return stats
def to_snapshot(self):
"""
Get a snapshot of partition status information.
Returns:
DataPartitionStatus object
"""
cls = self.__class__
snapshot = cls.__new__(cls)
for name, value in self.__dict__.items():
if isinstance(value, Tensor):
new_val = value.clone().detach()
else:
new_val = copy.deepcopy(value)
setattr(snapshot, name, new_val)
return snapshot
def clear_data(self, indexes_to_release: list[int], clear_consumption: bool = True):
"""Clear all production and optionally consumption data for given global_indexes."""
try:
if self.production_status is not None:
self.production_status[indexes_to_release, :] = 0
if clear_consumption:
for consumption_tensor in self.consumption_status.values():
consumption_tensor[indexes_to_release] = 0
self.global_indexes.difference_update(indexes_to_release)
empty_fields = []
for field_name, field_meta in self.field_metadata.items():
field_meta.remove_samples(indexes_to_release)
if len(field_meta.global_indexes) == 0:
empty_fields.append(field_name)
if len(self.global_indexes) == 0:
self.field_metadata.clear()
else:
for field_name in empty_fields:
self.field_metadata.pop(field_name)
for idx in indexes_to_release:
self.field_custom_backend_meta.pop(idx, None)
self.custom_meta.pop(idx, None)
if idx in self.revert_keys_mapping:
self.keys_mapping.pop(self.revert_keys_mapping[idx], None)
self.revert_keys_mapping.pop(idx, None)
except Exception as e:
logger.error(
f"Error clearing data for partition {self.partition_id}: {e}. "
f"Attempted to clear global_indexes: {indexes_to_release}"
)
def kv_retrieve_indexes(self, keys: list[str]) -> list[int | None]:
"""Translate the user-specified keys to global_indexes"""
global_indexes = [self.keys_mapping.get(k, None) for k in keys]
return global_indexes
def kv_retrieve_keys(self, global_indexes: list[int]) -> list[str | None]:
"""Translate the global_indexes to keys"""
keys = [self.revert_keys_mapping.get(idx, None) for idx in global_indexes]
return keys
@ray.remote(num_cpus=1)
class TransferQueueController:
"""
TransferQueue Controller with partition-based data management.
This refactored controller manages data through dynamic partitions instead of
fixed global batches. Each partition represents a logical data container
(e.g., "train@global_batch_0", "inference@kv_cache_1") that can be created
on-demand and managed independently.
Key improvements:
- Dynamic partition creation on-demand
- No dependency on training-specific parameters (global_batch_size, etc.)
- Support for diverse use cases (KV cache migration, model resharding, etc.)
- Flexible data organization through partition-based addressing
"""
def __init__(
self,
sampler: BaseSampler | type[BaseSampler] = SequentialSampler,
polling_mode: bool = False,
) -> None:
"""Initialize the TransferQueue Controller.
Args:
sampler: Sampler instance or sampler class to use for data sampling.
- If a BaseSampler instance is provided, it will be used directly
- If a BaseSampler subclass is provided, it will be instantiated
- Defaults to SequentialSampler for simple sequential sampling
- Example: sampler=GRPOGroupNSampler() (instance)
- Example: sampler=SequentialSampler (class)
polling_mode: Whether to use polling mode for TransferQueue controller.
- If False, the controller will raise an error when no enough data is available.
- If True, the controller will return an empty BatchMeta when no enough data is available.
The user side is responsible for handling this empty case (retrying later).
"""
if isinstance(sampler, BaseSampler):
self.sampler = sampler
elif isinstance(sampler, type) and issubclass(sampler, BaseSampler):
self.sampler = sampler()
else:
raise TypeError(
f"sampler {getattr(sampler, '__name__', repr(sampler))} must be an instance or subclass of BaseSampler"
)
self.controller_id = f"TQ_CONTROLLER_{uuid4().hex[:8]}"
self.polling_mode = polling_mode
self.tq_config = None
self._init_zmq_socket()
self.partitions: dict[str, DataPartitionStatus] = {}
self.index_manager = PartitionIndexManager()
self._connected_storage_managers: set[str] = set()
self._metrics: TQMetricsExporter | None = None
self._metrics_endpoint: str = ""
self._start_process_handshake()
self._start_process_request()
logger.info(f"TransferQueue Controller {self.controller_id} initialized")
def create_partition(self, partition_id: str) -> bool:
"""
Create a new data partition with pre-allocated sample indexes.
Partitions dynamically expand as needed. Additionally, TQ_PRE_ALLOC_SAMPLE_NUM
indexes are pre-allocated to allow consumers to determine consumption status
before all samples are produced.
Args:
partition_id: Unique identifier for the partition (e.g., "train@global_batch_0")
Returns:
True if partition was created successfully, False if it already exists
"""
TQ_PRE_ALLOC_SAMPLE_NUM = int(os.environ.get("TQ_PRE_ALLOC_SAMPLE_NUM", 1))
if partition_id in self.partitions:
logger.warning(f"Partition {partition_id} already exists")
return False
self.partitions[partition_id] = DataPartitionStatus(partition_id=partition_id)
global_indexes = self.index_manager.allocate_indexes(partition_id, count=TQ_PRE_ALLOC_SAMPLE_NUM)
self.partitions[partition_id].register_pre_allocated_indexes(global_indexes)
logger.info(f"Created partition {partition_id} with {TQ_PRE_ALLOC_SAMPLE_NUM} pre-allocated indexes")
return True
def _get_partition(self, partition_id: str) -> DataPartitionStatus | None:
"""
Get partition status information.
Args:
partition_id: ID of the partition to retrieve
Returns:
DataPartitionStatus object if partition exists, None otherwise
"""
return self.partitions.get(partition_id)
def get_partition_snapshot(self, partition_id: str) -> DataPartitionStatus | None:
"""
Get a copy of partition status information.
Args:
partition_id: ID of the partition to retrieve
Returns:
DataPartitionStatus object if partition exists, None otherwise
"""
partition = self._get_partition(partition_id)
if partition is None:
return None
return partition.to_snapshot()
def list_partitions(self) -> list[str]:
"""
List all available partition IDs.
Returns:
List of partition IDs
"""
return list(self.partitions.keys())
def get_partition_index_range(self, partition_id: str) -> list[int]:
"""
Get all indexes for a specific partition.
Args:
partition_id: Partition identifier
Returns:
List of indexes allocated to the partition
"""
return self.index_manager.get_indexes_for_partition(partition_id)
def update_production_status(
self,
partition_id: str,
global_indexes: list[int],
field_schema: dict[str, dict[str, Any]],
custom_backend_meta: dict[int, dict[str, Any]] | None = None,
) -> bool:
"""
Update production status for specific samples and fields in a partition.
Delegates to the partition's own update_production_status method.
Args:
partition_id: ID of the partition
global_indexes: List of sample indices to update
field_schema: Columnar field schema {field_name: {dtype, shape, is_nested, ...}}
custom_backend_meta: Optional custom backend metadata
Returns:
True if update was successful, False otherwise
"""
field_names = list(field_schema.keys())
partition = self._get_partition(partition_id)
if not partition:
logger.error(f"Partition {partition_id} not found")
return False
success = partition.update_production_status(global_indexes, field_names, field_schema, custom_backend_meta)
if success:
logger.debug(
f"[{self.controller_id}]: Updated production status for partition {partition_id}: "
f"samples={global_indexes}, fields={field_names}"
)
return success
def get_consumption_status(self, partition_id: str, task_name: str) -> tuple[Tensor | None, Tensor | None]:
"""
Get or create consumption status for a specific task and partition.
Delegates to the partition's own method.
Args:
partition_id: ID of the partition
task_name: Name of the consumer task
Returns:
Tuple of:
- Partition global index tensor
- Consumption status tensor for the specified task. 1 for consumed, 0 for not consumed.
"""
partition = self._get_partition(partition_id)
if not partition:
return None, None
return partition.get_consumption_status(task_name, mask=True)
def get_production_status(self, partition_id: str, data_fields: list[str]) -> tuple[Tensor | None, Tensor | None]:
"""
Check if all samples for specified fields are fully produced in a partition.
Args:
partition_id: ID of the partition
data_fields: List of field names to check production status for
Returns:
Tuple of:
- Partition global index tensor
- Production status tensor for the specified task. 1 for ready, 0 for not ready.
"""
partition = self._get_partition(partition_id)
if not partition:
return None, None
return partition.get_production_status_for_fields(data_fields, mask=True)
def set_custom_meta(self, partition_custom_meta: dict[str, dict[int, dict]]) -> None:
"""
Set custom_meta for samples in partitions.
This method allows setting per-sample custom metadata (custom_meta) for samples
identified by their global indexes within specific partitions. Custom metadata
is stored per-sample and can be retrieved along with BatchMeta in subsequent
get_meta calls.
Args:
partition_custom_meta: Dictionary mapping partition_id to custom metadata dict.
Format: {partition_id: {global_index: {metadata_key: metadata_value}}}
- partition_id: ID of the partition
- global_index: Global index of the sample
- metadata_key/value: User-defined metadata key-value pairs
Example:
>>> # Set custom metadata for samples in different partitions
>>> controller.set_custom_meta({
... "train_0": {
... 0: {"score": 0.9, "label": "positive"},
... 1: {"score": 0.8, "label": "negative"}
... },
... "train_1": {
... 10: {"score": 0.95, "label": "positive"}
... }
... })
"""
for partition_id, custom_meta in partition_custom_meta.items():
partition = self._get_partition(partition_id)
if partition:
partition.set_custom_meta(custom_meta)
else:
logger.warning(
f"set_custom_meta: partition {partition_id}' not found; "
f"custom_metadata for this partition will be ignored"
)
def get_metadata(
self,
data_fields: list[str],
partition_id: str,
mode: str = "fetch",
task_name: str | None = None,
batch_size: int | None = None,
sampling_config: dict[str, Any] | None = None,
*args,
**kwargs,
) -> BatchMeta:
"""
Retrieve metadata with support for three modes.
Args:
data_fields: List of field names to include in metadata
partition_id: Partition id for which to retrieve metadata
mode: Operation mode - 'insert', 'fetch', or 'force_fetch'
- mode="insert": Create metadata for new samples (for data insertion)
- mode="fetch": Get metadata from ready samples using the configured sampler
- mode="force_fetch": Get metadata for unconsumed samples without sampling
(excludes already consumed samples)
task_name: Name of the consumer task (required for fetch modes)
batch_size: Number of samples to retrieve
*args: Additional positional arguments
**kwargs: Additional keyword arguments
Returns:
BatchMeta object containing the requested metadata
Raises:
TimeoutError: If waiting for sufficient data times out in fetch mode
"""
if mode == "insert":
if partition_id not in self.partitions:
self.create_partition(partition_id)
partition = self._get_partition(partition_id)
if data_fields is None:
raise RuntimeError("Must provide data_fields for inserting new data")
if batch_size is None:
raise ValueError("must provide batch_size for inserting new data")
assert partition is not None
batch_global_indexes = partition.activate_pre_allocated_indexes(batch_size)
if len(batch_global_indexes) < batch_size:
new_global_indexes = self.index_manager.allocate_indexes(
partition_id, count=(batch_size - len(batch_global_indexes))
)
batch_global_indexes.extend(new_global_indexes)
partition.global_indexes.update(batch_global_indexes)
return self.generate_batch_meta(partition_id, batch_global_indexes, data_fields, mode)
if mode == "fetch":
assert task_name is not None
if batch_size is None:
raise ValueError("must provide batch_size in fetch mode")
start_time = time.time()
while True:
ready_for_consume_indexes = self.scan_data_status(partition_id, data_fields, task_name)
if len(ready_for_consume_indexes) < batch_size:
if self.polling_mode:
if self.sampler.has_cached_result(partition_id, task_name, sampling_config):
break
else:
logger.debug(
f"[{self.controller_id}]: Not enough data for task {task_name} in "
f"partition {partition_id}. Required: {batch_size}, "
f"Available: {len(ready_for_consume_indexes)}."
f" Returning None due to polling mode."
)
return BatchMeta.empty()
else:
logger.warning(
f"[{self.controller_id}]: Insufficient data for task {task_name}. Required: {batch_size} "
f"samples with fields {data_fields} in partition {partition_id}, but only have "
f"{len(ready_for_consume_indexes)} samples meeting the criteria. "
f"Retrying in {TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL}s..."
)
time.sleep(TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL)
if time.time() - start_time > TQ_CONTROLLER_GET_METADATA_TIMEOUT:
raise TimeoutError(
f"Timeout while waiting for sufficient data for task {task_name}. "
f"Required: {batch_size}, Available: {len(ready_for_consume_indexes)}"
)
else:
break
batch_global_indexes, consumed_indexes = self.sampler(
ready_for_consume_indexes,
batch_size,
partition=self._get_partition(partition_id),
**(sampling_config or {}),
**kwargs,
)
if len(batch_global_indexes) == 0:
if self.polling_mode:
return BatchMeta.empty()
raise RuntimeError(
f"Sampler returned no samples. Please check the sampler logic. "
f"Expected: {batch_size}, before sampling: {len(ready_for_consume_indexes)}, "
f"after sampling: {len(batch_global_indexes)}"
)
if consumed_indexes:
partition = self.partitions[partition_id]
partition.mark_consumed(task_name, consumed_indexes)
elif mode == "force_fetch":
batch_global_indexes = self.index_manager.get_indexes_for_partition(partition_id)
consumed_indexes = []
metadata = self.generate_batch_meta(partition_id, batch_global_indexes, data_fields, mode)
return metadata
def scan_data_status(
self,
partition_id: str,
data_fields: list[str],
task_name: str,
) -> list[int]:
"""
Find samples that are ready for consumption in a specific partition.
Delegates scanning functionality to the partition's own method.
Args:
partition_id: ID of the partition
data_fields: List of required field names
task_name: Name of the consumer task
Returns:
List of global indices that are ready for consumption
"""
partition = self._get_partition(partition_id)
if not partition:
return []
ready_sample_indices = partition.scan_data_status(data_fields, task_name)
return ready_sample_indices
def generate_batch_meta(
self,
partition_id: str,
batch_global_indexes: list[int],
data_fields: list[str],
mode: str = "fetch",
) -> BatchMeta:
"""
Generate BatchMeta for specific samples in a partition.
O(F) optimized version that uses field_schema instead of per-sample metadata.
This function is responsible only for metadata generation and does not
modify consumption state. State management is handled by the calling function.
Args:
partition_id: ID of the partition
batch_global_indexes: List of sample indices to include in the batch
data_fields: List of field names to include
mode: Operation mode - 'fetch', 'insert', or 'force_fetch'
Returns:
BatchMeta object containing sample metadata
Raises:
ValueError: If partition doesn't exist or invalid mode
"""
partition = self._get_partition(partition_id)
if not partition:
raise ValueError(f"Partition {partition_id} not found")
if mode not in ["fetch", "insert", "force_fetch"]:
raise ValueError(f"Invalid mode: {mode}")
batch_size = len(batch_global_indexes)
field_schema = partition.get_field_schema(data_fields, batch_global_indexes)
if mode == "insert":
for field_name in data_fields:
if field_name not in field_schema:
field_schema[field_name] = {
"dtype": None,
"shape": None,
"is_nested": False,
"is_non_tensor": False,
}
if mode == "fetch":
production_status = np.ones(batch_size, dtype=np.int8)
elif mode == "insert":
production_status = np.zeros(batch_size, dtype=np.int8)
else:
production_status = np.zeros(batch_size, dtype=np.int8)
if partition.production_status is not None and data_fields:
field_indices = [
partition.field_name_mapping.get(field_name)
for field_name in data_fields
if field_name in partition.field_name_mapping
]
if field_indices:
for i, global_idx in enumerate(batch_global_indexes):
if global_idx < partition.production_status.shape[0]:
sample_status = partition.production_status[global_idx, field_indices]
if torch.all(sample_status == 1):
production_status[i] = 1
custom_meta_dict = partition.get_custom_meta(batch_global_indexes)
custom_backend_meta = partition.get_field_custom_backend_meta(batch_global_indexes, data_fields)
custom_meta_list = [custom_meta_dict.get(global_index, {}) for global_index in batch_global_indexes]
custom_backend_meta_list = [custom_backend_meta.get(global_index, {}) for global_index in batch_global_indexes]
batch_meta = BatchMeta(
global_indexes=batch_global_indexes,
partition_ids=[partition_id] * batch_size,
field_schema=field_schema,
production_status=production_status,
custom_meta=custom_meta_list,
_custom_backend_meta=custom_backend_meta_list,
)
return batch_meta
def clear_partition(self, partition_id: str, clear_consumption: bool = True):
"""
Clear data for a specific partition (delete the whole partition).
Args:
partition_id: ID of the partition to clear
clear_consumption: Whether to also clear consumption status
"""
logger.debug(f"[{self.controller_id}]: clearing metadata in partition {partition_id}")
partition = self._get_partition(partition_id)
if not partition:
logger.warning(f"Try to clear an non-existent partition {partition_id}!")
return
global_indexes_range = list(self.index_manager.get_indexes_for_partition(partition_id))
partition.clear_data(global_indexes_range, clear_consumption)
self.index_manager.release_partition(partition_id)
self.partitions.pop(partition_id)
self.sampler.clear_cache(partition_id)
def reset_consumption(self, partition_id: str, task_name: str | None = None):
"""
Reset consumption status for a partition without clearing the actual data.
This allows the same data to be re-consumed, useful for debugging scenarios
where the same rollout data needs to be trained multiple times.
Args:
partition_id: ID of the partition to reset consumption for
task_name: Name of the task to reset. If None, resets all tasks.
"""
logger.debug(f"[{self.controller_id}]: Resetting consumption for partition {partition_id}, task={task_name}")
partition = self._get_partition(partition_id)
if not partition:
logger.warning(f"Try to reset consumption of an non-existent partition {partition_id}!")
return
partition.reset_consumption(task_name)
def clear_meta(
self,
global_indexes: list[int],
partition_ids: list[str],
clear_consumption: bool = True,
):
"""
Clear meta for individual samples (preserving the partition).
Args:
global_indexes: global_indexes to clear
partition_ids: IDs of the partitions to clear
clear_consumption: Whether to also clear consumption status
"""
logger.debug(
f"[{self.controller_id}]: Clearing meta with global_indexes {global_indexes} in partition {partition_ids}"
)
if global_indexes is None or partition_ids is None:
raise ValueError("global_indexes and partition_ids cannot be None")
if len(global_indexes) != len(partition_ids):
raise ValueError(
f"global_indexes and partition_ids must have the same length, "
f"got {len(global_indexes)} and {len(partition_ids)}"
)
combined = list(zip(partition_ids, global_indexes, strict=True))
combined.sort(key=itemgetter(0))
for partition_id, group in groupby(combined, key=itemgetter(0)):
partition = self._get_partition(partition_id)
if not partition:
raise ValueError(f"Partition {partition_id} not found")
global_indexes_to_clear = [idx for _, idx in group]
if not set(global_indexes_to_clear).issubset(partition.global_indexes):
raise ValueError(
f"Some global_indexes to clear do not exist in partition {partition_id}. "
f"Target: {global_indexes_to_clear}, Existing: {partition.global_indexes}"
)
partition.clear_data(global_indexes_to_clear, clear_consumption)
self.index_manager.release_indexes(partition_id, global_indexes_to_clear)
def kv_retrieve_meta(
self,
keys: list[str],
partition_id: str,
create: bool = False,
) -> BatchMeta:
"""
Retrieve BatchMeta from the controller using a list of keys.
Args:
keys: List of keys to retrieve from the controller
partition_id: Partition id to retrieve from the controller
create: Whether to register new keys if not found.
Returns:
metadata: BatchMeta of the requested keys
"""
logger.debug(f"[{self.controller_id}] Retrieve keys {keys} in partition {partition_id}")
partition = self._get_partition(partition_id)
if partition is None:
if not create:
logger.warning(f"Partition {partition_id} not found!")
return BatchMeta.empty()
self.create_partition(partition_id)
partition = self._get_partition(partition_id)
assert partition is not None
global_indexes = partition.kv_retrieve_indexes(keys)
none_indexes = [idx for idx, value in enumerate(global_indexes) if value is None]
if len(none_indexes) > 0:
if not create:
logger.error(f"Keys {[keys[i] for i in none_indexes]} were not found in partition {partition_id}!")
return BatchMeta.empty()
else:
batch_global_indexes = partition.activate_pre_allocated_indexes(len(none_indexes))
if len(batch_global_indexes) < len(none_indexes):
new_global_indexes = self.index_manager.allocate_indexes(
partition_id, count=(len(none_indexes) - len(batch_global_indexes))
)
batch_global_indexes.extend(new_global_indexes)
partition.global_indexes.update(batch_global_indexes)
for i in range(len(none_indexes)):
global_indexes[none_indexes[i]] = batch_global_indexes[i]
partition.keys_mapping[keys[none_indexes[i]]] = batch_global_indexes[i]
partition.revert_keys_mapping[batch_global_indexes[i]] = keys[none_indexes[i]]
partition.ensure_samples_capacity(max(batch_global_indexes) + 1)
verified_global_indexes = [idx for idx in global_indexes if idx is not None]
assert len(verified_global_indexes) == len(keys)
col_mask = partition.production_status[verified_global_indexes, :].sum(dim=0).reshape(-1) == len(
verified_global_indexes
)
data_fields = []
for field_name, col_idx in partition.field_name_mapping.items():
if col_idx < len(col_mask) and col_mask[col_idx]:
data_fields.append(field_name)
return self.generate_batch_meta(partition_id, verified_global_indexes, data_fields, mode="force_fetch")
def kv_retrieve_keys(
self,
global_indexes: list[int],
partition_id: str,
) -> list[str | None]:
"""
Retrieve keys from the controller using a list of global_indexes.
Args:
global_indexes: List of global_indexes to retrieve keys from the controller
partition_id: Partition id to retrieve from the controller
Returns:
metadata: BatchMeta of the requested keys
"""
logger.debug(f"[{self.controller_id}]: Retrieve global_indexes {global_indexes} in partition {partition_id}")
partition = self._get_partition(partition_id)
if partition is None:
logger.warning(f"Partition {partition_id} were not found in controller!")
return []
assert partition is not None
keys = partition.kv_retrieve_keys(global_indexes)
none_indexes = [idx for idx, value in enumerate(global_indexes) if value is None]
if len(none_indexes) > 0:
logger.error(
f"Key for global_index {[keys[i] for i in none_indexes]} were not found in partition {partition_id}!"
)
return []
return keys
def _init_zmq_socket(self):
"""Initialize ZMQ sockets for communication."""
self.zmq_context = zmq.Context()
self._node_ip = get_node_ip_address()
while True:
try:
self._handshake_socket_port = get_free_port(ip=self._node_ip)
self._request_handle_socket_port = get_free_port(ip=self._node_ip)
self.handshake_socket = create_zmq_socket(
ctx=self.zmq_context,
socket_type=zmq.ROUTER,
ip=self._node_ip,
)
self.handshake_socket.bind(format_zmq_address(self._node_ip, self._handshake_socket_port))
self.request_handle_socket = create_zmq_socket(
ctx=self.zmq_context,
socket_type=zmq.ROUTER,
ip=self._node_ip,
)
self.request_handle_socket.bind(format_zmq_address(self._node_ip, self._request_handle_socket_port))
break
except zmq.ZMQError:
logger.warning(f"[{self.controller_id}]: Try to bind ZMQ sockets failed, retrying...")
continue
self.zmq_server_info = ZMQServerInfo(
role=Role.CONTROLLER,
id=self.controller_id,
ip=self._node_ip,
ports={
"handshake_socket": self._handshake_socket_port,
"request_handle_socket": self._request_handle_socket_port,
},
)
def _wait_connection(self):
"""Wait for storage instances to complete handshake with retransmission support."""
poller = zmq.Poller()
poller.register(self.handshake_socket, zmq.POLLIN)
logger.debug(f"Controller {self.controller_id} started waiting for storage connections...")
while True:
socks = dict(poller.poll(1000))
if self.handshake_socket in socks:
try:
messages = self.handshake_socket.recv_multipart(copy=False)
identity = messages.pop(0)
serialized_msg = messages
request_msg = ZMQMessage.deserialize(serialized_msg)
if request_msg.request_type == ZMQRequestType.HANDSHAKE:
storage_manager_id = request_msg.sender_id
response_msg = ZMQMessage.create(
request_type=ZMQRequestType.HANDSHAKE_ACK,
sender_id=self.controller_id,
body={},
).serialize()
self.handshake_socket.send_multipart([identity, *response_msg])
if storage_manager_id not in self._connected_storage_managers:
self._connected_storage_managers.add(storage_manager_id)
storage_manager_type = request_msg.body.get("storage_manager_type", "Unknown")
logger.debug(
f"[{self.controller_id}]: received handshake from "
f"storage manager {storage_manager_id} (type: {storage_manager_type}). "
f"Total connected: {len(self._connected_storage_managers)}"
)
else:
logger.debug(
f"[{self.controller_id}]: received duplicate handshake from "
f"storage manager {storage_manager_id}. Resending ACK."
)
except Exception as e:
logger.error(f"[{self.controller_id}]: error processing handshake: {e}")
def _start_process_handshake(self):
"""Start the handshake process thread."""
self.wait_connection_thread = Thread(
target=self._wait_connection,
name="TransferQueueControllerWaitConnectionThread",
daemon=True,
)
self.wait_connection_thread.start()
def _start_process_request(self):
"""Start the request processing thread."""
self.process_request_thread = Thread(
target=self._process_request,
name="TransferQueueControllerProcessRequestThread",
daemon=True,
)
self.process_request_thread.start()
def _process_request(self):
"""Main request processing loop - adapted for partition-based operations."""
logger.info(f"[{self.controller_id}]: start processing requests...")
perf_monitor = IntervalPerfMonitor(caller_name=self.controller_id)
while True:
monitor = self._metrics if self._metrics is not None else perf_monitor
messages = self.request_handle_socket.recv_multipart(copy=False)
identity = messages.pop(0)
serialized_msg = messages
request_msg = ZMQMessage.deserialize(serialized_msg)
if request_msg.request_type == ZMQRequestType.GET_META:
with monitor.measure(op_type="GET_META"):
params = request_msg.body
metadata = self.get_metadata(
data_fields=params["data_fields"],
batch_size=params["batch_size"],
partition_id=params["partition_id"],
mode=params.get("mode", "fetch"),
task_name=params.get("task_name"),
sampling_config=params.get("sampling_config", {}),
)
response_msg = ZMQMessage.create(
request_type=ZMQRequestType.GET_META_RESPONSE,
sender_id=self.controller_id,
receiver_id=request_msg.sender_id,
body={"metadata": metadata},
)
elif request_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE:
with monitor.measure(op_type="NOTIFY_DATA_UPDATE"):
message_data = request_msg.body
partition_id = message_data.get("partition_id")
global_indexes = message_data.get("global_indexes", [])
success = self.update_production_status(
partition_id=partition_id,
global_indexes=global_indexes,
field_schema=message_data.get("field_schema", {}),
custom_backend_meta=message_data.get("custom_backend_meta", {}),
)
if success:
if self._metrics is not None:
self._metrics.record_samples("NOTIFY_DATA_UPDATE", len(global_indexes))
logger.debug(f"[{self.controller_id}]: Updated production status for partition {partition_id}")
response_msg = ZMQMessage.create(
request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ACK,
sender_id=self.controller_id,
receiver_id=request_msg.sender_id,
body={
"controller_id": self.controller_id,
"partition_id": partition_id,
"success": success,
},
)
elif request_msg.request_type == ZMQRequestType.GET_PARTITION_META:
with monitor.measure(op_type="GET_PARTITION_META"):
params = request_msg.body
partition_id = params["partition_id"]
partition = self._get_partition(partition_id)
if partition is not None:
partition_data_fields = list(partition.field_name_mapping.keys())
metadata = self.get_metadata(
data_fields=partition_data_fields,
partition_id=partition_id,
mode="force_fetch",
)
else:
metadata = None
response_msg = ZMQMessage.create(
request_type=ZMQRequestType.GET_PARTITION_META_RESPONSE,
sender_id=self.controller_id,
receiver_id=request_msg.sender_id,
body={"metadata": metadata},
)
elif request_msg.request_type == ZMQRequestType.SET_CUSTOM_META:
with monitor.measure(op_type="SET_CUSTOM_META"):
params = request_msg.body
partition_custom_meta = params["partition_custom_meta"]
self.set_custom_meta(partition_custom_meta=partition_custom_meta)
response_msg = ZMQMessage.create(
request_type=ZMQRequestType.SET_CUSTOM_META_RESPONSE,
sender_id=self.controller_id,
receiver_id=request_msg.sender_id,
body={"message": "Successfully set custom_meta"},
)
elif request_msg.request_type == ZMQRequestType.CLEAR_META:
with monitor.measure(op_type="CLEAR_META"):
params = request_msg.body
global_indexes = params["global_indexes"]
partition_ids = params["partition_ids"]
self.clear_meta(global_indexes, partition_ids)
if self._metrics is not None:
self._metrics.record_samples("CLEAR_META", len(global_indexes))
response_msg = ZMQMessage.create(
request_type=ZMQRequestType.CLEAR_META_RESPONSE,
sender_id=self.controller_id,
receiver_id=request_msg.sender_id,
body={"message": f"Clear samples operation completed by controller {self.controller_id}"},
)
elif request_msg.request_type == ZMQRequestType.CLEAR_PARTITION:
with monitor.measure(op_type="CLEAR_PARTITION"):
params = request_msg.body
partition_id = params["partition_id"]
self.clear_partition(partition_id)
response_msg = ZMQMessage.create(
request_type=ZMQRequestType.CLEAR_PARTITION_RESPONSE,
sender_id=self.controller_id,
receiver_id=request_msg.sender_id,
body={"message": f"Clear partition operation completed by controller {self.controller_id}"},
)
elif request_msg.request_type == ZMQRequestType.GET_CONSUMPTION:
with monitor.measure(op_type="GET_CONSUMPTION"):
params = request_msg.body
global_index, consumption_status = self.get_consumption_status(
params["partition_id"], params["task_name"]
)
sample_filter = params.get("sample_filter")
if sample_filter and consumption_status is not None:
consumption_status = consumption_status[sample_filter]
response_msg = ZMQMessage.create(
request_type=ZMQRequestType.CONSUMPTION_RESPONSE,
sender_id=self.controller_id,
receiver_id=request_msg.sender_id,
body={
"partition_id": params["partition_id"],
"global_index": global_index,
"consumption_status": consumption_status,
},
)
elif request_msg.request_type == ZMQRequestType.RESET_CONSUMPTION:
with monitor.measure(op_type="RESET_CONSUMPTION"):
params = request_msg.body
partition_id = params["partition_id"]
task_name = params.get("task_name")
try:
self.reset_consumption(partition_id, task_name)
response_msg = ZMQMessage.create(
request_type=ZMQRequestType.RESET_CONSUMPTION_RESPONSE,
sender_id=self.controller_id,
receiver_id=request_msg.sender_id,
body={
"partition_id": partition_id,
"success": True,
"message": f"Consumption reset for partition {partition_id}",
},
)
except Exception as e:
response_msg = ZMQMessage.create(
request_type=ZMQRequestType.RESET_CONSUMPTION_RESPONSE,
sender_id=self.controller_id,
receiver_id=request_msg.sender_id,
body={
"partition_id": partition_id,
"success": False,
"message": str(e),
},
)
elif request_msg.request_type == ZMQRequestType.GET_PRODUCTION:
with monitor.measure(op_type="GET_PRODUCTION"):
params = request_msg.body
global_index, production_status = self.get_production_status(
params["partition_id"], params["data_fields"]
)
response_msg = ZMQMessage.create(
request_type=ZMQRequestType.PRODUCTION_RESPONSE,
sender_id=self.controller_id,
receiver_id=request_msg.sender_id,
body={
"partition_id": params["partition_id"],
"global_index": global_index,
"production_status": production_status,
},
)
elif request_msg.request_type == ZMQRequestType.GET_LIST_PARTITIONS:
with monitor.measure(op_type="GET_LIST_PARTITIONS"):
partition_ids = self.list_partitions()
response_msg = ZMQMessage.create(
request_type=ZMQRequestType.LIST_PARTITIONS_RESPONSE,
sender_id=self.controller_id,
receiver_id=request_msg.sender_id,
body={"partition_ids": partition_ids},
)
elif request_msg.request_type == ZMQRequestType.KV_RETRIEVE_META:
with monitor.measure(op_type="KV_RETRIEVE_META"):
params = request_msg.body
keys = params["keys"]
partition_id = params["partition_id"]
create = params["create"]
metadata = self.kv_retrieve_meta(keys=keys, partition_id=partition_id, create=create)
response_msg = ZMQMessage.create(
request_type=ZMQRequestType.KV_RETRIEVE_META_RESPONSE,
sender_id=self.controller_id,
receiver_id=request_msg.sender_id,
body={"metadata": metadata},
)
elif request_msg.request_type == ZMQRequestType.KV_RETRIEVE_KEYS:
with monitor.measure(op_type="KV_RETRIEVE_KEYS"):
params = request_msg.body
global_indexes = params["global_indexes"]
partition_id = params["partition_id"]
keys = self.kv_retrieve_keys(global_indexes=global_indexes, partition_id=partition_id)
response_msg = ZMQMessage.create(
request_type=ZMQRequestType.KV_RETRIEVE_KEYS_RESPONSE,
sender_id=self.controller_id,
receiver_id=request_msg.sender_id,
body={"keys": keys},
)
elif request_msg.request_type == ZMQRequestType.KV_LIST:
with monitor.measure(op_type="KV_LIST"):
params = request_msg.body
partition_id = params["partition_id"]
if partition_id is None:
partition_id = list(self.partitions.keys())
else:
partition_id = [partition_id]
message = "success"
partition_info = {}
for pid in partition_id:
partition = self._get_partition(pid)
if partition:
keys = list(partition.keys_mapping.keys())
single_partition_info = {
k: partition.custom_meta.get(partition.keys_mapping[k], {}) for k in keys
}
partition_info[pid] = single_partition_info
else:
message = f"partition {pid} does not exist"
response_msg = ZMQMessage.create(
request_type=ZMQRequestType.KV_LIST_RESPONSE,
sender_id=self.controller_id,
receiver_id=request_msg.sender_id,
body={"partition_info": partition_info, "message": message},
)
self.request_handle_socket.send_multipart([identity, *response_msg.serialize()])
def get_zmq_server_info(self) -> ZMQServerInfo:
"""Get ZMQ server connection information."""
return self.zmq_server_info
def store_config(self, conf: DictConfig) -> None:
"""Store the global config of TransferQueue."""
self.tq_config = conf
def get_config(self) -> DictConfig:
"""Retrieve the global config of TransferQueue."""
return self.tq_config
def register_sampler(
self,
sampler: BaseSampler | type[BaseSampler] = SequentialSampler,
) -> None:
"""
Register a sampler instance or subclass after the controller is initialized.
Args:
sampler: Sampler instance or sampler class to use for data sampling.
- If a BaseSampler instance is provided, it will be used directly
- If a BaseSampler subclass is provided, it will be instantiated
- Defaults to SequentialSampler for simple sequential sampling
- Example: sampler=GRPOGroupNSampler() (instance)
- Example: sampler=SequentialSampler (class)
"""
if isinstance(sampler, BaseSampler):
self.sampler = sampler
elif isinstance(sampler, type) and issubclass(sampler, BaseSampler):
self.sampler = sampler()
else:
raise TypeError(
f"sampler {getattr(sampler, '__name__', repr(sampler))} must be an instance or subclass of BaseSampler"
)
def _build_metrics_snapshot(self) -> dict:
"""Build a plain-dict snapshot of controller state for the metrics exporter.
The snapshot contains only primitive / dict values — no references to
live controller objects — so the metrics thread can read it safely.
"""
partitions_data: dict = {}
for pid, partition in list(self.partitions.items()):
try:
stats = partition.get_statistics()
production_statistics: dict = {}
if "production_progress" in stats:
production_statistics["all"] = {
"production_progress": stats["production_progress"],
}
for task_name in stats.get("registered_tasks", []):
production_statistics[task_name] = {
"production_progress": stats.get("production_progress", 0),
}
partitions_data[pid] = {
"total_samples_num": stats["total_samples_num"],
"production_statistics": production_statistics,
"consumption_statistics": {
task_name: {"consumption_progress": cstats.get("consumption_progress", 0)}
for task_name, cstats in stats.get("consumption_statistics", {}).items()
},
}
except Exception:
pass
return {
"partitions": partitions_data,
"global_index_allocated": len(self.index_manager.allocated_indexes),
"global_index_reusable": len(self.index_manager.reusable_indexes),
}
def _push_metrics_snapshot(self) -> None:
"""Push a fresh metrics snapshot to the exporter (called from controller threads)."""
if self._metrics is None:
return
try:
snapshot = self._build_metrics_snapshot()
self._metrics.update_controller_snapshot(snapshot)
except Exception as e:
logger.debug(f"[{self.controller_id}]: Failed to push metrics snapshot: {e}")
def start_metrics(self, port: int = 0) -> str:
"""Initialize and start the Prometheus metrics exporter.
This creates a ``TQMetricsExporter``, starts an HTTP ``/metrics``
endpoint and a background collection thread. The method is safe
to call multiple times -- subsequent calls return the existing endpoint.
Args:
port: HTTP port for the /metrics endpoint (0 = auto-assign).
Returns:
The metrics endpoint address in ``host:port`` format.
"""
if self._metrics is not None:
return self._metrics_endpoint
from transfer_queue.metrics import TQMetricsExporter
self._metrics = TQMetricsExporter()
self._metrics_endpoint = self._metrics.start(node_ip=self._node_ip, port=port)
self._metrics_snapshot_thread = Thread(
target=self._metrics_snapshot_loop,
name="TQMetricsSnapshotThread",
daemon=True,
)
self._metrics_snapshot_thread.start()
logger.info(f"[{self.controller_id}]: Prometheus metrics exporter started on {self._metrics_endpoint}")
return self._metrics_endpoint
def _metrics_snapshot_loop(self) -> None:
"""Periodically push a metrics snapshot to the exporter."""
from transfer_queue.metrics import TQ_METRICS_COLLECT_INTERVAL
while True:
self._push_metrics_snapshot()
time.sleep(TQ_METRICS_COLLECT_INTERVAL)
def register_storage_units_for_metrics(self, storage_unit_infos: dict) -> None:
"""Register storage unit ZMQ endpoints so the metrics collector can query them.
Args:
storage_unit_infos: Mapping of storage unit names/IDs to ``ZMQServerInfo``.
"""
if self._metrics is not None:
self._metrics.register_storage_units(storage_unit_infos)