from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Callable, Optional
import torch
from torch import Tensor
from transfer_queue.storage.clients.base import StorageClientFactory, StorageKVClient
from transfer_queue.utils.logging_utils import get_logger
from transfer_queue.utils.serial_utils import batch_decode_from, batch_encode_into
from transfer_queue.utils.yuanrong_utils import find_reachable_host
logger = get_logger(__name__)
YUANRONG_DATASYSTEM_IMPORTED: bool = True
try:
from yr import datasystem
except ImportError:
YUANRONG_DATASYSTEM_IMPORTED = False
class StorageStrategy(ABC):
"""Abstract base class for storage strategies."""
@staticmethod
@abstractmethod
def init(config: dict) -> Optional["StorageStrategy"]:
"""Initialize strategy from config; return None if not applicable."""
@abstractmethod
def strategy_tag(self) -> Any:
"""Return metadata identifying this strategy (e.g., string name, byte tag)."""
@abstractmethod
def supports_put(self, value: Any) -> bool:
"""Check if this strategy can store the given value."""
@abstractmethod
def put(self, keys: list[str], values: list[Any]) -> None:
"""Store key-value pairs using this strategy."""
@abstractmethod
def supports_get(self, strategy_tag: Any) -> bool:
"""Check if this strategy can retrieve data with given tag."""
@abstractmethod
def get(self, keys: list[str], **kwargs) -> list[Any | None]:
"""Retrieve values by keys; kwargs may include shapes/dtypes."""
@abstractmethod
def supports_clear(self, strategy_tag: Any) -> bool:
"""Check if this strategy owns data identified by metadata."""
@abstractmethod
def clear(self, keys: list[str]) -> None:
"""Delete keys from storage."""
class NPUTensorKVClientAdapter(StorageStrategy):
"""Adapter for YuanRong's high-performance NPU tensor storage.
Using yr.datasystem.DsTensorClient to connect datasystem backends.
"""
KEYS_LIMIT: int = 10_000
def __init__(self, config: dict):
port = config.get("worker_port")
if port is None or not isinstance(port, int):
raise ValueError("Missing or invalid 'worker_port' in config")
logger.info(f"Auto-detecting reachable host for Yuanrong port {port}...")
host = find_reachable_host(port)
if host is None:
raise ValueError(
f"Could not find any reachable host for Yuanrong port {port}. "
"Please ensure yuanrong datasystem is running."
)
logger.info(f"Using auto-detected host: {host}")
self.device_id = torch.npu.current_device()
torch.npu.set_device(self.device_id)
self._ds_client = datasystem.DsTensorClient(host, port, self.device_id)
self._ds_client.init()
logger.info("YuanrongStorageClient: Create DsTensorClient to connect with yuanrong-datasystem backend!")
@staticmethod
def init(config: dict) -> Optional["StorageStrategy"]:
"""Initialize only if NPU and torch_npu are available."""
torch_npu_imported: bool = True
try:
import torch_npu
except ImportError:
torch_npu_imported = False
enable = config.get("enable_yr_npu_transport", True)
if not (enable and torch_npu_imported and torch.npu.is_available()):
return None
return NPUTensorKVClientAdapter(config)
def strategy_tag(self) -> str:
"""Strategy tag for NPU tensor storage. Using a single byte is for better performance."""
return "1"
def supports_put(self, value: Any) -> bool:
"""Supports contiguous NPU tensors only."""
if not (isinstance(value, torch.Tensor) and value.device.type == "npu"):
return False
return value.is_contiguous()
def put(self, keys: list[str], values: list[Any]) -> None:
"""Store NPU tensors in batches; deletes before overwrite."""
for i in range(0, len(keys), self.KEYS_LIMIT):
batch_keys = keys[i : i + self.KEYS_LIMIT]
batch_values = values[i : i + self.KEYS_LIMIT]
try:
self._ds_client.delete(batch_keys)
except Exception:
pass
self._ds_client.mset_d2h(batch_keys, batch_values)
def supports_get(self, strategy_tag: str) -> bool:
"""Matches 'DsTensorClient' Strategy tag."""
return isinstance(strategy_tag, str) and strategy_tag == self.strategy_tag()
def get(self, keys: list[str], **kwargs) -> list[Any | None]:
"""Fetch NPU tensors using pre-allocated empty buffers."""
shapes = kwargs.get("shapes", None)
dtypes = kwargs.get("dtypes", None)
if shapes is None or dtypes is None:
raise ValueError("YuanrongStorageClient needs Expected shapes and dtypes")
results = []
for i in range(0, len(keys), self.KEYS_LIMIT):
batch_keys = keys[i : i + self.KEYS_LIMIT]
batch_shapes = shapes[i : i + self.KEYS_LIMIT]
batch_dtypes = dtypes[i : i + self.KEYS_LIMIT]
batch_values = self._create_empty_npu_tensorlist(batch_shapes, batch_dtypes)
self._ds_client.mget_h2d(batch_keys, batch_values)
results.extend(batch_values)
return results
def supports_clear(self, strategy_tag: str) -> bool:
"""Matches 'DsTensorClient' strategy tag."""
return isinstance(strategy_tag, str) and strategy_tag == self.strategy_tag()
def clear(self, keys: list[str]) -> None:
"""Delete NPU tensor keys in batches."""
for i in range(0, len(keys), self.KEYS_LIMIT):
batch = keys[i : i + self.KEYS_LIMIT]
self._ds_client.delete(batch)
def _create_empty_npu_tensorlist(self, shapes: list[Any], dtypes: list[Any]) -> list[Tensor]:
"""
Create a list of empty NPU tensors with given shapes and dtypes.
Args:
shapes (list): List of tensor shapes (e.g., [(3,), (2, 4)])
dtypes (list): List of torch dtypes (e.g., [torch.float32, torch.int64])
Returns:
list[Tensor]: List of uninitialized NPU tensors
"""
tensors: list[Tensor] = []
for shape, dtype in zip(shapes, dtypes, strict=True):
tensor = torch.empty(shape, dtype=dtype, device=f"npu:{self.device_id}")
tensors.append(tensor)
return tensors
class GeneralKVClientAdapter(StorageStrategy):
"""Adapter for general-purpose KV storage with serialization.
Using yr.datasystem.KVClient to connect datasystem backends.
The serialization method uses 'batch_encode_into' and 'batch_decode_from' from 'transfer_queue.utils.serial_utils'.
"""
PUT_KEYS_LIMIT: int = 10_000
GET_CLEAR_KEYS_LIMIT: int = 10_000
DS_MAX_WORKERS: int = 16
def __init__(self, config: dict):
port = config.get("worker_port")
if port is None or not isinstance(port, int):
raise ValueError("Missing or invalid 'worker_port' in config")
logger.info(f"Auto-detecting reachable host for Yuanrong port {port}...")
host = find_reachable_host(port)
if host is None:
raise ValueError(
f"Could not find any reachable host for Yuanrong port {port}. "
"Please ensure yuanrong datasystem is running."
)
logger.info(f"Using auto-detected host: {host}")
self._ds_client = datasystem.KVClient(host, port)
self._ds_client.init()
logger.info("YuanrongStorageClient: Create KVClient to connect with yuanrong-datasystem backend!")
@staticmethod
def init(config: dict) -> Optional["StorageStrategy"]:
"""Always enabled for general objects."""
return GeneralKVClientAdapter(config)
def strategy_tag(self) -> str:
"""Strategy tag for general KV storage. Using a single byte is for better performance."""
return "2"
def supports_put(self, value: Any) -> bool:
"""Accepts any Python object."""
return True
def put(self, keys: list[str], values: list[Any]) -> None:
"""Store objects via zero-copy serialization in batches."""
for i in range(0, len(keys), self.PUT_KEYS_LIMIT):
batch_keys = keys[i : i + self.PUT_KEYS_LIMIT]
batch_vals = values[i : i + self.PUT_KEYS_LIMIT]
self.mset_zero_copy(batch_keys, batch_vals)
def supports_get(self, strategy_tag: str) -> bool:
"""Matches 'KVClient' strategy tag."""
return isinstance(strategy_tag, str) and strategy_tag == self.strategy_tag()
def get(self, keys: list[str], **kwargs) -> list[Any | None]:
"""Retrieve and deserialize objects in batches."""
results = []
for i in range(0, len(keys), self.GET_CLEAR_KEYS_LIMIT):
batch_keys = keys[i : i + self.GET_CLEAR_KEYS_LIMIT]
objects = self.mget_zero_copy(batch_keys)
results.extend(objects)
return results
def supports_clear(self, strategy_tag: str) -> bool:
"""Matches 'KVClient' strategy tag."""
return isinstance(strategy_tag, str) and strategy_tag == self.strategy_tag()
def clear(self, keys: list[str]) -> None:
"""Delete keys in batches."""
for i in range(0, len(keys), self.GET_CLEAR_KEYS_LIMIT):
batch_keys = keys[i : i + self.GET_CLEAR_KEYS_LIMIT]
self._ds_client.delete(batch_keys)
def mset_zero_copy(self, keys: list[str], objs: list[Any]):
"""Store multiple objects in zero-copy mode using parallel serialization and buffer packing.
Args:
keys (list[str]): List of string keys under which the objects will be stored.
objs (list[Any]): List of Python objects to store (e.g., tensors, strings).
"""
buffers: list = []
def alloc(sizes):
mcreate_bufs = self._ds_client.mcreate(keys, sizes)
buffers.extend(mcreate_bufs)
return [buf.MutableData() for buf in mcreate_bufs]
batch_encode_into(objs, alloc, num_workers=self.DS_MAX_WORKERS)
self._ds_client.mset_buffer(buffers)
def mget_zero_copy(self, keys: list[str]) -> list[Any]:
"""Retrieve multiple objects in zero-copy mode by directly deserializing from shared memory buffers.
Args:
keys (list[str]): List of string keys to retrieve from storage.
Returns:
list[Any]: List of deserialized objects corresponding to the input keys.
"""
buffers = self._ds_client.get_buffers(keys)
valid_indexes = [i for i, buf in enumerate(buffers) if buf is not None]
if valid_indexes and len(valid_indexes) < len(keys):
logger.warning(
f"{len(keys) - len(valid_indexes)} requested keys were not found in openYuanrong-datasystem storage. "
f"Returned results will contain None for these keys."
)
valid_bufs = [buffers[i] for i in valid_indexes]
decoded_objs = batch_decode_from(valid_bufs)
results = [None] * len(keys)
for idx, obj in zip(valid_indexes, decoded_objs, strict=True):
results[idx] = obj
return results
@StorageClientFactory.register("YuanrongStorageClient")
class YuanrongStorageClient(StorageKVClient):
"""
Storage client for YuanRong DataSystem.
Use different storage strategies depending on the data type.
Supports storing and fetching both:
- NPU tensors via NPUTensorKVClientAdapter (for high performance).
- General objects (CPU tensors, str, bool, list, etc.) via GeneralKVClientAdapter with serialization.
"""
ROUTE_ITEM_AS_VALUE = "value"
ROUTE_ITEM_AS_BACKEND_META = "backend_meta"
def __init__(self, config: dict[str, Any]):
if not YUANRONG_DATASYSTEM_IMPORTED:
raise ImportError("YuanRong DataSystem not installed.")
port = config.get("worker_port")
if port is None or not isinstance(port, int):
raise ValueError("Missing or invalid 'worker_port' in config")
super().__init__(config)
storage_strategies_priority = [NPUTensorKVClientAdapter, GeneralKVClientAdapter]
self._strategies: list[StorageStrategy] = []
for strategy_cls in storage_strategies_priority:
strategy = strategy_cls.init(config)
if strategy is not None:
self._strategies.append(strategy)
if not self._strategies:
raise RuntimeError("No storage backend available for YuanrongStorageClient")
def put(self, keys: list[str], values: list[Any]) -> list[str]:
"""Stores multiple key-value pairs to remote storage.
Automatically routes NPU tensors to high-performance tensor storage,
and other objects to general-purpose KV storage.
Args:
keys (List[str]): List of unique string identifiers.
values (List[Any]): List of values to store (tensors, scalars, dicts, etc.).
Returns:
List[str]: storage strategy tag of YuanrongStorageClient in the same order as input keys.
"""
if not isinstance(keys, list) or not isinstance(values, list):
raise ValueError("keys and values must be lists")
if len(keys) != len(values):
raise ValueError("Number of keys must match number of values")
routed_indexes = self._route_to_strategies(
values, lambda strategy_, item_: strategy_.supports_put(item_), item_label=self.ROUTE_ITEM_AS_VALUE
)
def put_task(strategy, indexes):
strategy.put([keys[i] for i in indexes], [values[i] for i in indexes])
return strategy.strategy_tag(), indexes
strategy_tags: list[str] = [""] * len(keys)
for tag, indexes in self._dispatch_tasks(routed_indexes, put_task):
for original_index in indexes:
strategy_tags[original_index] = tag
return strategy_tags
def get(
self,
keys: list[str],
shapes: list[Any] | None = None,
dtypes: list[Any] | None = None,
custom_backend_meta: list[str] | None = None,
) -> list[Any]:
"""Retrieves multiple values from remote storage with expected metadata.
Requires shape and dtype hints to reconstruct NPU tensors correctly.
Args:
keys (List[str]): Keys to fetch.
shapes (List[List[int]]): Expected tensor shapes (use [] for scalars).
dtypes (List[Optional[torch.dtype]]): Expected dtypes; use None for non-tensor data.
custom_backend_meta (List[str]): StorageStrategy tag for each key
Returns:
List[Any]: Retrieved values in the same order as input keys.
"""
if shapes is None or dtypes is None or custom_backend_meta is None:
raise ValueError("YuanrongStorageClient.get() needs Expected shapes, dtypes and custom_backend_meta")
if not (len(keys) == len(shapes) == len(dtypes) == len(custom_backend_meta)):
raise ValueError("Lengths of keys, shapes, dtypes, custom_backend_meta must match")
if any(not tag for tag in custom_backend_meta):
raise ValueError(
"Some keys have no backend metadata (empty string), indicating they "
"were not previously stored. Ensure all keys have been put before calling get."
)
strategy_tags = custom_backend_meta
routed_indexes = self._route_to_strategies(
strategy_tags,
lambda strategy_, item_: strategy_.supports_get(item_),
item_label=self.ROUTE_ITEM_AS_BACKEND_META,
)
def get_task(strategy, indexes):
res = strategy.get(
[keys[i] for i in indexes], shapes=[shapes[i] for i in indexes], dtypes=[dtypes[i] for i in indexes]
)
return res, indexes
results = [None] * len(keys)
for strategy_res, indexes in self._dispatch_tasks(routed_indexes, get_task):
for value, original_index in zip(strategy_res, indexes, strict=True):
results[original_index] = value
return results
def clear(self, keys: list[str], custom_backend_meta: list[str] | None = None) -> None:
"""Deletes multiple keys from remote storage.
Args:
keys (List[str]): List of keys to remove.
custom_backend_meta (List[str]): StorageStrategy tag for each key
"""
if not isinstance(keys, list) or not isinstance(custom_backend_meta, list):
raise ValueError("keys and custom_backend_meta must be a list")
if len(custom_backend_meta) != len(keys):
raise ValueError("custom_backend_meta length must match keys")
strategy_tags = custom_backend_meta
routed_indexes = self._route_to_strategies(
strategy_tags,
lambda strategy_, item_: strategy_.supports_clear(item_),
ignore_unmatched=True,
item_label=self.ROUTE_ITEM_AS_BACKEND_META,
)
def clear_task(strategy, indexes):
strategy.clear([keys[i] for i in indexes])
self._dispatch_tasks(routed_indexes, clear_task)
def _route_to_strategies(
self,
items: list[Any],
selector: Callable[[StorageStrategy, Any], bool],
*,
ignore_unmatched: bool = False,
item_label: str,
) -> dict[StorageStrategy, list[int]]:
"""Groups item indices by the first strategy that supports them.
Used to route data to storage strategies by grouped indexes.
Args:
items: A list used to distinguish which storage strategy the data is routed to.
e.g., route <keys, values> for put based on types of values,
or route <keys, Optional[shapes], Optional[dtypes]> for get/clear based on strategy_tags.
The order must correspond to the original keys.
selector: A function that determines whether a strategy supports an item.
Signature: `(strategy: StorageStrategy, item: Any) -> bool`.
ignore_unmatched: If True, items that don't match any strategy will be ignored (not included in output).
If False, a ValueError will be raised for any unmatched item.
item_label: Description of what `items` represents, used in error messages.
Use ROUTE_ITEM_AS_VALUE for put (user-provided data),
or ROUTE_ITEM_AS_BACKEND_META for get/clear (backend metadata).
Returns:
A dictionary mapping each active strategy to a list of indexes in `items`
that it should handle. Every index appears exactly once.
"""
unmatched_count = 0
warning_count = 0
routed_indexes: dict[StorageStrategy, list[int]] = {s: [] for s in self._strategies}
for i, item in enumerate(items):
for strategy in self._strategies:
if selector(strategy, item):
routed_indexes[strategy].append(i)
break
else:
if ignore_unmatched:
if item:
warning_count += 1
unmatched_count += 1
else:
if item_label == self.ROUTE_ITEM_AS_BACKEND_META:
raise ValueError(
"Cannot retrieve stored data because the backend that originally "
"stored it is unavailable in the current process or node. Please "
"check that the configuration and NPU resource availability are "
"consistent across all processes and nodes."
)
else:
raise ValueError(f"No storage backend can handle {item_label} of type {type(item).__name__}.")
if warning_count > 0:
logger.warning(
f"{warning_count} stored items could not be processed because the backend "
f"that originally handled them may be unavailable in the current process or "
f"node. Please check that the configuration and NPU resource availability "
f"are consistent across all processes and nodes."
)
if unmatched_count > warning_count:
logger.debug(
f"{unmatched_count - warning_count} items with empty {item_label} "
f"will be silently skipped (likely not previously stored)."
)
return routed_indexes
@staticmethod
def _dispatch_tasks(routed_tasks: dict[StorageStrategy, list[int]], task_function: Callable) -> list[Any]:
"""Executes tasks across one or more storage strategies, optionally in parallel.
Optimizes for common case: if only one strategy is involved, runs synchronously
to avoid thread overhead. Otherwise, uses thread pool for concurrency.
Args:
routed_tasks: Mapping from strategy to list of indexes it should process.
task_function: Callable accepting `(strategy, list_of_indexes)` and returning any result.
Returns:
List of results from `task_function`, one per active strategy, in arbitrary order.
Each result typically includes data and the corresponding indices for reassembly.
"""
active_tasks = [(strategy, indexes) for strategy, indexes in routed_tasks.items() if indexes]
if not active_tasks:
return []
if len(active_tasks) == 1:
return [task_function(*active_tasks[0])]
max_workers = min(len(active_tasks), 4)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(task_function, strategy, indexes) for strategy, indexes in active_tasks]
return [f.result() for f in futures]