from dataclasses import dataclass, field
from typing import Optional, Any
import ray
from ray.util.placement_group import PlacementGroup, placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from minirl.utils.utils import get_master_addr_port
def sort_placement_group_by_node_ip(pgs: list[PlacementGroup]):
node_ip = {node["NodeID"]: node["NodeManagerAddress"] for node in ray.nodes()}
pg_ip = {}
for pg in pgs:
specs = ray._private.state.state.placement_group_table(pg.id)
node_id = specs["bundles_to_node_id"][0]
pg_ip[pg.id] = node_ip[node_id]
return sorted(pgs, key=lambda pg: pg_ip[pg.id])
class RayResourcePool:
def __init__(
self,
process_on_nodes: Optional[list[int]] = None,
use_gpu: bool = True,
name_prefix: str = None,
max_colocate_count: int = 1,
) -> None:
if process_on_nodes is None:
process_on_nodes = []
self._store = process_on_nodes
self.max_colocate_count = max_colocate_count
self.n_gpus_per_node = 8
self.use_gpu = use_gpu
self.name_prefix = "default_pool" if name_prefix is None else name_prefix
@property
def world_size(self):
return sum(self._store)
@property
def store(self):
return self._store
def __call__(self):
return self._store
def get_placement_groups(self, strategy="STRICT_PACK", name=None, device_name="npu"):
pg_name_prefix = (
name if name else f"{self.name_prefix}verl_group_{'_'.join([str(count) for count in self._store])}:"
)
if device_name == "npu":
device_name = "NPU"
elif device_name == "cuda":
device_name = "GPU"
bundle = {"CPU": self.max_colocate_count}
if self.use_gpu:
bundle[device_name] = 1
pg_scheme = [[bundle.copy() for _ in range(process_count)] for process_count in self._store]
pgs = [
placement_group(bundles=bundles, strategy=strategy, name=pg_name_prefix + str(idx))
for idx, bundles in enumerate(pg_scheme)
]
ray.get([pg.ready() for pg in pgs])
self.pgs = sort_placement_group_by_node_ip(pgs)
return pgs
@dataclass
class ResourcePoolManager:
resource_pool_spec: dict[str, list[int]]
resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict)
def create_resource_pool(self):
"""Create Ray resource pools for distributed training.
Initializes resource pools based on the resource pool specification,
with each pool managing GPU resources across multiple nodes.
"""
for resource_pool_name, process_on_nodes in self.resource_pool_spec.items():
resource_pool = RayResourcePool(
process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name
)
self.resource_pool_dict[resource_pool_name] = resource_pool
def get_resource_pool(self, resource_pool_name: str) -> RayResourcePool:
"""Get the resource pool of the worker_cls"""
return self.resource_pool_dict[resource_pool_name]
class RayClassWithInitArgs:
def __init__(self, cls, *args, **kwargs) -> None:
self.cls = cls
self.args = args
self.kwargs = kwargs
self._options = {}
def update_options(self, options: dict):
"""Update the Ray actor creation options.
Args:
options: Dictionary of options to update
"""
self._options.update(options)
def __call__(
self,
placement_group,
placement_group_bundle_idx,
use_gpu: bool = True,
num_gpus=1,
device_name="npu",
) -> Any:
options = {
"scheduling_strategy": PlacementGroupSchedulingStrategy(
placement_group=placement_group, placement_group_bundle_index=placement_group_bundle_idx
)
}
options.update(self._options)
if use_gpu and device_name == "cuda":
options["num_gpus"] = num_gpus
elif use_gpu and device_name == "npu":
options["resources"] = {"NPU": num_gpus}
else:
raise ValueError(f"Unsupported device_name: {device_name}")
return self.cls.options(**options).remote(*self.args, **self.kwargs)
def create_workers_from_pgs(worker_cls, resource_pool: RayResourcePool, name_prefix: str, device_name: str = "npu"):
pgs = resource_pool.get_placement_groups(device_name=device_name)
if not pgs:
raise ValueError("No placement group created for rollout workers")
master_addr, master_port = ray.get(
get_master_addr_port.options(
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pgs[0],
placement_group_bundle_index=0,
),
).remote()
)
workers = []
rank = -1
for pg_idx, pg in enumerate(pgs):
local_world_size = resource_pool.store[pg_idx]
if local_world_size > pg.bundle_count:
raise ValueError(
f"Requested {local_world_size} workers in pg[{pg_idx}] "
f"but only {pg.bundle_count} bundles exist."
)
for local_rank in range(local_world_size):
rank += 1
env_vars = {
"WORLD_SIZE": str(resource_pool.world_size),
"RANK": str(rank),
"LOCAL_WORLD_SIZE": str(local_world_size),
"LOCAL_RANK": str(local_rank),
"MASTER_ADDR": master_addr,
"MASTER_PORT": master_port,
"WG_PREFIX": name_prefix,
"WG_BACKEND": "ray",
"RAY_LOCAL_WORLD_SIZE": str(local_world_size),
}
worker_cls.update_options(
{
"runtime_env": {"env_vars": env_vars},
"name": f"{name_prefix}_{pg_idx}:{local_rank}",
}
)
worker = worker_cls(
placement_group=pg,
placement_group_bundle_idx=local_rank,
use_gpu=resource_pool.use_gpu,
num_gpus=1 / resource_pool.max_colocate_count,
device_name=device_name,
)
workers.append(worker)
if len(workers) != resource_pool.world_size:
raise RuntimeError(f"Worker count mismatch: {len(workers)} != {resource_pool.world_size}")
return workers