import asyncio
import os
import time
import tqdm
import torch
import torch.distributed as dist
import requests
import multiprocessing
import multiprocessing.connection
from typing import Optional, Callable
from torch.distributed.tensor import DTensor
import json
from configs import AgentTrainingConfig
from log import logger
from databases import InferenceService
from .utils import init_process_group, obtain_local_port, _compute_tensor_nbytes
from sglang.srt.utils import MultiprocessingSerializer
from sglang.srt.entrypoints.http_server import launch_server, ServerArgs
try:
from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
except (ImportError, ModuleNotFoundError):
from sglang.srt.patch_torch import monkey_patch_torch_reductions
def _compute_tensor_nbytes(t: torch.Tensor) -> int:
try:
return t.numel() * t.element_size()
except Exception:
return int(t.nelement()) * int(getattr(t, "element_size", lambda: 1)())
def fixed_launch_server_process(
server_args: ServerArgs,
pipe_finish_writer: Optional[multiprocessing.connection.Connection] = None,
launch_callback: Optional[Callable[[], None]] = None,
):
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from typing import List
def detokenize_logprob_tokens(
self,
token_logprobs_val: List[float],
token_logprobs_idx: List[int],
decode_to_text: bool,
):
if not decode_to_text:
return [
(logprob, token_id, None)
for logprob, token_id in zip(token_logprobs_val, token_logprobs_idx)
]
else:
assert self.tokenizer is not None
token_texts = self.tokenizer.convert_ids_to_tokens(token_logprobs_idx)
return list(zip(token_logprobs_val, token_logprobs_idx, token_texts))
TokenizerManager.detokenize_logprob_tokens = detokenize_logprob_tokens
launch_server(
server_args=server_args,
pipe_finish_writer=pipe_finish_writer,
launch_callback=launch_callback,
)
class SingletonMeta(type):
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super().__call__(*args, **kwargs)
return cls._instances[cls]
class InferenceManager(metaclass=SingletonMeta):
def __init__(
self,
config: AgentTrainingConfig,
loop: Optional[asyncio.AbstractEventLoop] = None,
):
"""
Initialize the Scheduler with a configuration and an optional event loop.
Args:
config (SamplerConfig): Configuration for the scheduler.
loop (Optional[asyncio.AbstractEventLoop]): Event loop to use for asynchronous operations.
"""
self.config = config
self.loop = loop if loop is not None else asyncio.get_event_loop()
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.mesh = dist.device_mesh.init_device_mesh(
device_type="cuda",
mesh_shape=(self.world_size // self.config.inf_tp_size, self.config.inf_tp_size),
mesh_dim_names=["dp", "tp"]
)
self.tp_rank = self.mesh.get_local_rank("tp")
self.dp_rank = self.mesh.get_local_rank("dp")
self.local_rank = int(os.environ.get("LOCAL_RANK",))
self.local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE",))
self.service_id = None
self.weights_update_group = None
dist.barrier(self.mesh.get_group("tp"))
if self.tp_rank == 0:
self.loop.run_until_complete(self._run_sglang_process(model_name_or_path=self.config.model_name_or_path))
dist.barrier(self.mesh.get_group("tp"))
self.loop.run_until_complete(self.release_memory_occupation(timeout=600.0))
def is_running(self):
"""Check if the inference service is running."""
return hasattr(self, "inference_proc") and self.inference_proc is not None and self.inference_proc.is_alive()
async def _run_sglang_process(self, model_name_or_path):
os.environ.pop("TORCHELASTIC_USE_AGENT_STORE", None)
monkey_patch_torch_reductions()
import torch
torch.multiprocessing.set_start_method("spawn",force=True)
local_ip, port = obtain_local_port()
if self.mesh["tp"].size() <= 1:
raw_visible = os.environ.get("CUDA_VISIBLE_DEVICES", "")
tp_size = self.mesh["tp"].size()
base_local_rank = max(0, self.local_rank - self.tp_rank)
if raw_visible:
try:
devices = [int(x.strip()) for x in raw_visible.split(',') if x.strip() != ""]
except ValueError:
devices = list(range(self.local_world_size))
selected = []
for i in range(tp_size):
idx = base_local_rank + i
if idx < len(devices):
selected.append(str(devices[idx]))
else:
selected.append(str(idx))
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(selected)
else:
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(
str(base_local_rank + i) for i in range(tp_size)
)
server_args=ServerArgs(
model_path=model_name_or_path,
host=local_ip,
port=port,
tp_size=self.config.inf_tp_size,
ep_size=self.config.inf_ep_size,
dp_size=1,
enable_dp_attention=True,
mem_fraction_static=self.config.inf_mem_ratio,
trust_remote_code=True,
dtype=self.config.torch_dtype,
served_model_name=self.config.model_name,
enable_memory_saver=True,
enable_fp32_lm_head=True,
load_format="dummy",
log_level="warning",
base_gpu_id=0 if self.mesh["tp"].size() <= 1 else self.dp_rank * self.config.inf_tp_size % self.local_world_size,
tool_call_parser=self.config.tool_call_parser,
mm_attention_backend="fa3",
attention_backend="fa3",
gpu_id_step=1,
random_seed=self.rank,
preferred_sampling_params=self.config.preferred_sampling_params
)
p = multiprocessing.Process(target=fixed_launch_server_process, args=(server_args,))
p.start()
base_url = server_args.url()
timeout = 300.0
start_time = time.perf_counter()
with requests.Session() as session:
while time.perf_counter() - start_time < timeout:
try:
headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {server_args.api_key}",
}
response = session.get(f"{base_url}/health_generate", headers=headers)
if response.status_code == 200:
break
except requests.RequestException:
pass
if not p.is_alive():
raise Exception("Server process terminated unexpectedly.")
time.sleep(2)
self.inference_proc = p
await asyncio.sleep(3)
logger.debug(f"Inference service started at {base_url}")
if self.mesh["tp"].size() <= 1:
os.environ["CUDA_VISIBLE_DEVICES"] = raw_visible
service = InferenceService(
models=[model_name_or_path],
connection_type="openai",
configs=InferenceService.OpenAIConfig(
base_url=f"http://{local_ip}:{port}/v1",
host=local_ip,
port=port,
),
status="UP"
)
await service.save()
self.service_id = service.id
def get_load(self,service:InferenceService):
response = requests.get(
f"http://{service.configs.host}:{service.configs.port}/get_load",
json={}
)
response.raise_for_status()
load_info = response.json()
if isinstance(load_info, dict):
if "load" in load_info:
return load_info["load"]
elif "num_reqs" in load_info:
return load_info["num_reqs"] + load_info.get("num_waiting_reqs",0) + load_info.get("num_tokens",0)
else:
raise ValueError("Invalid load info format received from inference service.")
elif isinstance(load_info, list):
return sum(map(lambda x:x["num_reqs"]+x["num_waiting_reqs"]+x["num_tokens"],load_info))
else:
raise ValueError("Invalid load info format received from inference service.")
def _pause(self, service: InferenceService):
"""Pause the inference service."""
if service is None:
return
response = requests.post(
f"http://{service.configs.host}:{service.configs.port}/pause_generation",
json={}
)
response.raise_for_status()
logger.debug("Inference service paused successfully.")
def _continue(self, service: InferenceService):
"""Continue the inference service."""
if service is None:
return
response = requests.post(
f"http://{service.configs.host}:{service.configs.port}/continue_generation",
json={}
)
response.raise_for_status()
logger.debug("Inference service continued successfully.")
async def release_memory_occupation(self, timeout: float = 90.0):
"""Release memory occupation of the inference service."""
if self.service_id is not None:
service = await InferenceService.get(self.service_id)
service = await service.set(
{InferenceService.status: "DOWN"}
)
start_time = time.perf_counter()
MAX_PROBE_TIMES = int(os.environ.get("SG_SRT_MAX_PROBE_TIMES", 5))
prob_times = MAX_PROBE_TIMES
while prob_times > 0:
prob_times -= 1
while (load := self.get_load(service)) > 0 or service.running_req_count > 0:
if time.perf_counter() - start_time < timeout:
await asyncio.sleep(1)
else:
logger.warning("Releasing memory after waiting for {:.2f} seconds (load {}).".format(time.perf_counter() - start_time, load))
self._pause(service)
start_time = time.perf_counter()
await service.sync()
prob_times = MAX_PROBE_TIMES
await asyncio.sleep(2)
try:
response = requests.post(
f"http://{service.configs.host}:{service.configs.port}/release_memory_occupation",
json={}
)
response.raise_for_status()
logger.debug("Memory occupation released successfully.")
except requests.RequestException as e:
logger.error(f"Failed to release memory occupation: {e}")
dist.barrier(self.mesh.get_group("tp"))
if self.service_id is not None and not self.is_running():
logger.warning(f"Inference service {self.service_id} is not running.")
self.terminate_inference_service()
await service.delete()
await self._run_sglang_process(self.config.model_name_or_path)
try:
response = requests.post(
f"http://{service.configs.host}:{service.configs.port}/release_memory_occupation",
json={}
)
response.raise_for_status()
logger.debug("Memory occupation released successfully.")
except requests.RequestException as e:
logger.error(f"Failed to release memory occupation while restarting inference: {e}")
import traceback
traceback.print_exc()
raise RuntimeError("Inference Service failed (and failed to restart) while releasing memory occupation") from e
dist.barrier(self.mesh.get_group("tp"))
async def resume_memory_occupation(self):
"""Resume memory occupation of the inference service."""
if self.service_id is not None:
logger.debug(f"Resuming memory occupation for service {self.service_id}...")
try:
service = await InferenceService.get(self.service_id)
response = requests.post(
f"http://{service.configs.host}:{service.configs.port}/resume_memory_occupation",
json={}
)
response.raise_for_status()
logger.debug("Memory occupation resumed successfully.")
service.status = "UP"
await service.save()
except requests.RequestException as e:
logger.error(f"Failed to resume memory occupation: {e}")
dist.barrier(self.mesh.get_group("tp"))
async def update_model_weights(self, state_dict: dict[str, torch.Tensor], device: Optional[torch.device], *, budget_bytes: Optional[int] = 512 * 1024 * 1024):
"""Update the model weights in the inference service."""
if self.service_id is not None:
service = await InferenceService.get(self.service_id)
else:
service = None
pbar = tqdm.tqdm(total=len(state_dict), desc="Updating weights", unit="tensors", disable=self.rank != 0)
named_tensors_batch = []
current_batch_bytes = 0
def flush_batch_if_needed(force: bool = False):
nonlocal named_tensors_batch, current_batch_bytes
serialized_tensors = MultiprocessingSerializer.serialize(named_tensors_batch, output_str=True)
if self.tp_rank == 0:
gathered_serialized_tensors = [None for _ in range(self.mesh.get_group("tp").size())]
else:
gathered_serialized_tensors = None
dist.gather_object(
obj=serialized_tensors,
object_gather_list=gathered_serialized_tensors,
group=self.mesh.get_group("tp"),
group_dst=0,
)
if not service or self.tp_rank != 0:
named_tensors_batch = []
current_batch_bytes = 0
return
if named_tensors_batch and (force or current_batch_bytes > (budget_bytes or 0)):
pbar.set_description(f"Updating {len(named_tensors_batch)} weights with {current_batch_bytes / (1024*1024):.2f} MB")
if self.tp_rank == 0:
response = requests.post(
f"http://{service.configs.host}:{service.configs.port}/update_weights_from_tensor",
json={
"serialized_named_tensors": gathered_serialized_tensors,
"load_format": None,
"flush_cache": False,
},
)
response.raise_for_status()
pbar.update(len(named_tensors_batch))
named_tensors_batch = []
current_batch_bytes = 0
for name, tensor in state_dict.items():
if self.rank == 0:
tensor = tensor.to(device).detach()
else:
tensor = torch.empty_like(tensor, dtype=tensor.dtype, device=device)
dist.broadcast(tensor, src=0)
named_tensors_batch.append((name, tensor))
tensor_nbytes = _compute_tensor_nbytes(tensor)
current_batch_bytes += tensor_nbytes
if current_batch_bytes > budget_bytes:
flush_batch_if_needed(force=True)
flush_batch_if_needed(force=True)
pbar.close()
self._continue(service)
dist.barrier()
def terminate_inference_service(self,*args, **kwargs):
"""Terminate the inference service."""
if hasattr(self, "inference_proc") and self.inference_proc:
if self.inference_proc.is_alive():
self.inference_proc.terminate()
if self.inference_proc.join(timeout=10):
self.inference_proc.kill()
self.inference_proc.close()
self.inference_proc = None
logger.debug("Inference service process terminated successfully.")
def clean(self):
"""Shutdown the inference service."""
self.terminate_inference_service()