import abc
import logging
import os
import random
from datetime import timedelta
import ray
import torch
import torch.distributed as dist
import slime.utils.eval_config
from slime.ray.ray_actor import RayActor
from slime.utils.distributed_utils import init_gloo_group
from slime.utils.logging_utils import configure_logger
from slime.utils.memory_utils import clear_memory, print_memory
from slime.utils.common import is_npu
logger = logging.getLogger(__name__)
def get_local_gpu_id():
if is_npu():
env_var = "ASCEND_RT_VISIBLE_DEVICES"
device_ids = ray.get_runtime_context().get_accelerator_ids()["NPU"]
else:
env_var = "CUDA_VISIBLE_DEVICES"
device_ids = ray.get_gpu_ids()
cvd = os.environ.get(env_var, None)
if cvd is None:
return device_ids[0]
else:
return cvd.split(",").index(str(device_ids[0]))
class TrainRayActor(RayActor):
def __init__(self, world_size, rank, master_addr, master_port):
configure_logger()
self._world_size = world_size
self._rank = rank
if master_addr:
self.master_addr, self.master_port = master_addr, master_port
else:
self.master_addr, self.master_port = self._get_current_node_ip_and_free_port(
start_port=random.randint(20000, 21000)
)
os.environ["MASTER_ADDR"] = self.master_addr
os.environ["MASTER_PORT"] = str(self.master_port)
os.environ["WORLD_SIZE"] = str(self._world_size)
os.environ["RANK"] = str(self._rank)
os.environ["LOCAL_RANK"] = str(get_local_gpu_id())
def init(self, args, role, with_ref=False, with_opd_teacher=False):
self.args = args
self.role = role
self.with_ref = with_ref
self.with_opd_teacher = with_opd_teacher
torch.serialization.add_safe_globals([slime.utils.eval_config.EvalDatasetConfig])
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(f"cuda:{local_rank}")
backend = args.distributed_backend
dist.init_process_group(
backend=backend,
timeout=timedelta(minutes=args.distributed_timeout_minutes),
)
init_gloo_group()
args.rank = dist.get_rank()
args.world_size = dist.get_world_size()
try:
if torch.version.hip is not None:
logger.info("Detected ROCm/HIP environment, skipping NUMA affinity setup")
else:
import pynvml
pynvml.nvmlInit()
local_rank = int(os.environ["RANK"]) % args.num_gpus_per_node
handle = pynvml.nvmlDeviceGetHandleByIndex(local_rank)
pynvml.nvmlDeviceSetCpuAffinity(handle)
logger.info(f"Set NUMA affinity for GPU {local_rank}")
pynvml.nvmlShutdown()
except ImportError:
logger.info("Warning: pynvml not available, skipping NUMA affinity setup")
except Exception as e:
logger.info(f"Warning: Failed to set NUMA affinity: {e}")
def clear_memory(self):
if self.args.debug_rollout_only:
return
print_memory("before TrainRayActor.clear_memory")
clear_memory()
print_memory("after TrainRayActor.clear_memory")
@abc.abstractmethod
def sleep(self, tags):
raise NotImplementedError
@abc.abstractmethod
def wake_up(self, tags):
raise NotImplementedError
@abc.abstractmethod
def train(self, rollout_id, rollout_data_ref):
raise NotImplementedError
@abc.abstractmethod
def save_model(self, rollout_id, force_sync=False):
raise NotImplementedError
@abc.abstractmethod
def update_weights(self):
raise NotImplementedError
@abc.abstractmethod
def connect_actor_critic(self, critic_group):
raise NotImplementedError
@abc.abstractmethod
def _get_parallel_config(self):
raise NotImplementedError
def set_rollout_manager(self, rollout_manager):
self.rollout_manager = rollout_manager
if not self.args.debug_rollout_only and self.args.rank == 0:
ray.get(self.rollout_manager.set_train_parallel_config.remote(self.train_parallel_config))