import itertools
from typing import Any
import ray
import torch
from transfer_queue.storage.clients.base import StorageClientFactory, StorageKVClient
@ray.remote(max_concurrency=8)
class RayObjectRefStorage:
"""Ray object ref storage."""
def __init__(self):
self.storage_dict = {}
def put_obj_ref(self, keys: list[str], obj_refs: list[ray.ObjectRef]):
"""Put object ref to remote storage."""
self.storage_dict.update(itertools.starmap(lambda k, v: (k, v), zip(keys, obj_refs, strict=True)))
def get_obj_ref(self, keys: list[str]) -> list[ray.ObjectRef]:
"""Get object ref from remote storage."""
obj_refs = [self.storage_dict.get(key, None) for key in keys]
return obj_refs
def clear_obj_ref(self, keys: list[str]):
"""Clear object ref from remote storage."""
for key in keys:
if key in self.storage_dict:
del self.storage_dict[key]
@StorageClientFactory.register("RayStorageClient")
class RayStorageClient(StorageKVClient):
"""
Storage client for Ray RDT.
"""
def __init__(self, config=None):
if not ray.is_initialized():
raise RuntimeError("Ray is not initialized. Please call ray.init() before creating RayStorageClient.")
try:
self.storage_actor = ray.get_actor("RayObjectRefStorage", namespace="transfer_queue")
except ValueError:
self.storage_actor = RayObjectRefStorage.options(
name="RayObjectRefStorage", namespace="transfer_queue", get_if_exists=False
).remote()
def put(self, keys: list[str], values: list[Any]) -> list[Any] | None:
"""
Store tensors to remote storage.
Args:
keys (list): List of string keys
values (list): List of torch.Tensor on GPU(CUDA) or CPU
"""
if not isinstance(keys, list) or not isinstance(values, list):
raise ValueError(f"keys and values must be lists, but got {type(keys)} and {type(values)}")
if len(keys) != len(values):
raise ValueError("Number of keys must match number of values")
transports = itertools.repeat("nixl")
obj_refs = list(
itertools.starmap(
lambda v, tx: ray.put(v, _tensor_transport=tx) if isinstance(v, torch.Tensor) else ray.put(v),
zip(values, transports, strict=False),
)
)
ray.get(self.storage_actor.put_obj_ref.remote(keys, obj_refs))
return None
def get(self, keys: list[str], shapes=None, dtypes=None, custom_backend_meta=None) -> list[Any]:
"""
Retrieve objects from remote storage.
Args:
keys (list): List of string keys to fetch.
shapes (list, optional): Ignored. For compatibility with KVStorageManager.
dtypes (list, optional): Ignored. For compatibility with KVStorageManager.
custom_backend_meta (list, optional): Ray object ref for each key
Returns:
list: List of retrieved objects
"""
if not isinstance(keys, list):
raise ValueError(f"keys must be a list, but got {type(keys)}")
obj_refs = ray.get(self.storage_actor.get_obj_ref.remote(keys))
try:
values = ray.get(obj_refs)
except Exception as e:
raise RuntimeError(f"Failed to retrieve value for key '{keys}': {e}") from e
return values
def clear(self, keys: list[str], custom_backend_meta=None):
"""
Delete entries from storage by keys.
Args:
keys (list): List of keys to delete
custom_backend_meta (List[Any], optional): ...
"""
ray.get(self.storage_actor.clear_obj_ref.remote(keys))