import os
import signal
import ipaddress
import subprocess
import threading
from motor.common.resources.instance import PDRole
from motor.common.resources.endpoint import Endpoint
from motor.common.utils.singleton import ThreadSafeSingleton
from motor.common.logger import get_logger
from motor.common.utils.env import Env
from motor.config.node_manager import NodeManagerConfig
from motor.common.utils.snapshot_utils import MOTOR_SNAPSHOT_METADATA_PATH
logger = get_logger(__name__)
MAX_PORT = 65535
MIN_PORT = 1024
class Daemon(ThreadSafeSingleton):
def __init__(self, config: NodeManagerConfig | None = None):
if hasattr(self, "_initialized"):
return
self.engine_pids: list[int] = []
if config is None:
config = NodeManagerConfig.from_json()
self.parallel_config = config.basic_config.parallel_config
self.device_num = config.basic_config.device_num
self.single_container_flag = config.single_container_config.single_container_flag
self.enable_multi_endpoints = config.basic_config.enable_multi_endpoints
self.enable_snapshot = config.snapshot_config.enable_snapshot
self.snapshot_metadata_path = (
config.snapshot_config.snapshot_metadata_path
if config.snapshot_config.snapshot_metadata_path != ""
else MOTOR_SNAPSHOT_METADATA_PATH
)
if self.single_container_flag:
self.device_offset = config.single_container_config.device_offset
self.kv_port = config.single_container_config.kv_port
self.lookup_rpc_port = config.single_container_config.lookup_rpc_port
self.dp_rpc_port = config.single_container_config.dp_rpc_port
self._initialized = True
self._pids_lock = threading.Lock()
@staticmethod
def _check_params(params: Endpoint) -> bool:
try:
port = int(params.business_port)
if not (MIN_PORT <= port <= MAX_PORT):
logger.error("Port %s is out of valid range", port)
return False
except ValueError:
logger.error("Invalid port value: %s", params.business_port)
return False
try:
ipaddress.ip_address(params.ip)
except ValueError:
logger.error("Invalid IP address: %s", params.ip)
return False
except Exception as e:
logger.error("Error validating IP address %s: %s", params.ip, e)
return False
return True
@staticmethod
def _to_engine_role(pd_role_info: PDRole) -> str:
if pd_role_info == PDRole.ROLE_U:
return "union"
return str(pd_role_info.value)
def pull_engine(
self,
pd_role_info: PDRole,
endpoints_info: list[Endpoint],
instance_id: int,
master_dp_ip: str,
d2d_peer_ips: list[str] | None = None,
node_rank: int = 0,
):
"""
start engine processes based on the provided role and endpoint information.
engine_server parameters:
--dp-rank engine dpGroup rank
--engine-id
--role prefill | decode | union
--host engine service ip
--port engine service port
--mgmt-port endpoint management port
--master-dp-ip master data parallel node IP address
--node-rank node rank assigned by Controller (registration order)
--config-path engine config file path
"""
try:
env = os.environ.copy()
pod_ip = env.get("POD_IP")
if pod_ip and not env.get("VLLM_HOST_IP"):
env["VLLM_HOST_IP"] = pod_ip
if env.get("MOONCAKE_ASCEND_IPV6_EXPERIMENT") == "1":
env["MC_USE_IPV6"] = env.get("MC_USE_IPV6", "1")
device_size = self.device_num
for i, endpoint in enumerate(endpoints_info):
if not self._check_params(endpoint):
raise ValueError("Invalid endpoint parameters")
if self.enable_multi_endpoints:
device_ids_str = self._calc_visible_device_ids(i, device_size)
logger.info("Device IDs: %s", device_ids_str)
env["ASCEND_RT_VISIBLE_DEVICES"] = device_ids_str
cmd = [
"engine_server",
"--dp-rank",
str(endpoint.id),
"--instance-id",
str(instance_id),
"--role",
self._to_engine_role(pd_role_info),
"--host",
str(endpoint.ip),
"--port",
str(int(endpoint.business_port)),
"--mgmt-port",
str(int(endpoint.mgmt_port)),
"--master-dp-ip",
master_dp_ip,
"--node-rank",
str(node_rank),
"--config-path",
str(Env.user_config_path),
]
if self.enable_snapshot:
cmd.extend(["--snapshot-metadata", self.snapshot_metadata_path])
if self.single_container_flag:
cmd.extend(["--kv-port", str(self.kv_port)])
cmd.extend(["--dp-rpc-port", str(self.dp_rpc_port)])
if self.lookup_rpc_port is not None:
cmd.extend(["--lookup-rpc-port", str(self.lookup_rpc_port)])
if d2d_peer_ips:
ep_id = str(endpoint.id)
peer_ips = []
for entry in d2d_peer_ips:
encoded_ep_id, ip = entry.split(":", 1)
if encoded_ep_id == ep_id:
peer_ips.append(ip)
if peer_ips:
cmd.extend(["--d2d-peer-ips", ",".join(peer_ips)])
logger.info("D2D peer IPs for ep_id %s: %s", endpoint.id, peer_ips)
logger.info(" ".join(cmd))
process = subprocess.Popen(cmd, shell=False, env=env)
if process.poll() is not None:
raise RuntimeError("Engine process exited immediately with code %s" % process.returncode)
with self._pids_lock:
self.engine_pids.append(process.pid)
except Exception as e:
raise RuntimeError("Failed to pull engine: %s" % e) from e
def stop(self):
with self._pids_lock:
pids = list(self.engine_pids)
self.engine_pids.clear()
for pid in pids:
try:
os.kill(pid, signal.SIGKILL)
logger.info("Killed engine process with PID: %s", pid)
except ProcessLookupError:
logger.info("Process %s already terminated", pid)
except PermissionError:
logger.error("No permission to kill process %s", pid)
except Exception as e:
logger.error("Failed to kill process %s: %s", pid, e)
def _calc_visible_device_ids(self, index: int, device_size: int) -> str:
"""Calculate visible device IDs string for ASCEND_RT_VISIBLE_DEVICES.
Returns:
Comma-separated device IDs string, e.g., "0,1,2,3"
"""
local_world_size = self.parallel_config.local_world_size
start_device_id = index * local_world_size % device_size
end_device_id = start_device_id + local_world_size
if end_device_id > device_size:
device_ids = list(range(start_device_id, device_size)) + list(range(0, end_device_id - device_size))
else:
device_ids = list(range(start_device_id, end_device_id))
if self.single_container_flag:
device_ids = [x + self.device_offset for x in device_ids]
return ",".join(map(str, device_ids))