@@ -1,3 +1,3 @@
[codespell]
-ignore-words-list = ans, als, hel, boostrap, childs, te, vas, hsa, ment, cann, thi, makro, wil, rouge, PRIS
+ignore-words-list = ans, als, hel, boostrap, childs, te, vas, hsa, ment, cann, thi, makro, wil, rouge, PRIS, medias
skip = *.json,*.jsonl,*.patch,*.txt
@@ -274,6 +274,7 @@ class ModelConfig:
if is_draft_model and self.hf_config.architectures[0] in [
"DeepseekV3ForCausalLM",
+ "DeepseekV32ForCausalLM",
"GlmMoeDsaForCausalLM",
]:
self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN"
@@ -1016,10 +1017,10 @@ class ModelConfig:
if tf_version < required_version:
if needs_tf_v5:
- raise ValueError(
- f"Transformers version {tf_version_str} is not supported for model {self.model_path} "
+ logger.warning(
+ f"Transformers version {tf_version_str} may not be fully supported for model {self.model_path} "
f"or model type {self.hf_config.model_type}. "
- "Please upgrade transformers to >= 5.0.0."
+ "Recommended transformers >= 5.0.0, but proceeding with current version."
)
elif not needs_tf_v5:
logger.warning(
@@ -17,6 +17,7 @@ class KVArgs:
kv_data_ptrs: List[int]
kv_data_lens: List[int]
kv_item_lens: List[int]
+ aux_buffer_names: List[str]
aux_data_ptrs: List[int]
aux_data_lens: List[int]
aux_item_lens: List[int]
@@ -24,6 +24,7 @@ from sglang.srt.disaggregation.base.conn import (
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.distributed import get_pp_group
from sglang.srt.layers.dp_attention import (
+ get_attention_cp_size,
get_attention_dp_rank,
get_attention_dp_size,
get_attention_tp_rank,
@@ -116,10 +117,21 @@ class CommonKVManager(BaseKVManager):
bootstrap_server_url = f"{host}:{self.bootstrap_port}"
url = f"http://{bootstrap_server_url}/route"
+ route_attn_tp_rank = self.attn_tp_rank
+ # In prefill CP mode, attention TP rank is flattened to 0, but requests are
+ # still routed by engine rank; register by engine rank to preserve all routes.
+ # Only apply this when actual CP is in use (cp_size > 1), not in pure DP
+ # attention mode (e.g. EP64) where each rank has its own dp_group already.
+ if (
+ self.disaggregation_mode == DisaggregationMode.PREFILL
+ and self.attn_tp_size == 1
+ and get_attention_cp_size() > 1
+ ):
+ route_attn_tp_rank = self.kv_args.engine_rank
payload = {
"role": "Prefill",
"attn_tp_size": self.attn_tp_size,
- "attn_tp_rank": self.attn_tp_rank,
+ "attn_tp_rank": route_attn_tp_rank,
"attn_dp_size": self.attn_dp_size,
"attn_dp_rank": self.attn_dp_rank,
"pp_size": self.pp_size,
@@ -333,6 +345,10 @@ class CommonKVReceiver(BaseKVReceiver):
self.required_dst_info_num = (
self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size
)
+ # With attention DP, one request is routed to one decode rank.
+ # Waiting for all TP shards to pre-allocate the same bootstrap room would stall forever.
+ if self.kv_mgr.attn_dp_size > 1:
+ self.required_dst_info_num = 1
self.required_prefill_response_num = 1 * (
self.prefill_pp_size // self.kv_mgr.pp_size
)
@@ -357,6 +373,11 @@ class CommonKVReceiver(BaseKVReceiver):
# multiple connections in the connection pool and have to send dummy requests to other prefill ranks,
# or the KVPoll will never be set correctly
self.target_tp_rank = self.target_tp_ranks[0]
+ # For prefill CP mode (decode attention TP=1, prefill attention TP>1),
+ # route bootstrap to all prefill ranks as non-dummy so the serving rank
+ # always receives decode-side metadata.
+ if self.kv_mgr.attn_tp_size == 1 and self.prefill_attn_tp_size > 1:
+ self.target_tp_rank = None
self.required_dst_info_num = 1
if self.kv_mgr.is_mla_backend:
self.required_prefill_response_num = (
@@ -422,6 +443,7 @@ class CommonKVReceiver(BaseKVReceiver):
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group} and target_pp_rank {target_pp_rank}",
)
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
+ self.bootstrap_infos = None
return
self.bootstrap_infos = bootstrap_infos
@@ -610,8 +632,12 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
and int(target_dp_group) == -1
and int(target_pp_rank) == -1
):
+ inferred_attn_tp_size = max(
+ (len(v) for v in self.prefill_port_table.values()),
+ default=self.attn_tp_size,
+ )
prefill_parallel_info = {
- "prefill_attn_tp_size": self.attn_tp_size,
+ "prefill_attn_tp_size": inferred_attn_tp_size,
"prefill_dp_size": self.dp_size,
"prefill_pp_size": self.pp_size,
"prefill_page_size": self.page_size,
@@ -21,6 +21,7 @@ Life cycle of a request in the decode server
from __future__ import annotations
import logging
+import os
import time
from collections import deque
from dataclasses import dataclass
@@ -40,8 +41,10 @@ from sglang.srt.disaggregation.utils import (
MetadataBuffers,
ReqToMetadataIdxAllocator,
TransferBackend,
+ apply_prefill_timing_payload,
get_kv_class,
is_mla_backend,
+ is_slime_profiling_enabled,
kv_to_page_indices,
poll_and_all_reduce,
prepare_abort,
@@ -295,6 +298,7 @@ class DecodePreallocQueue:
kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
self.metadata_buffers.get_buf_infos()
)
+ kv_args.aux_buffer_names = self.metadata_buffers.get_aux_buffer_names()
if hasattr(self.token_to_kv_pool, "get_state_buf_infos"):
state_data_ptrs, state_data_lens, state_item_lens = (
@@ -336,6 +340,16 @@ class DecodePreallocQueue:
)
return kv_manager
+ def release_memory_occupation(self):
+ self.queue.clear()
+ self.retracted_queue.clear()
+ if hasattr(self.kv_manager, "deregister_buffer_to_engine"):
+ self.kv_manager.deregister_buffer_to_engine()
+
+ def resume_memory_occupation(self):
+ if hasattr(self.kv_manager, "register_buffer_to_engine"):
+ self.kv_manager.register_buffer_to_engine()
+
def add(self, req: Req, is_retracted: bool = False) -> None:
"""Add a request to the pending queue."""
if self._check_if_req_exceed_kv_capacity(req):
@@ -440,12 +454,37 @@ class DecodePreallocQueue:
[decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
)
+ # Bootstrap timeout: if a request has been stuck in Bootstrapping for too long, treat it as failed.
+ bootstrap_timeout = float(
+ os.environ.get("SGLANG_DISAGGREGATION_TRANSFER_TIMEOUT", "600")
+ )
+ now = time.perf_counter()
+
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
if rids_to_check is not None and decode_req.req.rid not in rids_to_check:
continue
if poll == KVPoll.Bootstrapping:
- pass
+ # Check for bootstrap timeout
+ entry_time = getattr(
+ decode_req.req.time_stats,
+ "decode_prealloc_queue_entry_time",
+ None,
+ )
+ if entry_time is not None and (now - entry_time) > bootstrap_timeout:
+ error_message = (
+ f"Decode bootstrap timed out after {now - entry_time:.1f}s "
+ f"for request rank={self.tp_rank} "
+ f"{decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
+ )
+ logger.error(error_message)
+ prepare_abort(
+ decode_req.req,
+ error_message,
+ status_code=HTTPStatus.GATEWAY_TIMEOUT,
+ )
+ if self.scheduler.enable_metrics:
+ self.scheduler.metrics_collector.increment_bootstrap_failed_reqs()
elif poll == KVPoll.WaitingForInput:
decode_req.waiting_for_input = True
elif poll == KVPoll.Failed:
@@ -590,6 +629,7 @@ class DecodePreallocQueue:
self.req_to_metadata_buffer_idx_allocator.alloc()
)
assert decode_req.metadata_buffer_index is not None
+ self.metadata_buffers.clear_profiling_buf(decode_req.metadata_buffer_index)
page_indices = kv_to_page_indices(kv_indices, page_size)
decode_req.kv_receiver.init(
page_indices, decode_req.metadata_buffer_index, state_indices
@@ -751,6 +791,7 @@ class DecodeTransferQueue:
output_topk_index,
output_hidden_states,
output_bootstrap_room,
+ output_prefill_timing,
) = self.metadata_buffers.get_buf(idx)
# Validate bootstrap_room to detect context corruption
@@ -813,6 +854,14 @@ class DecodeTransferQueue:
output_top_logprobs_idx[: decode_req.req.top_logprobs_num].tolist()
)
+ # Inject prefill-side PD timing forwarded from the P instance.
+ # Layout: [bootstrap_queue, forward, transfer_queue, bootstrap,
+ # alloc_waiting, transfer_speed, transfer_mb, retry_count]
+ if is_slime_profiling_enabled():
+ apply_prefill_timing_payload(
+ decode_req.req.time_stats, output_prefill_timing
+ )
+
decode_req.kv_receiver.clear()
decode_req.kv_receiver = None
trace_slice_end(
@@ -830,6 +879,13 @@ class DecodeTransferQueue:
[decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
)
+ # Transfer timeout: if a request has been in the transfer queue for too long
+ # (e.g., stuck in Bootstrapping/WaitingForInput/Transferring), treat it as failed.
+ transfer_timeout = float(
+ os.environ.get("SGLANG_DISAGGREGATION_TRANSFER_TIMEOUT", "600")
+ )
+ now = time.perf_counter()
+
transferred_reqs = []
indices_to_remove = set()
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
@@ -877,7 +933,20 @@ class DecodeTransferQueue:
KVPoll.WaitingForInput,
KVPoll.Transferring,
]:
- pass
+ # Check for transfer timeout
+ entry_time = getattr(
+ decode_req.req.time_stats,
+ "decode_transfer_queue_entry_time",
+ None,
+ )
+ if entry_time is not None and (now - entry_time) > transfer_timeout:
+ error_message = (
+ f"Decode transfer timed out after {now - entry_time:.1f}s "
+ f"(state={poll}) for request rank={self.tp_rank} "
+ f"{decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
+ )
+ logger.error(error_message)
+ decode_req.kv_receiver.abort()
else:
raise ValueError(f"Unexpected poll case: {poll}")
@@ -893,6 +962,14 @@ class DecodeTransferQueue:
return transferred_reqs
+ def release_memory_occupation(self):
+ """Clean up all in-flight transfers before releasing GPU memory."""
+ self.queue.clear()
+
+ def resume_memory_occupation(self):
+ """Resume after GPU memory re-allocation. Queue was already cleared on release."""
+ pass
+
class SchedulerDisaggregationDecodeMixin:
@@ -1072,7 +1149,15 @@ class SchedulerDisaggregationDecodeMixin:
resumed_reqs = self.disagg_decode_prealloc_queue.resume_retracted_reqs()
self.waiting_queue.extend(resumed_reqs)
if len(self.disagg_decode_prealloc_queue.retracted_queue) > 0:
- # if there are still retracted requests, we do not allocate new requests
+ # Still have retracted requests that couldn't resume (not enough memory).
+ # Don't accept new requests (pop_preallocated) — they would consume memory
+ # that retracted requests need.
+ # But DO drain completed transfers: their KV is already committed, and
+ # moving them to waiting_queue frees the reserved-decode-token budget
+ # in _allocatable_tokens(), which may unblock resume on the next iteration.
+ # Without this, completed transfers hold memory indefinitely → deadlock.
+ alloc_reqs = self.disagg_decode_transfer_queue.pop_transferred()
+ self.waiting_queue.extend(alloc_reqs)
return
if not hasattr(self, "polling_count"):
@@ -117,7 +117,7 @@ def _convert(data):
return data
-_image_grid_attrs = ["image_grid_thw", "image_grid_hws"]
+_image_grid_attrs = ["image_grid_thw", "image_grid_hws", "grid_thws"]
def _get_image_grid_dim(images_input):
@@ -320,7 +320,26 @@ class MMEncoder:
try:
kwargs = {"device": self.device} if self.use_image_processor_gpu else {}
- images_input = self.image_processor(images=images, **kwargs)
+ # Some processors (e.g., KimiK25VisionProcessor) expect MediaInput
+ # dicts rather than raw PIL Images. Wrap PIL images as needed.
+ from PIL import Image as PILImage
+
+ if (
+ isinstance(images, (list, tuple))
+ and images
+ and isinstance(images[0], PILImage.Image)
+ ):
+ import inspect
+
+ sig = inspect.signature(self.image_processor.preprocess)
+ first_param = list(sig.parameters.keys())[0]
+ if first_param == "medias":
+ medias = [{"type": "image", "image": img} for img in images]
+ images_input = self.image_processor.preprocess(medias, **kwargs)
+ else:
+ images_input = self.image_processor(images=images, **kwargs)
+ else:
+ images_input = self.image_processor(images=images, **kwargs)
feature = images_input["pixel_values"]
mm_item = MultimodalDataItem.from_dict(
{
@@ -30,7 +30,7 @@ from sglang.srt.disaggregation.common.utils import (
from sglang.srt.disaggregation.mooncake.utils import (
check_mooncake_custom_mem_pool_enabled,
)
-from sglang.srt.disaggregation.utils import DisaggregationMode
+from sglang.srt.disaggregation.utils import DisaggregationMode, iter_aux_transfer_specs
from sglang.srt.distributed.parallel_state import get_mooncake_transfer_engine
from sglang.srt.environ import envs
from sglang.srt.server_args import ServerArgs
@@ -260,6 +260,19 @@ class MooncakeKVManager(CommonKVManager):
self.kv_args.state_data_ptrs, self.kv_args.state_data_lens
)
+ def deregister_buffer_to_engine(self):
+ # Batch deregister KV data buffers
+ if self.kv_args.kv_data_ptrs:
+ self.engine.batch_deregister(self.kv_args.kv_data_ptrs)
+
+ # Batch deregister auxiliary data buffers
+ if self.kv_args.aux_data_ptrs:
+ self.engine.batch_deregister(self.kv_args.aux_data_ptrs)
+
+ # Batch deregister state/extra pool data buffers
+ if self.kv_args.state_data_ptrs:
+ self.engine.batch_deregister(self.kv_args.state_data_ptrs)
+
def _transfer_data(self, mooncake_session_id, transfer_blocks):
if not transfer_blocks:
return 0
@@ -524,10 +537,14 @@ class MooncakeKVManager(CommonKVManager):
prefill_aux_ptrs = self.kv_args.aux_data_ptrs
prefill_aux_item_lens = self.kv_args.aux_item_lens
- for i, dst_aux_ptr in enumerate(dst_aux_ptrs):
- length = prefill_aux_item_lens[i]
- src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
- dst_addr = dst_aux_ptrs[i] + length * req.dst_aux_index
+ for _, src_addr, dst_addr, length in iter_aux_transfer_specs(
+ self.kv_args.aux_buffer_names,
+ prefill_aux_ptrs,
+ prefill_aux_item_lens,
+ dst_aux_ptrs,
+ prefill_aux_index,
+ req.dst_aux_index,
+ ):
transfer_blocks.append((src_addr, dst_addr, length))
return self._transfer_data(req.mooncake_session_id, transfer_blocks)
@@ -541,9 +558,14 @@ class MooncakeKVManager(CommonKVManager):
prefill_aux_ptrs = self.kv_args.aux_data_ptrs
prefill_aux_item_lens = self.kv_args.aux_item_lens
- for i in range(len(prefill_aux_ptrs)):
- length = prefill_aux_item_lens[i]
- src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
+ for i, src_addr, _, length in iter_aux_transfer_specs(
+ self.kv_args.aux_buffer_names,
+ prefill_aux_ptrs,
+ prefill_aux_item_lens,
+ dst_aux_ptrs,
+ prefill_aux_index,
+ req.dst_aux_index,
+ ):
data = AuxDataCodec.serialize_data_from_buffer(src_addr, length)
self.send_aux_data_to_endpoint(
@@ -643,13 +665,13 @@ class MooncakeKVManager(CommonKVManager):
raise RuntimeError(
f"PD Disaggregation does NOT support PD different TP sizes for non-MLA {state_type.upper()} hybrid models yet."
)
- if len(prefill_state_indices) < len(req.dst_state_indices):
- logger.warning(
- f"len(prefill_state_indices) = {len(prefill_state_indices)}, len(dst_state_indices) = {len(req.dst_state_indices)}"
+ if len(prefill_state_indices) != len(req.dst_state_indices):
+ logger.error(
+ "PD extra-state index mismatch, reject transfer to avoid corrupted outputs: "
+ f"len(prefill_state_indices)={len(prefill_state_indices)}, "
+ f"len(dst_state_indices)={len(req.dst_state_indices)}"
)
- prefill_state_indices = prefill_state_indices[
- : len(req.dst_state_indices)
- ]
+ return -1
# Reuse _send_kvcache_generic interface to send extra pool data
prefill_state_indices = np.array(prefill_state_indices, dtype=np.int32)
dst_state_indices = np.array(req.dst_state_indices, dtype=np.int32)
@@ -858,12 +880,6 @@ class MooncakeKVManager(CommonKVManager):
if ret != 0:
with self.session_lock:
self.session_failures[req.mooncake_session_id] += 1
- # Failures should never happen if the session is not dead, if the session fails once, mark it as failed
- if self.session_failures[req.mooncake_session_id] >= 1:
- self.failed_sessions.add(req.mooncake_session_id)
- logger.error(
- f"Session {req.mooncake_session_id} failed."
- )
self.record_failure(
kv_chunk.room,
f"Failed to send kv chunk of {kv_chunk.room} to {req.endpoint}:{req.dst_port}",
@@ -880,13 +896,31 @@ class MooncakeKVManager(CommonKVManager):
if kv_chunk.is_last:
if kv_chunk.state_indices is not None:
- self.maybe_send_extra(
+ ret = self.maybe_send_extra(
req,
kv_chunk.state_indices,
target_rank_registration_info.dst_state_data_ptrs,
executor,
target_rank_registration_info,
)
+ if ret != 0:
+ with self.session_lock:
+ self.session_failures[
+ req.mooncake_session_id
+ ] += 1
+ self.record_failure(
+ kv_chunk.room,
+ f"Failed to send extra state chunk of {kv_chunk.room} to {req.endpoint}:{req.dst_port}",
+ )
+ self.update_status(kv_chunk.room, KVPoll.Failed)
+ self.sync_status_to_decode_endpoint(
+ req.endpoint,
+ req.dst_port,
+ req.room,
+ KVPoll.Failed,
+ local_rank,
+ )
+ break
# Only the last chunk we need to send the aux data
ret = self.send_aux(
@@ -895,6 +929,11 @@ class MooncakeKVManager(CommonKVManager):
target_rank_registration_info.dst_aux_ptrs,
)
polls.append(True if ret == 0 else False)
+ if ret != 0:
+ # Mark session as failed to avoid hanging
+ # on subsequent batch_transfer_sync calls
+ with self.session_lock:
+ self.session_failures[req.mooncake_session_id] += 1
dst_ranks_infos.append(
(req.endpoint, req.dst_port, req.room)
)
@@ -977,15 +1016,20 @@ class MooncakeKVManager(CommonKVManager):
if status == KVPoll.Success:
if bootstrap_room in self.request_status:
- self.prefill_response_tracker[bootstrap_room].add(prefill_rank)
+ # Guard against TOCTOU race: clear() may remove the entry
+ # between the request_status check and dict access here.
expected_response_num = (
- self.required_prefill_response_num_table[bootstrap_room]
+ self.required_prefill_response_num_table.get(bootstrap_room)
)
- arrived_response_num = len(
- self.prefill_response_tracker[bootstrap_room]
- )
- if arrived_response_num == expected_response_num:
- self.update_status(bootstrap_room, KVPoll.Success)
+ if expected_response_num is not None:
+ self.prefill_response_tracker[bootstrap_room].add(
+ prefill_rank
+ )
+ arrived_response_num = len(
+ self.prefill_response_tracker[bootstrap_room]
+ )
+ if arrived_response_num == expected_response_num:
+ self.update_status(bootstrap_room, KVPoll.Success)
elif status == KVPoll.Failed:
self.record_failure(
bootstrap_room,
@@ -1266,7 +1310,10 @@ class MooncakeKVReceiver(CommonKVReceiver):
super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank)
self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(self.bootstrap_room)
- self.kv_mgr.update_status(self.bootstrap_room, KVPoll.WaitingForInput)
+ # Only transition to WaitingForInput if bootstrap succeeded;
+ # if super().__init__() set status to Failed, do not override it.
+ if self.bootstrap_infos is not None:
+ self.kv_mgr.update_status(self.bootstrap_room, KVPoll.WaitingForInput)
def _register_kv_args(self):
for bootstrap_info in self.bootstrap_infos:
@@ -20,6 +20,7 @@ Life cycle of a request in the prefill server
from __future__ import annotations
import logging
+import os
import time
from collections import deque
from http import HTTPStatus
@@ -167,6 +168,7 @@ class PrefillBootstrapQueue:
kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
self.metadata_buffers.get_buf_infos()
)
+ kv_args.aux_buffer_names = self.metadata_buffers.get_aux_buffer_names()
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
kv_args.gpu_id = self.scheduler.gpu_id
@@ -276,6 +278,12 @@ class PrefillBootstrapQueue:
[req.disagg_kv_sender for req in self.queue], self.gloo_group
)
+ # Bootstrap timeout: if a request has been stuck in Bootstrapping for too long, treat it as failed.
+ bootstrap_timeout = float(
+ os.environ.get("SGLANG_DISAGGREGATION_TRANSFER_TIMEOUT", "600")
+ )
+ now = time.perf_counter()
+
for i, (req, poll) in enumerate(zip(self.queue, polls)):
if rids_to_check is not None:
# if req not in reqs_info_to_check, skip
@@ -283,6 +291,27 @@ class PrefillBootstrapQueue:
continue
if poll == KVPoll.Bootstrapping:
+ # Check for bootstrap timeout
+ entry_time = getattr(
+ req.time_stats,
+ "prefill_bootstrap_queue_entry_time",
+ None,
+ )
+ if entry_time is not None and (now - entry_time) > bootstrap_timeout:
+ error_message = (
+ f"Prefill bootstrap timed out after {now - entry_time:.1f}s "
+ f"for request rank={self.tp_rank} "
+ f"{req.rid=} {req.bootstrap_room=}"
+ )
+ logger.error(error_message)
+ prepare_abort(
+ req, error_message, status_code=HTTPStatus.GATEWAY_TIMEOUT
+ )
+ self.scheduler.stream_output([req], req.return_logprob)
+ indices_to_remove.add(i)
+ failed_reqs.append(req)
+ if self.scheduler.enable_metrics:
+ self.scheduler.metrics_collector.increment_bootstrap_failed_reqs()
continue
elif poll == KVPoll.Failed:
error_message = f"Prefill bootstrap failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}"
@@ -335,6 +364,15 @@ class PrefillBootstrapQueue:
else:
return bootstrapped_reqs, failed_reqs
+ def release_memory_occupation(self):
+ self.queue.clear()
+ if hasattr(self.kv_manager, "deregister_buffer_to_engine"):
+ self.kv_manager.deregister_buffer_to_engine()
+
+ def resume_memory_occupation(self):
+ if hasattr(self.kv_manager, "register_buffer_to_engine"):
+ self.kv_manager.register_buffer_to_engine()
+
class SchedulerDisaggregationPrefillMixin:
"""
@@ -547,6 +585,18 @@ class SchedulerDisaggregationPrefillMixin:
self.maybe_send_health_check_signal()
+ if (
+ self.current_scheduler_metrics_enabled
+ and hasattr(batch, "prefill_stats")
+ and batch.prefill_stats is not None
+ ):
+ can_run_cuda_graph = getattr(result, "can_run_cuda_graph", False)
+ self.log_prefill_stats(
+ prefill_stats=batch.prefill_stats,
+ can_run_cuda_graph=can_run_cuda_graph,
+ dp_cooperation_info=getattr(batch, "dp_cooperation_info", None),
+ )
+
def process_disagg_prefill_inflight_queue(
self: Scheduler, rids_to_check: Optional[List[str]] = None
) -> List[Req]:
@@ -559,11 +609,24 @@ class SchedulerDisaggregationPrefillMixin:
done_reqs = []
+ # When CP > 1, use the full TP gloo group so all CP ranks reach
+ # consensus; otherwise a subset may enter run_batch while others wait
+ # in recv_requests, causing a deadlock.
+ disagg_gloo_group = (
+ self.tp_cpu_group if self.attn_cp_size > 1 else self.attn_tp_cpu_group
+ )
polls = poll_and_all_reduce(
[req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue],
- self.attn_tp_cpu_group,
+ disagg_gloo_group,
)
+ # Transfer timeout: if a request has been in the inflight queue for too long
+ # (e.g., stuck in WaitingForInput/Transferring), treat it as failed.
+ transfer_timeout = float(
+ os.environ.get("SGLANG_DISAGGREGATION_TRANSFER_TIMEOUT", "600")
+ )
+ now = time.perf_counter()
+
undone_reqs: List[Req] = []
# Check .poll() for the reqs in disagg_prefill_inflight_queue. If Success, respond to the client and remove it from the queue
for req, poll in zip(self.disagg_prefill_inflight_queue, polls):
@@ -573,10 +636,35 @@ class SchedulerDisaggregationPrefillMixin:
undone_reqs.append(req)
continue
- assert poll == KVPoll.Success or poll == KVPoll.Failed
+ if poll not in (KVPoll.Success, KVPoll.Failed):
+ undone_reqs.append(req)
+ continue
if poll in [KVPoll.WaitingForInput, KVPoll.Transferring]:
- undone_reqs.append(req)
+ # Check for transfer timeout
+ entry_time = getattr(
+ req.time_stats,
+ "prefill_transfer_queue_entry_time",
+ None,
+ )
+ if entry_time is not None and (now - entry_time) > transfer_timeout:
+ error_message = (
+ f"Prefill transfer timed out after {now - entry_time:.1f}s "
+ f"(state={poll}) for request rank={self.tp_rank} "
+ f"{req.rid=} {req.bootstrap_room=}"
+ )
+ logger.error(error_message)
+ release_kv_cache(req, self.tree_cache) # unlock the tree
+ prepare_abort(
+ req, error_message, status_code=HTTPStatus.GATEWAY_TIMEOUT
+ )
+ if hasattr(req.disagg_kv_sender, "clear"):
+ req.disagg_kv_sender.clear()
+ done_reqs.append(req)
+ if self.enable_metrics:
+ self.metrics_collector.increment_transfer_failed_reqs()
+ else:
+ undone_reqs.append(req)
elif poll == KVPoll.Success: # transfer done
release_kv_cache(req, self.tree_cache) # unlock the tree
req.finished_reason = FINISH_LENGTH(length=0)
@@ -628,9 +716,12 @@ class SchedulerDisaggregationPrefillMixin:
"""
Used by PP, get the transferred rids but **do not pop**
"""
+ disagg_gloo_group = (
+ self.tp_cpu_group if self.attn_cp_size > 1 else self.attn_tp_cpu_group
+ )
polls = poll_and_all_reduce(
[req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue],
- self.attn_tp_cpu_group,
+ disagg_gloo_group,
)
transferred_rids: List[str] = []
@@ -21,6 +21,17 @@ if TYPE_CHECKING:
# Constants & Enums
#########################
FAKE_BOOTSTRAP_HOST = "2.2.2.2"
+PREFILL_TIMING_AUX_BUFFER_NAME = "prefill_timing"
+PREFILL_TIMING_DEST_ATTRS = (
+ ("fwd_prefill_bootstrap_queue_duration", float),
+ ("fwd_prefill_forward_duration", float),
+ ("fwd_prefill_transfer_queue_duration", float),
+ ("fwd_bootstrap_duration", float),
+ ("fwd_alloc_waiting_duration", float),
+ ("fwd_transfer_speed_gb_s", float),
+ ("fwd_transfer_total_mb", float),
+ ("fwd_prefill_retry_count", int),
+)
class DisaggregationMode(Enum):
@@ -139,46 +150,35 @@ class MetadataBuffers:
self.bootstrap_room = torch.zeros(
(size, 8), dtype=torch.uint64, device=device
)
+ # Prefill-side PD timing (8 floats, padded to 16 for RDMA alignment).
+ # Layout: [bootstrap_queue, forward, transfer_queue, bootstrap,
+ # alloc_waiting, transfer_speed, transfer_mb, retry_count]
+ self.prefill_timing = torch.zeros(
+ (size, 16), dtype=torch.float32, device=device
+ )
+ self.aux_buffers = [
+ ("output_ids", self.output_ids),
+ ("cached_tokens", self.cached_tokens),
+ ("output_token_logprobs_val", self.output_token_logprobs_val),
+ ("output_token_logprobs_idx", self.output_token_logprobs_idx),
+ ("output_top_logprobs_val", self.output_top_logprobs_val),
+ ("output_top_logprobs_idx", self.output_top_logprobs_idx),
+ ("output_topk_p", self.output_topk_p),
+ ("output_topk_index", self.output_topk_index),
+ ("output_hidden_states", self.output_hidden_states),
+ ("bootstrap_room", self.bootstrap_room),
+ (PREFILL_TIMING_AUX_BUFFER_NAME, self.prefill_timing),
+ ]
def get_buf_infos(self):
- ptrs = [
- self.output_ids.data_ptr(),
- self.cached_tokens.data_ptr(),
- self.output_token_logprobs_val.data_ptr(),
- self.output_token_logprobs_idx.data_ptr(),
- self.output_top_logprobs_val.data_ptr(),
- self.output_top_logprobs_idx.data_ptr(),
- self.output_topk_p.data_ptr(),
- self.output_topk_index.data_ptr(),
- self.output_hidden_states.data_ptr(),
- self.bootstrap_room.data_ptr(),
- ]
- data_lens = [
- self.output_ids.nbytes,
- self.cached_tokens.nbytes,
- self.output_token_logprobs_val.nbytes,
- self.output_token_logprobs_idx.nbytes,
- self.output_top_logprobs_val.nbytes,
- self.output_top_logprobs_idx.nbytes,
- self.output_topk_p.nbytes,
- self.output_topk_index.nbytes,
- self.output_hidden_states.nbytes,
- self.bootstrap_room.nbytes,
- ]
- item_lens = [
- self.output_ids[0].nbytes,
- self.cached_tokens[0].nbytes,
- self.output_token_logprobs_val[0].nbytes,
- self.output_token_logprobs_idx[0].nbytes,
- self.output_top_logprobs_val[0].nbytes,
- self.output_top_logprobs_idx[0].nbytes,
- self.output_topk_p[0].nbytes,
- self.output_topk_index[0].nbytes,
- self.output_hidden_states[0].nbytes,
- self.bootstrap_room[0].nbytes,
- ]
+ ptrs = [buffer.data_ptr() for _, buffer in self.aux_buffers]
+ data_lens = [buffer.nbytes for _, buffer in self.aux_buffers]
+ item_lens = [buffer[0].nbytes for _, buffer in self.aux_buffers]
return ptrs, data_lens, item_lens
+ def get_aux_buffer_names(self):
+ return [name for name, _ in self.aux_buffers]
+
def get_buf(self, idx: int):
return (
self.output_ids[idx],
@@ -191,8 +191,12 @@ class MetadataBuffers:
self.output_topk_index[idx],
self.output_hidden_states[idx],
self.bootstrap_room[idx],
+ self.prefill_timing[idx],
)
+ def clear_profiling_buf(self, idx: int):
+ self.prefill_timing[idx].zero_()
+
def set_buf(self, req: Req):
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
@@ -237,6 +241,84 @@ class MetadataBuffers:
self.bootstrap_room[req.metadata_buffer_index, 0] = (
req.bootstrap_room if req.bootstrap_room is not None else 0
)
+ # Pack prefill-side PD timing durations for transfer to decode instance.
+ # Note: set_buf is called at the START of the last KV chunk send, so
+ # completion_time and prefill_transfer_queue_entry_time are not yet set.
+ # We use time.perf_counter() as the "forward just completed" timestamp.
+ import time
+
+ ts = req.time_stats
+ timing = self.prefill_timing[req.metadata_buffer_index]
+ self.clear_profiling_buf(req.metadata_buffer_index)
+ if not is_slime_profiling_enabled():
+ return
+ for idx, value in enumerate(
+ build_prefill_timing_payload(ts, now=time.perf_counter())
+ ):
+ if value > 0:
+ timing[idx] = value
+
+
+def is_slime_profiling_enabled() -> bool:
+ return envs.SLIME_ENABLE_PROFILING.get()
+
+
+def build_prefill_timing_payload(time_stats, now: float) -> tuple[float, ...]:
+ bootstrap_queue_duration = 0.0
+ if (
+ time_stats.prefill_bootstrap_queue_entry_time > 0
+ and time_stats.wait_queue_entry_time > 0
+ ):
+ bootstrap_queue_duration = (
+ time_stats.wait_queue_entry_time
+ - time_stats.prefill_bootstrap_queue_entry_time
+ )
+
+ prefill_forward_duration = (
+ now - time_stats.forward_entry_time
+ if time_stats.forward_entry_time > 0
+ else 0.0
+ )
+
+ return (
+ bootstrap_queue_duration,
+ prefill_forward_duration,
+ 0.0,
+ max(0.0, time_stats.bootstrap_duration),
+ max(0.0, time_stats.alloc_waiting_duration),
+ max(0.0, time_stats.transfer_speed_gb_s),
+ max(0.0, time_stats.transfer_total_mb),
+ float(max(0, time_stats.prefill_retry_count)),
+ )
+
+
+def apply_prefill_timing_payload(time_stats, timing) -> None:
+ for value, (attr_name, caster) in zip(
+ timing[: len(PREFILL_TIMING_DEST_ATTRS)].tolist(),
+ PREFILL_TIMING_DEST_ATTRS,
+ ):
+ if value > 0:
+ setattr(time_stats, attr_name, caster(value))
+
+
+def iter_aux_transfer_specs(
+ aux_buffer_names: list[str],
+ prefill_aux_ptrs: list[int],
+ prefill_aux_item_lens: list[int],
+ dst_aux_ptrs: list[int],
+ prefill_aux_index: int,
+ dst_aux_index: int,
+):
+ profiling_enabled = is_slime_profiling_enabled()
+ for i, (buffer_name, dst_aux_ptr) in enumerate(zip(aux_buffer_names, dst_aux_ptrs)):
+ if not profiling_enabled and buffer_name == PREFILL_TIMING_AUX_BUFFER_NAME:
+ continue
+ length = prefill_aux_item_lens[i]
+ if length <= 0:
+ continue
+ src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
+ dst_addr = dst_aux_ptr + length * dst_aux_index
+ yield i, src_addr, dst_addr, length
#########################
@@ -1999,7 +1999,10 @@ def get_tensor_model_parallel_world_size():
def get_tensor_model_parallel_rank():
"""Return my rank for the tensor model parallel group."""
- return get_tp_group().rank_in_group
+ try:
+ return get_tp_group().rank_in_group
+ except Exception:
+ return 0
# ATTN_TP
@@ -52,6 +52,7 @@ from sglang.srt.managers.io_struct import (
LoadLoRAAdapterReqInput,
MultimodalDataInputFormat,
OpenSessionReqInput,
+ PostProcessWeightsReqInput,
ReleaseMemoryOccupationReqInput,
ResumeMemoryOccupationReqInput,
RpcReqInput,
@@ -641,6 +642,24 @@ class Engine(EngineBase):
self.tokenizer_manager.update_weights_from_ipc(obj, None)
)
+ def post_process_weights(
+ self,
+ restore_weights_before_load: bool = False,
+ post_process_quantization: bool = False,
+ ):
+ """
+ Optional post-processing for updated weights (e.g., Marlin conversion).
+ Should be called after weight update is finished.
+ """
+ obj = PostProcessWeightsReqInput(
+ restore_weights_before_load=restore_weights_before_load,
+ post_process_quantization=post_process_quantization,
+ )
+
+ return self.loop.run_until_complete(
+ self.tokenizer_manager.post_process_weights(obj, None)
+ )
+
def get_weights_by_name(self, name: str, truncate_size: int = 100):
"""Get weights by parameter name."""
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
@@ -115,6 +115,7 @@ from sglang.srt.managers.io_struct import (
OpenSessionReqInput,
ParseFunctionCallReq,
PauseGenerationReqInput,
+ PostProcessWeightsReqInput,
ProfileReqInput,
ReleaseMemoryOccupationReqInput,
ResumeMemoryOccupationReqInput,
@@ -574,10 +575,8 @@ async def model_info():
@app.get("/weight_version")
async def weight_version():
"""Get the current weight version."""
- raise HTTPException(
- status_code=404,
- detail="Endpoint '/get_weight_version' or '/weight_version' is deprecated. Please use '/model_info' instead.",
- )
+ result = await model_info()
+ return {"weight_version": result.get("weight_version", None)}
@app.get("/get_server_info")
@@ -594,9 +593,19 @@ async def get_server_info():
async def server_info():
"""Get the server information."""
# Returns internal states per DP.
- internal_states: List[Dict[Any, Any]] = (
- await _global_state.tokenizer_manager.get_internal_state()
- )
+ # In large/disaggregated deployments this can occasionally block; keep endpoint responsive.
+ server_info_timeout = float(os.environ.get("SGLANG_SERVER_INFO_TIMEOUT", "2"))
+ try:
+ internal_states: List[Dict[Any, Any]] = await asyncio.wait_for(
+ _global_state.tokenizer_manager.get_internal_state(),
+ timeout=server_info_timeout,
+ )
+ except asyncio.TimeoutError:
+ logger.warning(
+ "Timed out getting internal state for /server_info after %.1fs; returning empty internal_states",
+ server_info_timeout,
+ )
+ internal_states = []
# This field is not serializable.
if hasattr(_global_state.tokenizer_manager.server_args, "model_config"):
@@ -1084,6 +1093,23 @@ async def update_weights_from_ipc(obj: UpdateWeightsFromIPCReqInput, request: Re
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
+@app.post("/post_process_weights")
+@auth_level(AuthLevel.ADMIN_OPTIONAL)
+async def post_process_weights(req: PostProcessWeightsReqInput, request: Request):
+ """
+ Optional post-processing for updated weights (e.g., Marlin conversion).
+ This should be called selectively after `update_weights_from_distributed/update_weights_from_tensor`.
+ """
+ success, message = await _global_state.tokenizer_manager.post_process_weights(
+ req, request
+ )
+
+ content = {"success": success, "message": message}
+ return ORJSONResponse(
+ content, status_code=200 if success else HTTPStatus.BAD_REQUEST
+ )
+
+
@app.post("/update_weight_version")
@auth_level(AuthLevel.ADMIN_OPTIONAL)
async def update_weight_version(obj: UpdateWeightVersionReqInput, request: Request):
@@ -244,6 +244,7 @@ class Envs:
SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE = EnvInt(2)
SGLANG_DISAGGREGATION_WAITING_TIMEOUT = EnvInt(300)
SGLANG_DISAGGREGATION_NIXL_BACKEND = EnvStr("UCX")
+ SLIME_ENABLE_PROFILING = EnvBool(False)
# Scheduler: others:
SGLANG_EMPTY_CACHE_INTERVAL = EnvFloat(-1) # in seconds. Set if you observe high memory accumulation over a long serving period.
@@ -630,7 +630,6 @@ def _get_k_and_s_triton(
page_indices,
k_out,
s_out,
- seq_len,
page_size,
buf_numel_per_page,
index_head_dim,
@@ -647,7 +646,6 @@ def _get_k_and_s_triton_kernel(
page_indices_ptr,
k_out_ptr,
s_out_ptr,
- seq_len: tl.constexpr,
page_size: tl.constexpr,
buf_numel_per_page: tl.constexpr,
index_head_dim: tl.constexpr,
@@ -1,6 +1,7 @@
from __future__ import annotations
import contextlib
+import os
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
@@ -201,14 +202,31 @@ class Indexer(MultiPlatformOp):
prefix=add_prefix("weights_proj", prefix),
)
self.k_norm = LayerNorm(self.head_dim, dtype=torch.float32)
+ server_args = get_global_server_args()
+ disable_flag = server_args.disable_indexer_rope_neox_style
+ env_raw = os.environ.get("INDEXER_ROPE_NEOX_STYLE", None)
+ if env_raw is not None:
+ env_value = env_raw == "1"
+ if disable_flag and env_value:
+ raise ValueError(
+ "Conflict: --disable-indexer-rope-neox-style is set but "
+ "INDEXER_ROPE_NEOX_STYLE='1'. "
+ "Please remove one or make them consistent."
+ )
+ resolved_neox_style = env_value
+ elif disable_flag:
+ resolved_neox_style = False
+ else:
+ resolved_neox_style = is_neox_style
+
self.rotary_emb = get_rope_wrapper(
rope_head_dim,
rotary_dim=rope_head_dim,
max_position=max_position_embeddings,
base=rope_theta, # type: ignore
rope_scaling=rope_scaling,
- is_neox_style=is_neox_style,
- device=get_global_server_args().device,
+ is_neox_style=resolved_neox_style,
+ device=server_args.device,
)
self.block_size = block_size
self.scale_fmt = scale_fmt
@@ -244,6 +262,11 @@ class Indexer(MultiPlatformOp):
x = x.to(self.weights_proj.weight.dtype)
weights, _ = self.weights_proj(x)
weights = weights.float()
+ if weights.shape[1] < q_scale.shape[1]:
+ assert q_scale.shape[1] % weights.shape[1] == 0
+ weights = weights.repeat_interleave(
+ q_scale.shape[1] // weights.shape[1], dim=1
+ )
weights = weights * self.n_heads**-0.5
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
return weights
@@ -982,15 +1005,26 @@ class Indexer(MultiPlatformOp):
query, key = self._get_q_k_bf16(
q_lora, x, positions, enable_dual_stream, forward_batch=forward_batch
)
+ if query.shape[1] < 32:
+ assert 32 % query.shape[1] == 0
+ query = query.repeat_interleave(32 // query.shape[1], dim=1)
q_fp8, q_scale = act_quant(query, self.block_size, self.scale_fmt)
with torch.cuda.stream(self.alt_stream):
k_fp8, k_scale = act_quant(key, self.block_size, self.scale_fmt)
current_stream.wait_stream(self.alt_stream)
+ if weights.shape[1] < q_scale.shape[1]:
+ assert q_scale.shape[1] % weights.shape[1] == 0
+ weights = weights.repeat_interleave(
+ q_scale.shape[1] // weights.shape[1], dim=1
+ )
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
else:
query, key = self._get_q_k_bf16(
q_lora, x, positions, enable_dual_stream, forward_batch=forward_batch
)
+ if query.shape[1] < 32:
+ assert 32 % query.shape[1] == 0
+ query = query.repeat_interleave(32 // query.shape[1], dim=1)
if enable_dual_stream:
current_stream = torch.cuda.current_stream()
@@ -91,20 +91,29 @@ def nsa_cp_round_robin_split_data(input_: Union[torch.Tensor, List]):
def cal_padded_tokens(forward_batch: "ForwardBatch"):
# Consistent with the padding calculation logic in ForwardBatch.prepare_mlp_sync_batch,
# calculate the actual token length after padding when attn_tp_size > 1 or in the MAX_LEN padding mode.
- global_num_tokens = forward_batch.global_num_tokens_cpu.copy()
+ if forward_batch.global_num_tokens_cpu is None:
+ # PD prefill CP+PP path can bypass MLP-sync metadata. Reconstruct a single-rank
+ # global token view from the local token count for NSA padding logic.
+ local_tokens = forward_batch.num_token_non_padded_cpu
+ if local_tokens is None:
+ local_tokens = len(forward_batch.input_ids)
+ global_num_tokens = [local_tokens * get_attention_cp_size()]
+ else:
+ global_num_tokens = forward_batch.global_num_tokens_cpu.copy()
sync_group_size = len(global_num_tokens)
attn_cp_size = get_attention_cp_size()
for i in range(sync_group_size):
global_num_tokens[i] = ceil_align(global_num_tokens[i], attn_cp_size)
- dp_padding_mode = DpPaddingMode.get_dp_padding_mode(
- forward_batch.is_extend_in_batch, global_num_tokens
- )
- if dp_padding_mode.is_max_len():
- tokens = max(global_num_tokens)
- elif len(global_num_tokens) > 1:
- tokens = global_num_tokens[get_attention_dp_rank()]
- else:
+ if len(global_num_tokens) == 1:
tokens = global_num_tokens[0]
+ else:
+ dp_padding_mode = DpPaddingMode.get_dp_padding_mode(
+ forward_batch.is_extend_in_batch, global_num_tokens
+ )
+ if dp_padding_mode.is_max_len():
+ tokens = max(global_num_tokens)
+ else:
+ tokens = global_num_tokens[get_attention_dp_rank()]
if can_nsa_prefill_cp_round_robin_split(forward_batch):
tokens = ceil_div(tokens, attn_cp_size)
return tokens
@@ -175,10 +184,6 @@ def can_cp_split(seq_len: int, cp_size: int, use_nsa: bool, forward_batch):
def cp_split_and_rebuild_data(forward_batch, input_: torch.Tensor):
if is_nsa_prefill_cp_round_robin_split():
- cp_size = get_attention_cp_size()
- assert (
- input_.shape[0] % cp_size == 0
- ), f"Expect input shape 0 can divided by cp size, but got input shape {input_.shape}, cp size {cp_size}"
return nsa_cp_round_robin_split_data(input_)
input_list = list(
@@ -192,11 +197,6 @@ def cp_split_and_rebuild_data(forward_batch, input_: torch.Tensor):
def cp_split_and_rebuild_position(forward_batch, positions: torch.Tensor):
if is_nsa_prefill_cp_round_robin_split():
- cp_size = get_attention_cp_size()
- assert positions.shape[0] % cp_size == 0, (
- f"Expect positions shape 0 can divided by cp size, but got positions shape {positions.shape}, "
- f"cp size {cp_size}"
- )
return nsa_cp_round_robin_split_data(positions)
position_id_list = list(
@@ -34,7 +34,6 @@ from sglang.srt.layers.communicator import (
from sglang.srt.layers.dp_attention import (
attn_cp_all_gather_into_tensor,
attn_cp_reduce_scatter_tensor,
- get_local_dp_buffer,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -153,9 +152,23 @@ class NSACPCommunicateWithAllReduceAndLayerNormFn(
# for decode: attn tp full -> full
if nsa_use_prefill_cp(forward_batch):
assert context.attn_dp_size == 1
- hidden_states, local_hidden_states = (
- get_local_dp_buffer(),
- hidden_states,
+ local_hidden_states = hidden_states
+ total_tokens = (
+ sum(forward_batch.extend_seq_lens_cpu)
+ if forward_batch.extend_seq_lens_cpu is not None
+ else local_hidden_states.shape[0] * context.attn_cp_size
+ )
+ max_len = (total_tokens + context.attn_cp_size - 1) // context.attn_cp_size
+ if local_hidden_states.shape[0] < max_len:
+ pad = local_hidden_states.new_zeros(
+ (
+ max_len - local_hidden_states.shape[0],
+ local_hidden_states.shape[1],
+ )
+ )
+ local_hidden_states = torch.cat([local_hidden_states, pad], dim=0)
+ hidden_states = local_hidden_states.new_empty(
+ (max_len * context.attn_cp_size, local_hidden_states.shape[1])
)
attn_cp_all_gather_into_tensor(
hidden_states,
@@ -90,11 +90,11 @@ class _DpGatheredBufferWrapper:
_hidden_size: int
_dtype: torch.dtype
_device: torch.device
- _global_dp_buffer_len: int
- _local_dp_buffer_len: int
- _dp_max_padding: bool
- _global_num_tokens: Optional[List[int]]
- _is_extend_in_batch: bool
+ _global_dp_buffer_len: int = 0
+ _local_dp_buffer_len: int = 0
+ _dp_max_padding: bool = False
+ _global_num_tokens: Optional[List[int]] = None
+ _is_extend_in_batch: bool = False
@classmethod
def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device):
@@ -872,11 +872,6 @@ class LogitsProcessor(nn.Module):
None, # bias
True, # is_vnni
)
- elif get_global_server_args().rl_on_policy_target is not None:
- # Due to tie-weight, we may not be able to change lm_head's weight dtype
- logits = torch.matmul(
- hidden_states.bfloat16(), lm_head.weight.T.bfloat16()
- )
else:
logits = torch.matmul(
hidden_states.to(lm_head.weight.dtype), lm_head.weight.T
new file mode 100644
@@ -0,0 +1,146 @@
+"""Fused Triton kernels for DeepEP BF16 low-latency MoE decode.
+
+Replaces the naive activation + masking pipeline (5+ CUDA kernels for silu+mul
+and arange+comparison+masked_fill+copy) with a single Triton elementwise kernel,
+while keeping cuBLAS batched GEMM for the matrix multiplies.
+
+Pipeline: bmm → fused_act_mul_masked (in-place) → bmm(out=hidden)
+ (3 ops total: 2 cuBLAS + 1 Triton, vs original 7-8 separate CUDA kernels)
+"""
+
+import torch
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _silu_mul_masked_kernel(
+ gate_up_ptr,
+ masked_m_ptr,
+ M,
+ N,
+ stride_ge,
+ stride_gm,
+ stride_gn,
+ BLOCK: tl.constexpr,
+):
+ """Fused SiLU(gate) * up with per-expert masking, written in-place.
+
+ gate_up: [E, M, 2*N] — first N cols are gate, last N cols are up.
+ Writes SiLU(gate)*up to gate_up[:,:,:N] in-place.
+ Rows m >= masked_m[e] are zeroed.
+ """
+ expert_id = tl.program_id(1)
+ pid = tl.program_id(0)
+
+ expert_valid_m = tl.load(masked_m_ptr + expert_id)
+
+ offs = pid * BLOCK + tl.arange(0, BLOCK)
+ total = M * N
+ mask = offs < total
+
+ m = offs // N
+ n = offs % N
+
+ gate_base = gate_up_ptr + expert_id * stride_ge
+
+ gate_val = tl.load(gate_base + m * stride_gm + n * stride_gn, mask=mask, other=0.0)
+ up_val = tl.load(
+ gate_base + m * stride_gm + (n + N) * stride_gn, mask=mask, other=0.0
+ )
+
+ gate_f32 = gate_val.to(tl.float32)
+ result = (gate_f32 * tl.sigmoid(gate_f32)) * up_val.to(tl.float32)
+
+ # Zero invalid rows
+ valid = m < expert_valid_m
+ result = tl.where(valid, result, 0.0)
+
+ tl.store(
+ gate_base + m * stride_gm + n * stride_gn,
+ result.to(gate_up_ptr.dtype.element_ty),
+ mask=mask,
+ )
+
+
+@triton.jit
+def _gelu_mul_masked_kernel(
+ gate_up_ptr,
+ masked_m_ptr,
+ M,
+ N,
+ stride_ge,
+ stride_gm,
+ stride_gn,
+ BLOCK: tl.constexpr,
+):
+ """Fused GELU(gate) * up with per-expert masking, written in-place."""
+ expert_id = tl.program_id(1)
+ pid = tl.program_id(0)
+
+ expert_valid_m = tl.load(masked_m_ptr + expert_id)
+
+ offs = pid * BLOCK + tl.arange(0, BLOCK)
+ total = M * N
+ mask = offs < total
+
+ m = offs // N
+ n = offs % N
+
+ gate_base = gate_up_ptr + expert_id * stride_ge
+
+ gate_val = tl.load(gate_base + m * stride_gm + n * stride_gn, mask=mask, other=0.0)
+ up_val = tl.load(
+ gate_base + m * stride_gm + (n + N) * stride_gn, mask=mask, other=0.0
+ )
+
+ g = gate_val.to(tl.float32)
+ kAlpha = 0.7978845608028654
+ gate_act = 0.5 * g * (1.0 + tl.math.tanh(kAlpha * (g + 0.044715 * g * g * g)))
+ result = gate_act * up_val.to(tl.float32)
+
+ valid = m < expert_valid_m
+ result = tl.where(valid, result, 0.0)
+
+ tl.store(
+ gate_base + m * stride_gm + n * stride_gn,
+ result.to(gate_up_ptr.dtype.element_ty),
+ mask=mask,
+ )
+
+
+def fused_act_mul_masked_inplace(
+ gate_up: torch.Tensor,
+ intermediate_size: int,
+ masked_m: torch.Tensor,
+ use_gelu: bool = False,
+) -> None:
+ """Fused activation + multiply + masking, written in-place to gate_up[:,:,:I].
+
+ After this call, gate_up[:, :, :intermediate_size] contains the masked
+ activated intermediate, suitable for the down projection GEMM.
+
+ Args:
+ gate_up: [E, M, 2*I] output of bmm(tokens, w13.T), modified in-place
+ intermediate_size: I
+ masked_m: [E] per-expert valid token count
+ use_gelu: use GELU instead of SiLU
+ """
+ E, M, _ = gate_up.shape
+ N = intermediate_size
+
+ total = M * N
+ BLOCK = 1024
+ grid = (triton.cdiv(total, BLOCK), E)
+
+ kernel = _gelu_mul_masked_kernel if use_gelu else _silu_mul_masked_kernel
+ kernel[grid](
+ gate_up,
+ masked_m,
+ M,
+ N,
+ gate_up.stride(0),
+ gate_up.stride(1),
+ gate_up.stride(2),
+ BLOCK=BLOCK,
+ )
@@ -132,11 +132,12 @@ class DeepEPMoE(FusedMoE):
and not _is_npu
and not (
get_moe_runner_backend().is_flashinfer_cutedsl()
+ and self.quant_config is not None
and self.quant_config.get_name() == "modelopt_fp4"
)
+ and (self.use_fp8_w8a8 or self.use_w4afp8)
):
- # NPU supports low_latency deepep without deepgemm
- # FP4 quantization with flashinfer_cutedsl also supports low_latency deepep without deepgemm
+ # BF16 models don't need deep_gemm; they use per-expert torch.mm
assert (
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
@@ -154,6 +155,10 @@ class DeepEPMoE(FusedMoE):
# the last one is invalid rank_id
self.expert_mask[:-1] = 1
+ # Set bf16_weights flag on dispatcher so dispatch skips FP8 quantization
+ if not self.use_fp8_w8a8 and not self.use_w4afp8:
+ self.dispatcher.set_quant_config({"bf16_weights": True})
+
def forward(
self,
hidden_states: torch.Tensor,
@@ -228,6 +233,8 @@ class DeepEPMoE(FusedMoE):
elif DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
if self.use_w4afp8:
output = self.forward_cutlass_w4afp8(dispatch_output)
+ elif not self.use_fp8_w8a8:
+ output = self.forward_bf16_normal(dispatch_output)
else:
assert False, "forward_deepgemm_contiguous is deprecated"
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
@@ -238,6 +245,8 @@ class DeepEPMoE(FusedMoE):
output = self.forward_flashinfer_cutedsl(dispatch_output)
elif self.use_w4afp8:
output = self.forward_cutlass_w4afp8_masked(dispatch_output)
+ elif not self.use_fp8_w8a8:
+ output = self.forward_bf16_ll(dispatch_output)
else:
assert False, "forward_deepgemm_masked is deprecated"
@@ -341,6 +350,71 @@ class DeepEPMoE(FusedMoE):
dispatch_output=dispatch_output,
)
+ def forward_bf16_normal(
+ self,
+ dispatch_output: DeepEPNormalDispatchOutput,
+ ) -> torch.Tensor:
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
+
+ hidden_states = dispatch_output.hidden_states
+ topk_ids = dispatch_output.topk_ids
+ topk_weights = dispatch_output.topk_weights
+
+ if hidden_states.shape[0] == 0:
+ return hidden_states
+
+ # topk_ids uses local expert IDs (0..num_local_experts-1), -1 for remote.
+ # fused_experts handles -1 via moe_align_block_size filtering.
+ return fused_experts(
+ hidden_states=hidden_states,
+ w1=self.w13_weight,
+ w2=self.w2_weight,
+ topk_output=(topk_weights, topk_ids, None),
+ moe_runner_config=self.moe_runner_config,
+ )
+
+ def forward_bf16_ll(
+ self,
+ dispatch_output: DeepEPLLDispatchOutput,
+ ) -> torch.Tensor:
+ from sglang.srt.layers.moe.ep_moe.deepep_bf16_kernels import (
+ fused_act_mul_masked_inplace,
+ )
+
+ hidden_states = dispatch_output.hidden_states
+ masked_m = dispatch_output.masked_m
+ expected_m = dispatch_output.expected_m
+
+ _, max_tokens, _ = hidden_states.shape
+ if masked_m.numel() == 0 or max_tokens == 0:
+ return hidden_states
+
+ expected_m = min(expected_m, max_tokens)
+ if expected_m <= 0:
+ return hidden_states
+
+ tokens = hidden_states[:, :expected_m, :]
+
+ # 1. Gate+Up GEMM (cuBLAS batched GEMM)
+ gate_up = torch.bmm(tokens, self.w13_weight.transpose(1, 2))
+
+ # 2. Fused SiLU(gate)*up + masking in-place (1 Triton kernel replaces 6 ops)
+ fused_act_mul_masked_inplace(
+ gate_up,
+ self.intermediate_size_per_partition,
+ masked_m,
+ use_gelu=(self.moe_runner_config.activation == "gelu"),
+ )
+
+ # 3. Down GEMM into hidden_states (cuBLAS, non-contiguous input is OK)
+ torch.bmm(
+ gate_up[:, :, : self.intermediate_size_per_partition],
+ self.w2_weight.transpose(1, 2),
+ out=hidden_states[:, :expected_m, :],
+ )
+
+ return hidden_states
+
def forward_npu(
self,
dispatch_output: Union[DeepEPNormalDispatchOutput, DeepEPLLDispatchOutput],
@@ -697,6 +697,7 @@ class FusedMoE(torch.nn.Module):
"CompressedTensorsWNA16TritonMoE",
]
)
+ and "zero" not in weight_name
else loaded_weight
)
@@ -821,13 +822,16 @@ class FusedMoE(torch.nn.Module):
FusedMoeWeightScaleSupported.GROUP.value,
FusedMoeWeightScaleSupported.BLOCK.value,
]:
- self._load_model_weight_or_group_weight_scale(
- shard_id=shard_id,
- shard_dim=shard_dim,
- loaded_weight=loaded_weight,
- expert_data=expert_data,
- tp_rank=tp_rank,
- )
+ if getattr(param, "load_full_w2", False) and shard_id == "w2":
+ expert_data.copy_(loaded_weight)
+ else:
+ self._load_model_weight_or_group_weight_scale(
+ shard_id=shard_id,
+ shard_dim=shard_dim,
+ loaded_weight=loaded_weight,
+ expert_data=expert_data,
+ tp_rank=tp_rank,
+ )
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust FP8 per-tensor scaling number for e4m3fnuz (AMD)
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
@@ -916,6 +920,7 @@ class FusedMoE(torch.nn.Module):
"CompressedTensorsWNA16TritonMoE",
]
)
+ and "zero" not in weight_name
else loaded_weight
)
@@ -8,10 +8,15 @@ import torch
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.layers.dp_attention import (
+ attn_tp_all_gather_into_tensor,
get_attention_dp_rank,
+ get_attention_tp_size,
get_dp_local_info,
is_dp_attention_enabled,
)
+from sglang.srt.layers.moe import (
+ get_moe_a2a_backend,
+)
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import get_global_server_args
@@ -181,13 +186,26 @@ class _RoutedExpertsCapturerReal(RoutedExpertsCapturer):
device=device,
)
+ if get_moe_a2a_backend().is_deepep():
+ attn_tp_size = get_attention_tp_size() if is_dp_attention_enabled() else 1
+ self.gather_buffer = torch.empty(
+ (
+ self.device_cache.buffer.shape[0] * attn_tp_size,
+ self.device_cache.buffer.shape[2],
+ ),
+ dtype=torch.int32,
+ device=device,
+ )
+
def _sync_fwd_experts_buffer_DtoH(
self,
forward_batch: ForwardBatch,
can_run_graph: bool,
cuda_graph_batch: int,
):
- if is_dp_attention_enabled():
+ # When DeepEP is enabled, capture() already does all_gather, so device_cache.buffer
+ # contains data from all DP ranks. We should not slice by DP rank in this case.
+ if is_dp_attention_enabled() and not get_moe_a2a_backend().is_deepep():
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
# handle with cuda graph padding
if can_run_graph:
@@ -206,6 +224,12 @@ class _RoutedExpertsCapturerReal(RoutedExpertsCapturer):
].cpu()
def capture(self, layer_id: int, topk_ids: torch.Tensor):
+ if get_moe_a2a_backend().is_deepep():
+ local_topk_ids = topk_ids
+ topk_ids = self.gather_buffer[
+ : local_topk_ids.size(0) * get_attention_tp_size()
+ ]
+ attn_tp_all_gather_into_tensor(topk_ids, local_topk_ids)
self.device_cache.capture_fwd_routed_experts(layer_id, topk_ids)
def get_routed_experts(
@@ -388,6 +388,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
and not get_moe_runner_backend().is_cutlass()
and not envs.SGLANG_DEEPEP_BF16_DISPATCH.get()
+ and not self.quant_config.get("bf16_weights", False)
):
# TODO hard code 128 block quant,use fp8 communication
hidden_states = sglang_per_token_group_quant_fp8(
@@ -466,7 +467,12 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
previous_event=previous_event,
async_finish=self.async_finish,
allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
- expert_alignment=128 if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM else 1,
+ expert_alignment=(
+ 128
+ if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
+ and not self.quant_config.get("bf16_weights", False)
+ else 1
+ ),
config=DeepEPConfig.get_instance().normal_dispatch_config,
)
get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
@@ -491,7 +497,12 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
topk_weights: torch.Tensor,
):
- if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
+ if (
+ deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
+ or _use_aiter
+ or _is_npu
+ or self.quant_config.get("bf16_weights", False)
+ ):
output = hidden_states
else:
raise NotImplementedError() # triton runner was supported but it's temporarily disabled
@@ -551,10 +562,18 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
buffer = self._get_buffer()
topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
topk_ids = topk_ids.to(torch.int64)
- expected_m = (
- hidden_states.shape[0] * buffer.group_size * topk_ids.shape[1]
- + self.num_experts
- ) // self.num_experts
+ if self.quant_config.get("bf16_weights", False):
+ # BF16 low-latency path slices hidden_states[:, :expected_m, :], so
+ # expected_m must remain a correctness-preserving upper bound.
+ expected_m = min(
+ hidden_states.shape[0] * buffer.group_size,
+ self.num_max_dispatch_tokens_per_rank * buffer.group_size,
+ )
+ else:
+ expected_m = (
+ hidden_states.shape[0] * buffer.group_size * topk_ids.shape[1]
+ + self.num_experts
+ ) // self.num_experts
hidden_states, masked_m, event, hook = self._dispatch_core(
hidden_states,
topk_ids,
@@ -609,7 +628,9 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
input_global_scale = self.quant_config.get("input_global_scale", None)
if input_global_scale is not None:
use_nvfp4 = True
- elif not envs.SGLANG_DEEPEP_BF16_DISPATCH.get():
+ elif not envs.SGLANG_DEEPEP_BF16_DISPATCH.get() and not self.quant_config.get(
+ "bf16_weights", False
+ ):
use_fp8 = True
buffer = self._get_buffer()
@@ -499,7 +499,7 @@ class CompressedTensorsConfig(QuantizationConfig):
)
is_static = not weight_quant.dynamic
- return is_channel_group and input_quant_none and is_symmetric and is_static
+ return is_channel_group and input_quant_none and is_static
def _is_mxint4a16(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool:
input_quant_none = input_quant is None
@@ -968,6 +968,9 @@ class CompressedTensorsFusedMoEMethod(FusedMoEMethodBase):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.scheme.process_weights_after_loading(layer)
+ def restore_weights_before_loading(self, layer: torch.nn.Module) -> None:
+ layer.scheme.restore_weights_before_loading(layer)
+
def create_weights(
self,
layer: torch.nn.Module,
@@ -17,7 +17,10 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsMoEScheme,
)
from sglang.srt.layers.quantization.gptq import gptq_marlin_moe_repack
-from sglang.srt.layers.quantization.marlin_utils import marlin_moe_permute_scales
+from sglang.srt.layers.quantization.marlin_utils import (
+ marlin_moe_permute_scales,
+ moe_awq_to_marlin_zero_points,
+)
from sglang.srt.layers.quantization.utils import replace_parameter
from sglang.srt.utils import get_bool_env_var, is_cuda, is_hip, set_weight_attrs
@@ -64,7 +67,7 @@ class CompressedTensorsWNA16MoE(CompressedTensorsMoEScheme):
self.strategy = config.strategy
self.group_size = config.group_size
self.actorder = config.actorder
- assert config.symmetric, "Only symmetric quantization is supported for MoE"
+ self.sym = config.symmetric
if not (
self.quant_config.quant_format == CompressionFormat.pack_quantized.value
@@ -124,7 +127,7 @@ class CompressedTensorsWNA16MoE(CompressedTensorsMoEScheme):
# In the case where we have actorder/g_idx,
# we do not partition the w2 scales
- load_full_w2 = self.actorder and self.group_size != -1
+ load_full_w2 = (self.actorder != "static") and self.group_size != -1
if load_full_w2:
w2_scales_size = intermediate_size_per_partition * layer.moe_tp_size
@@ -172,6 +175,32 @@ class CompressedTensorsWNA16MoE(CompressedTensorsMoEScheme):
layer.register_parameter("w13_weight_shape", w13_weight_shape)
set_weight_attrs(w13_weight_shape, extra_weight_attrs)
+ # add zero param
+ if not self.sym:
+ w13_qzeros = torch.nn.Parameter(
+ torch.empty(
+ num_experts,
+ num_groups_w13,
+ 2 * intermediate_size_per_partition // self.packed_factor,
+ dtype=torch.int32,
+ ),
+ requires_grad=False,
+ )
+ layer.register_parameter("w13_weight_zero_point", w13_qzeros)
+ set_weight_attrs(w13_qzeros, extra_weight_attrs)
+
+ w2_qzeros = torch.nn.Parameter(
+ torch.empty(
+ num_experts,
+ num_groups_w2,
+ hidden_size // self.packed_factor,
+ dtype=torch.int32,
+ ),
+ requires_grad=False,
+ )
+ layer.register_parameter("w2_weight_zero_point", w2_qzeros)
+ set_weight_attrs(w2_qzeros, extra_weight_attrs)
+
w13_g_idx = torch.nn.Parameter(
torch.empty(
num_experts,
@@ -225,11 +254,14 @@ class CompressedTensorsWNA16MoE(CompressedTensorsMoEScheme):
# Force record: these are the target GPTQ shapes for rollback.
layer._original_shapes["w13_weight_packed"] = tuple(w13_weight.shape)
- layer._original_shapes["w2_weight_packed"] = tuple(w2_weight.shape)
+ layer._original_shapes["w13_weight_scale"] = tuple(w13_scale.shape)
+ if not self.sym:
+ layer._original_shapes["w13_weight_zero_point"] = w13_qzeros.shape
- # Also record the shapes of the scales.
+ layer._original_shapes["w2_weight_packed"] = tuple(w2_weight.shape)
layer._original_shapes["w2_weight_scale"] = tuple(w2_scale.shape)
- layer._original_shapes["w13_weight_scale"] = tuple(w13_scale.shape)
+ if not self.sym:
+ layer._original_shapes["w2_weight_zero_point"] = tuple(w2_qzeros.shape)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
@@ -334,6 +366,24 @@ class CompressedTensorsWNA16MoE(CompressedTensorsMoEScheme):
)
replace_tensor("w2_weight_scale", marlin_w2_scales)
+ # Repack zero
+ if not self.sym:
+ marlin_w13_zp = moe_awq_to_marlin_zero_points(
+ layer.w13_weight_zero_point,
+ size_k=layer.w13_weight_zero_point.shape[1],
+ size_n=layer.w13_weight_zero_point.shape[2] * self.packed_factor,
+ num_bits=self.num_bits,
+ )
+ replace_tensor("w13_weight_zero_point", marlin_w13_zp)
+
+ marlin_w2_zp = moe_awq_to_marlin_zero_points(
+ layer.w2_weight_zero_point,
+ size_k=layer.w2_weight_zero_point.shape[1],
+ size_n=layer.w2_weight_zero_point.shape[2] * self.packed_factor,
+ num_bits=self.num_bits,
+ )
+ replace_tensor("w2_weight_zero_point", marlin_w2_zp)
+
layer.is_marlin_converted = True
def restore_weights_before_loading(self, layer: torch.nn.Module):
@@ -399,6 +449,8 @@ class CompressedTensorsWNA16MoE(CompressedTensorsMoEScheme):
g_idx2=layer.w2_weight_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
+ w1_zeros=layer.w13_weight_zero_point if not self.sym else None,
+ w2_zeros=layer.w2_weight_zero_point if not self.sym else None,
num_bits=self.num_bits,
is_k_full=self.is_k_full,
routed_scaling_factor=self.moe_runner_config.routed_scaling_factor,
@@ -305,9 +305,6 @@ class RotaryEmbedding(MultiPlatformOp):
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""A PyTorch-npu implementation of forward()."""
- assert (
- fused_set_kv_buffer_arg is None
- ), "fused_set_kv_buffer_arg is not supported for npu implementation"
if query.dtype == torch.bfloat16 and self.cos_sin_cache.dtype == torch.float:
return self.forward_native(positions, query, key, offsets)
if self.is_neox_style:
@@ -1778,6 +1775,9 @@ class MRotaryEmbedding(RotaryEmbedding):
key: torch.Tensor,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
+ assert (
+ fused_set_kv_buffer_arg is None
+ ), "fused_set_kv_buffer_arg is not supported for npu implementation"
# TODO: remove this when npu_mrope supports QNumHeads * QHeadSize > 4096
assert (
fused_set_kv_buffer_arg is None
@@ -405,6 +405,17 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
prefill_launch_delay=recv_obj.prefill_launch_delay,
prefill_launch_latency=recv_obj.prefill_launch_latency,
prefill_finished_ts=recv_obj.prefill_finished_ts,
+ pd_prefill_bootstrap_queue_duration=recv_obj.pd_prefill_bootstrap_queue_duration,
+ pd_prefill_forward_duration=recv_obj.pd_prefill_forward_duration,
+ pd_prefill_transfer_queue_duration=recv_obj.pd_prefill_transfer_queue_duration,
+ pd_decode_prealloc_duration=recv_obj.pd_decode_prealloc_duration,
+ pd_decode_transfer_duration=recv_obj.pd_decode_transfer_duration,
+ pd_decode_forward_duration=recv_obj.pd_decode_forward_duration,
+ pd_bootstrap_duration=recv_obj.pd_bootstrap_duration,
+ pd_alloc_waiting_duration=recv_obj.pd_alloc_waiting_duration,
+ pd_transfer_speed_gb_s=recv_obj.pd_transfer_speed_gb_s,
+ pd_transfer_total_mb=recv_obj.pd_transfer_total_mb,
+ pd_prefill_retry_count=recv_obj.pd_prefill_retry_count,
)
def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):
@@ -101,6 +101,42 @@ class RequestTimingMetricsMixin:
# This marks when the prefill computation finishes.
prefill_finished_ts: Optional[List[Optional[float]]]
+ # --- PD disaggregation timing fields ---
+ # All fields are None when profiling is disabled or not in PD disaggregation mode.
+
+ # P instance: duration spent in bootstrap queue before entering the wait queue.
+ pd_prefill_bootstrap_queue_duration: Optional[List[Optional[float]]]
+
+ # P instance: duration for the actual prefill forward computation.
+ pd_prefill_forward_duration: Optional[List[Optional[float]]]
+
+ # P instance: duration spent in the KV transfer queue.
+ pd_prefill_transfer_queue_duration: Optional[List[Optional[float]]]
+
+ # D instance: duration waiting for KV cache slot pre-allocation.
+ pd_decode_prealloc_duration: Optional[List[Optional[float]]]
+
+ # D instance: duration waiting for the KV cache transfer to complete.
+ pd_decode_transfer_duration: Optional[List[Optional[float]]]
+
+ # D instance: duration for the actual decode forward computation.
+ pd_decode_forward_duration: Optional[List[Optional[float]]]
+
+ # Bootstrap handshake duration (P and D instances).
+ pd_bootstrap_duration: Optional[List[Optional[float]]]
+
+ # KV cache allocation waiting duration (P and D instances).
+ pd_alloc_waiting_duration: Optional[List[Optional[float]]]
+
+ # KV cache transfer speed in GB/s.
+ pd_transfer_speed_gb_s: Optional[List[Optional[float]]]
+
+ # Total KV cache transferred in MB.
+ pd_transfer_total_mb: Optional[List[Optional[float]]]
+
+ # Number of prefill retries (P instance only).
+ pd_prefill_retry_count: Optional[List[Optional[int]]]
+
@dataclass
class SpeculativeDecodingMetricsMixin:
@@ -1403,6 +1439,20 @@ class UpdateWeightsFromIPCReqOutput(BaseReq):
message: str
+@dataclass
+class PostProcessWeightsReqInput(BaseReq):
+ # Whether to restore weights before loading new weights
+ restore_weights_before_load: bool = False
+ # Whether to enable quantization post-processing
+ post_process_quantization: bool = False
+
+
+@dataclass
+class PostProcessWeightsReqOutput(BaseReq):
+ success: bool
+ message: str
+
+
@dataclass
class InitWeightsSendGroupForRemoteInstanceReqOutput(BaseReq):
success: bool
@@ -1802,6 +1852,10 @@ class GetLoadReqOutput(BaseReq):
num_waiting_reqs: int
num_tokens: int
ts_tic: float
+ # Per-queue breakdown: list of {name, num_reqs, num_tokens, reqs: [{rid, seqlen, input_len, output_len}]}
+ queue_details: Optional[List[Dict[str, Any]]] = None
+ # Running batch info
+ running_details: Optional[Dict[str, Any]] = None
@dataclass
@@ -142,6 +142,39 @@ def _handle_output_by_index(output, i):
prefill_finished_ts=_extract_field_by_index(
output, "prefill_finished_ts", i
),
+ pd_prefill_bootstrap_queue_duration=_extract_field_by_index(
+ output, "pd_prefill_bootstrap_queue_duration", i
+ ),
+ pd_prefill_forward_duration=_extract_field_by_index(
+ output, "pd_prefill_forward_duration", i
+ ),
+ pd_prefill_transfer_queue_duration=_extract_field_by_index(
+ output, "pd_prefill_transfer_queue_duration", i
+ ),
+ pd_decode_prealloc_duration=_extract_field_by_index(
+ output, "pd_decode_prealloc_duration", i
+ ),
+ pd_decode_transfer_duration=_extract_field_by_index(
+ output, "pd_decode_transfer_duration", i
+ ),
+ pd_decode_forward_duration=_extract_field_by_index(
+ output, "pd_decode_forward_duration", i
+ ),
+ pd_bootstrap_duration=_extract_field_by_index(
+ output, "pd_bootstrap_duration", i
+ ),
+ pd_alloc_waiting_duration=_extract_field_by_index(
+ output, "pd_alloc_waiting_duration", i
+ ),
+ pd_transfer_speed_gb_s=_extract_field_by_index(
+ output, "pd_transfer_speed_gb_s", i
+ ),
+ pd_transfer_total_mb=_extract_field_by_index(
+ output, "pd_transfer_total_mb", i
+ ),
+ pd_prefill_retry_count=_extract_field_by_index(
+ output, "pd_prefill_retry_count", i
+ ),
finished_reasons=_extract_field_by_index(output, "finished_reasons", i),
decoded_texts=_extract_field_by_index(output, "decoded_texts", i),
decode_ids=_extract_field_by_index(output, "decode_ids", i),
@@ -211,6 +244,50 @@ def _handle_output_by_index(output, i):
elif isinstance(output, BatchEmbeddingOutput):
new_output = BatchEmbeddingOutput(
rids=[output.rids[i]],
+ queue_time=_extract_field_by_index(output, "queue_time", i),
+ forward_entry_time=_extract_field_by_index(output, "forward_entry_time", i),
+ prefill_launch_delay=_extract_field_by_index(
+ output, "prefill_launch_delay", i
+ ),
+ prefill_launch_latency=_extract_field_by_index(
+ output, "prefill_launch_latency", i
+ ),
+ prefill_finished_ts=_extract_field_by_index(
+ output, "prefill_finished_ts", i
+ ),
+ pd_prefill_bootstrap_queue_duration=_extract_field_by_index(
+ output, "pd_prefill_bootstrap_queue_duration", i
+ ),
+ pd_prefill_forward_duration=_extract_field_by_index(
+ output, "pd_prefill_forward_duration", i
+ ),
+ pd_prefill_transfer_queue_duration=_extract_field_by_index(
+ output, "pd_prefill_transfer_queue_duration", i
+ ),
+ pd_decode_prealloc_duration=_extract_field_by_index(
+ output, "pd_decode_prealloc_duration", i
+ ),
+ pd_decode_transfer_duration=_extract_field_by_index(
+ output, "pd_decode_transfer_duration", i
+ ),
+ pd_decode_forward_duration=_extract_field_by_index(
+ output, "pd_decode_forward_duration", i
+ ),
+ pd_bootstrap_duration=_extract_field_by_index(
+ output, "pd_bootstrap_duration", i
+ ),
+ pd_alloc_waiting_duration=_extract_field_by_index(
+ output, "pd_alloc_waiting_duration", i
+ ),
+ pd_transfer_speed_gb_s=_extract_field_by_index(
+ output, "pd_transfer_speed_gb_s", i
+ ),
+ pd_transfer_total_mb=_extract_field_by_index(
+ output, "pd_transfer_total_mb", i
+ ),
+ pd_prefill_retry_count=_extract_field_by_index(
+ output, "pd_prefill_retry_count", i
+ ),
finished_reasons=_extract_field_by_index(output, "finished_reasons", i),
embeddings=_extract_field_by_index(output, "embeddings", i),
prompt_tokens=_extract_field_by_index(output, "prompt_tokens", i),
@@ -239,6 +316,39 @@ def _handle_output_by_index(output, i):
prefill_finished_ts=_extract_field_by_index(
output, "prefill_finished_ts", i
),
+ pd_prefill_bootstrap_queue_duration=_extract_field_by_index(
+ output, "pd_prefill_bootstrap_queue_duration", i
+ ),
+ pd_prefill_forward_duration=_extract_field_by_index(
+ output, "pd_prefill_forward_duration", i
+ ),
+ pd_prefill_transfer_queue_duration=_extract_field_by_index(
+ output, "pd_prefill_transfer_queue_duration", i
+ ),
+ pd_decode_prealloc_duration=_extract_field_by_index(
+ output, "pd_decode_prealloc_duration", i
+ ),
+ pd_decode_transfer_duration=_extract_field_by_index(
+ output, "pd_decode_transfer_duration", i
+ ),
+ pd_decode_forward_duration=_extract_field_by_index(
+ output, "pd_decode_forward_duration", i
+ ),
+ pd_bootstrap_duration=_extract_field_by_index(
+ output, "pd_bootstrap_duration", i
+ ),
+ pd_alloc_waiting_duration=_extract_field_by_index(
+ output, "pd_alloc_waiting_duration", i
+ ),
+ pd_transfer_speed_gb_s=_extract_field_by_index(
+ output, "pd_transfer_speed_gb_s", i
+ ),
+ pd_transfer_total_mb=_extract_field_by_index(
+ output, "pd_transfer_total_mb", i
+ ),
+ pd_prefill_retry_count=_extract_field_by_index(
+ output, "pd_prefill_retry_count", i
+ ),
finished_reasons=_extract_field_by_index(output, "finished_reasons", i),
output_strs=_extract_field_by_index(output, "output_strs", i),
output_ids=_extract_field_by_index(output, "output_ids", i),
@@ -524,6 +634,60 @@ def monkey_patch_uvicorn_multiprocessing(timeout: float = 10):
"uvicorn.supervisors.multiprocess not found, skipping monkey patch"
)
+ # Fix stdin fd issue when running under Ray (or other managed
+ # environments where stdin may not be a real terminal):
+ #
+ # Uvicorn's get_subprocess() captures sys.stdin.fileno() in the parent
+ # and passes it to spawn'd children, which call os.fdopen(stdin_fileno)
+ # to re-attach stdin. This is intended for interactive debugging (e.g.
+ # pdb attach to a child worker).
+ #
+ # In Ray Actors, sys.stdin.fileno() succeeds in the parent (returns a
+ # valid fd number), but the fd is not inheritable across spawn. The
+ # child's os.fdopen() then crashes with OSError: [Errno 9] Bad file
+ # descriptor, killing every tokenizer worker.
+ #
+ # Instead of unconditionally disabling stdin passthrough, we probe
+ # whether the fd is truly usable by dup'ing it. If os.dup() fails,
+ # the fd won't survive spawn either, so we fall back to None. In a
+ # normal terminal environment os.dup() succeeds and debugging ability
+ # is preserved.
+ try:
+ import uvicorn._subprocess as _uv_sub
+ import uvicorn.supervisors.multiprocess as _uv_mp
+
+ def _safe_get_stdin_fileno():
+ """Return stdin fileno only if it is genuinely usable."""
+ try:
+ fileno = sys.stdin.fileno()
+ # Verify the fd is valid and duplicable — if it isn't,
+ # spawn'd children won't be able to reopen it either.
+ dup_fd = os.dup(fileno)
+ os.close(dup_fd)
+ return fileno
+ except (AttributeError, OSError):
+ return None
+
+ def _patched_get_subprocess(config, target, sockets):
+ stdin_fileno = _safe_get_stdin_fileno()
+ kwargs = {
+ "config": config,
+ "target": target,
+ "sockets": sockets,
+ "stdin_fileno": stdin_fileno,
+ }
+ return _uv_sub.spawn.Process(
+ target=_uv_sub.subprocess_started, kwargs=kwargs
+ )
+
+ # Must patch both: the supervisor module caches its own reference
+ # to get_subprocess at import time via
+ # ``from uvicorn._subprocess import get_subprocess``.
+ _uv_sub.get_subprocess = _patched_get_subprocess
+ _uv_mp.get_subprocess = _patched_get_subprocess
+ except Exception:
+ pass
+
class SenderWrapper:
def __init__(self, port_args: PortArgs, send_to_scheduler: zmq.Socket):
@@ -1869,7 +1869,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
while first_iter or (
not self.check_decode_mem(selected_indices=sorted_indices)
):
- if len(sorted_indices) == 1:
+ # We should allow all requests to be retracted in decode disaggregation mode
+ # because there call be prealloc prefill requests.
+ num_minimum_reqs = 0 if server_args.disaggregation_mode == "decode" else 1
+ if len(sorted_indices) == num_minimum_reqs:
# Always keep at least one request
break
@@ -114,6 +114,7 @@ from sglang.srt.managers.io_struct import (
OpenSessionReqInput,
OpenSessionReqOutput,
PauseGenerationReqInput,
+ PostProcessWeightsReqInput,
ProfileReq,
ReleaseMemoryOccupationReqInput,
ResumeMemoryOccupationReqInput,
@@ -952,6 +953,11 @@ class Scheduler(
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
)
+ # When CP > 1, all CP ranks must agree on poll results so they
+ # enter run_batch together; use the full TP gloo group for consensus.
+ disagg_prefill_gloo_group = (
+ self.tp_cpu_group if self.attn_cp_size > 1 else self.attn_tp_cpu_group
+ )
self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
draft_token_to_kv_pool=draft_token_to_kv_pool,
@@ -961,7 +967,7 @@ class Scheduler(
tp_size=self.tp_size,
gpu_id=self.gpu_id,
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
- gloo_group=self.attn_tp_cpu_group,
+ gloo_group=disagg_prefill_gloo_group,
max_total_num_tokens=self.max_total_num_tokens,
decode_tp_size=self.server_args.disaggregation_decode_tp,
decode_dp_size=self.server_args.disaggregation_decode_dp,
@@ -1063,6 +1069,7 @@ class Scheduler(
),
(UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
(UpdateWeightsFromIPCReqInput, self.update_weights_from_ipc),
+ (PostProcessWeightsReqInput, self.post_process_weights),
(GetWeightsByNameReqInput, self.get_weights_by_name),
(ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
@@ -609,12 +609,54 @@ class SchedulerMetricsMixin:
num_tokens += sum(req.seqlen for queue in waiting_queues for req in queue)
num_waiting_reqs = sum(len(queue) for queue in waiting_queues)
+ # Collect per-queue details
+ queue_names = ["waiting_queue"]
+ if self.disaggregation_mode == DisaggregationMode.PREFILL:
+ queue_names.append("bootstrap_queue")
+ elif self.disaggregation_mode == DisaggregationMode.DECODE:
+ queue_names.append("prealloc_queue")
+ queue_names.append("transfer_queue")
+ queue_names.append("retracted_queue")
+
+ queue_details = []
+ for name, queue in zip(queue_names, waiting_queues):
+ reqs_info = []
+ for req in queue:
+ reqs_info.append(
+ {
+ "seqlen": req.seqlen,
+ }
+ )
+ queue_details.append(
+ {
+ "name": name,
+ "num_reqs": len(queue),
+ "num_tokens": sum(r["seqlen"] for r in reqs_info),
+ "reqs": reqs_info,
+ }
+ )
+
+ # Collect running batch details
+ running_reqs_info = []
+ for req in self.running_batch.reqs:
+ running_reqs_info.append(
+ {
+ "seqlen": req.seqlen,
+ }
+ )
+ running_details = {
+ "num_reqs": len(self.running_batch.reqs),
+ "reqs": running_reqs_info,
+ }
+
return GetLoadReqOutput(
dp_rank=self.dp_rank,
num_reqs=len(self.running_batch.reqs) + num_waiting_reqs,
num_waiting_reqs=num_waiting_reqs,
num_tokens=num_tokens,
ts_tic=time.perf_counter(),
+ queue_details=queue_details,
+ running_details=running_details,
)
def get_loads(self: Scheduler, req: GetLoadsReqInput = None) -> GetLoadsReqOutput:
@@ -922,6 +922,18 @@ class SchedulerOutputProcessorMixin:
prefill_launch_delays = []
prefill_launch_latencies = []
prefill_finished_timestamps = []
+ profiling_enabled = envs.SLIME_ENABLE_PROFILING.get()
+ pd_prefill_bootstrap_queue_durations = [] if profiling_enabled else None
+ pd_prefill_forward_durations = [] if profiling_enabled else None
+ pd_prefill_transfer_queue_durations = [] if profiling_enabled else None
+ pd_decode_prealloc_durations = [] if profiling_enabled else None
+ pd_decode_transfer_durations = [] if profiling_enabled else None
+ pd_decode_forward_durations = [] if profiling_enabled else None
+ pd_bootstrap_durations = [] if profiling_enabled else None
+ pd_alloc_waiting_durations = [] if profiling_enabled else None
+ pd_transfer_speeds_gb_s = [] if profiling_enabled else None
+ pd_transfer_totals_mb = [] if profiling_enabled else None
+ pd_prefill_retry_counts = [] if profiling_enabled else None
if return_logprob:
input_token_logprobs_val = []
@@ -1037,6 +1049,40 @@ class SchedulerOutputProcessorMixin:
prefill_finished_timestamps.append(
req.time_stats.get_prefill_finished_ts()
)
+ if profiling_enabled:
+ pd_prefill_bootstrap_queue_durations.append(
+ req.time_stats.get_pd_prefill_bootstrap_queue_duration()
+ )
+ pd_prefill_forward_durations.append(
+ req.time_stats.get_pd_prefill_forward_duration()
+ )
+ pd_prefill_transfer_queue_durations.append(
+ req.time_stats.get_pd_prefill_transfer_queue_duration()
+ )
+ pd_decode_prealloc_durations.append(
+ req.time_stats.get_pd_decode_prealloc_duration()
+ )
+ pd_decode_transfer_durations.append(
+ req.time_stats.get_pd_decode_transfer_duration()
+ )
+ pd_decode_forward_durations.append(
+ req.time_stats.get_pd_decode_forward_duration()
+ )
+ pd_bootstrap_durations.append(
+ req.time_stats.get_pd_bootstrap_duration()
+ )
+ pd_alloc_waiting_durations.append(
+ req.time_stats.get_pd_alloc_waiting_duration()
+ )
+ pd_transfer_speeds_gb_s.append(
+ req.time_stats.get_pd_transfer_speed_gb_s()
+ )
+ pd_transfer_totals_mb.append(
+ req.time_stats.get_pd_transfer_total_mb()
+ )
+ pd_prefill_retry_counts.append(
+ req.time_stats.get_pd_prefill_retry_count()
+ )
if not self.spec_algorithm.is_none():
spec_verify_ct.append(req.spec_verify_ct)
@@ -1134,7 +1180,7 @@ class SchedulerOutputProcessorMixin:
req.log_time_stats()
# Send to detokenizer
- if reqs or is_idle_batch:
+ if rids or is_idle_batch:
if self.model_config.is_multimodal_gen:
return
self.send_to_detokenizer.send_output(
@@ -1149,6 +1195,17 @@ class SchedulerOutputProcessorMixin:
prefill_launch_delay=prefill_launch_delays,
prefill_launch_latency=prefill_launch_latencies,
prefill_finished_ts=prefill_finished_timestamps,
+ pd_prefill_bootstrap_queue_duration=pd_prefill_bootstrap_queue_durations,
+ pd_prefill_forward_duration=pd_prefill_forward_durations,
+ pd_prefill_transfer_queue_duration=pd_prefill_transfer_queue_durations,
+ pd_decode_prealloc_duration=pd_decode_prealloc_durations,
+ pd_decode_transfer_duration=pd_decode_transfer_durations,
+ pd_decode_forward_duration=pd_decode_forward_durations,
+ pd_bootstrap_duration=pd_bootstrap_durations,
+ pd_alloc_waiting_duration=pd_alloc_waiting_durations,
+ pd_transfer_speed_gb_s=pd_transfer_speeds_gb_s,
+ pd_transfer_total_mb=pd_transfer_totals_mb,
+ pd_prefill_retry_count=pd_prefill_retry_counts,
finished_reasons=finished_reasons,
decoded_texts=decoded_texts,
decode_ids=decode_ids_list,
@@ -1198,6 +1255,18 @@ class SchedulerOutputProcessorMixin:
prefill_launch_delays = []
prefill_launch_latencies = []
prefill_finished_timestamps = []
+ profiling_enabled = envs.SLIME_ENABLE_PROFILING.get()
+ pd_prefill_bootstrap_queue_durations = [] if profiling_enabled else None
+ pd_prefill_forward_durations = [] if profiling_enabled else None
+ pd_prefill_transfer_queue_durations = [] if profiling_enabled else None
+ pd_decode_prealloc_durations = [] if profiling_enabled else None
+ pd_decode_transfer_durations = [] if profiling_enabled else None
+ pd_decode_forward_durations = [] if profiling_enabled else None
+ pd_bootstrap_durations = [] if profiling_enabled else None
+ pd_alloc_waiting_durations = [] if profiling_enabled else None
+ pd_transfer_speeds_gb_s = [] if profiling_enabled else None
+ pd_transfer_totals_mb = [] if profiling_enabled else None
+ pd_prefill_retry_counts = [] if profiling_enabled else None
retraction_counts = []
for req in reqs:
if req.finished():
@@ -1221,6 +1290,40 @@ class SchedulerOutputProcessorMixin:
prefill_finished_timestamps.append(
req.time_stats.get_prefill_finished_ts()
)
+ if profiling_enabled:
+ pd_prefill_bootstrap_queue_durations.append(
+ req.time_stats.get_pd_prefill_bootstrap_queue_duration()
+ )
+ pd_prefill_forward_durations.append(
+ req.time_stats.get_pd_prefill_forward_duration()
+ )
+ pd_prefill_transfer_queue_durations.append(
+ req.time_stats.get_pd_prefill_transfer_queue_duration()
+ )
+ pd_decode_prealloc_durations.append(
+ req.time_stats.get_pd_decode_prealloc_duration()
+ )
+ pd_decode_transfer_durations.append(
+ req.time_stats.get_pd_decode_transfer_duration()
+ )
+ pd_decode_forward_durations.append(
+ req.time_stats.get_pd_decode_forward_duration()
+ )
+ pd_bootstrap_durations.append(
+ req.time_stats.get_pd_bootstrap_duration()
+ )
+ pd_alloc_waiting_durations.append(
+ req.time_stats.get_pd_alloc_waiting_duration()
+ )
+ pd_transfer_speeds_gb_s.append(
+ req.time_stats.get_pd_transfer_speed_gb_s()
+ )
+ pd_transfer_totals_mb.append(
+ req.time_stats.get_pd_transfer_total_mb()
+ )
+ pd_prefill_retry_counts.append(
+ req.time_stats.get_pd_prefill_retry_count()
+ )
retraction_counts.append(req.retraction_count)
self.send_to_detokenizer.send_output(
BatchEmbeddingOutput(
@@ -1231,6 +1334,17 @@ class SchedulerOutputProcessorMixin:
prefill_launch_delay=prefill_launch_delays,
prefill_launch_latency=prefill_launch_latencies,
prefill_finished_ts=prefill_finished_timestamps,
+ pd_prefill_bootstrap_queue_duration=pd_prefill_bootstrap_queue_durations,
+ pd_prefill_forward_duration=pd_prefill_forward_durations,
+ pd_prefill_transfer_queue_duration=pd_prefill_transfer_queue_durations,
+ pd_decode_prealloc_duration=pd_decode_prealloc_durations,
+ pd_decode_transfer_duration=pd_decode_transfer_durations,
+ pd_decode_forward_duration=pd_decode_forward_durations,
+ pd_bootstrap_duration=pd_bootstrap_durations,
+ pd_alloc_waiting_duration=pd_alloc_waiting_durations,
+ pd_transfer_speed_gb_s=pd_transfer_speeds_gb_s,
+ pd_transfer_total_mb=pd_transfer_totals_mb,
+ pd_prefill_retry_count=pd_prefill_retry_counts,
finished_reasons=finished_reasons,
embeddings=embeddings,
prompt_tokens=prompt_tokens,
@@ -20,6 +20,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_dp_rank,
get_attention_dp_size,
is_dp_attention_enabled,
+ set_is_extend_in_batch,
)
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.managers.utils import (
@@ -224,7 +225,28 @@ class SchedulerPPMixin:
self.process_prefill_chunk()
batch = self.get_new_batch_prefill()
- batch = self.maybe_prepare_mlp_sync_batch(batch)
+ need_mlp_sync = self.require_mlp_sync
+ skipped_mlp_sync = False
+ if (
+ need_mlp_sync
+ and self.disaggregation_mode == DisaggregationMode.PREFILL
+ and self.server_args.enable_nsa_prefill_context_parallel
+ and self.pp_size > 1
+ ):
+ # In PD prefill CP+PP, MLP sync all_gather can deadlock on idle micro-batches.
+ # Skip MLP sync here because decode-side MLP gather is not involved in this path.
+ need_mlp_sync = False
+ skipped_mlp_sync = True
+ batch = self.maybe_prepare_mlp_sync_batch(
+ batch, need_sync=need_mlp_sync
+ )
+ if skipped_mlp_sync:
+ # MLP sync was skipped but set_is_extend_in_batch is still needed
+ # by the deepep dispatcher (called in model forward).
+ is_extend = (
+ batch.forward_mode.is_extend() if batch is not None else False
+ )
+ set_is_extend_in_batch(is_extend)
self.mbs[mb_id] = batch
self.running_mbs[mb_id] = self.running_batch
@@ -288,6 +310,11 @@ class SchedulerPPMixin:
next_batch_result,
)
self.last_mbs[next_mb_id] = self.mbs[next_mb_id]
+ if self.current_scheduler_metrics_enabled:
+ self.log_prefill_stats(
+ prefill_stats=self.mbs[next_mb_id].prefill_stats,
+ can_run_cuda_graph=next_batch_result.can_run_cuda_graph,
+ )
if tmbs[next_mb_id] is not None:
self.process_disagg_prefill_inflight_queue(next_release_rids)
@@ -524,6 +551,11 @@ class SchedulerPPMixin:
self.last_rank_comm_queue: deque[Tuple[torch.cuda.Event, PPProxyTensors]] = (
deque()
)
+ # PP1 (last rank) stores its own batch outputs locally to avoid the
+ # PP1→PP0→PP1 round-trip that causes a deadlock in disagg prefill.
+ self.last_rank_local_result_queue: deque[
+ Tuple[torch.cuda.Event, PPProxyTensors]
+ ] = deque()
self.send_req_work = []
self.send_proxy_work = []
@@ -859,31 +891,39 @@ class SchedulerPPMixin:
def _pp_send_pyobj_to_next_stage(self: Scheduler, data, async_send: bool = False):
p2p_work = []
- if self.attn_tp_rank == 0:
- dp_offset = self.attn_dp_rank * self.attn_tp_size
+ if self.attn_tp_rank == 0 and self.attn_cp_rank == 0:
+ lane_offset = self.attn_dp_rank * self.attn_tp_size
p2p_work = point_to_point_pyobj(
data,
- self.pp_rank * self.tp_size + dp_offset,
+ self.pp_rank * self.tp_size + lane_offset,
self.world_group.cpu_group,
- self.pp_rank * self.tp_size + dp_offset,
- ((self.pp_rank + 1) % self.pp_size) * self.tp_size + dp_offset,
+ self.pp_rank * self.tp_size + lane_offset,
+ ((self.pp_rank + 1) % self.pp_size) * self.tp_size + lane_offset,
async_send=async_send,
)
return p2p_work
def _pp_recv_pyobj_from_prev_stage(self: Scheduler):
- if self.attn_tp_rank == 0:
- dp_offset = self.attn_dp_rank * self.attn_tp_size
+ if self.attn_tp_rank == 0 and self.attn_cp_rank == 0:
+ lane_offset = self.attn_dp_rank * self.attn_tp_size
data = point_to_point_pyobj(
[],
- self.pp_rank * self.tp_size + dp_offset,
+ self.pp_rank * self.tp_size + lane_offset,
self.world_group.cpu_group,
- ((self.pp_rank - 1) % self.pp_size) * self.tp_size + dp_offset,
- self.pp_rank * self.tp_size + dp_offset,
+ ((self.pp_rank - 1) % self.pp_size) * self.tp_size + lane_offset,
+ self.pp_rank * self.tp_size + lane_offset,
)
else:
data = None
+ if self.attn_cp_size > 1:
+ data = broadcast_pyobj(
+ data,
+ self.attn_cp_group.rank,
+ self.attn_cp_cpu_group,
+ src=self.attn_cp_group.ranks[0],
+ )
+
if self.attn_tp_size > 1:
data = broadcast_pyobj(
data,
@@ -1004,8 +1044,13 @@ class SchedulerPPMixin:
pp_outputs_to_send.tensors,
async_send=True,
)
- # send the outputs from the last round to let the next stage worker run post processing
- if not self.pp_group.is_last_rank:
+ # Store locally so the last rank can process its own batch result
+ # without receiving from the second-to-last rank (avoids deadlock).
+ self.last_rank_local_result_queue.append((q_event, pp_outputs_to_send))
+ elif self.pp_rank != self.pp_size - 2:
+ # Forward output through the chain: PP0→PP1→...→PP(last-2).
+ # The second-to-last rank does NOT forward to the last rank because
+ # the last rank uses last_rank_local_result_queue instead of receiving.
if pp_outputs:
with torch.profiler.record_function("send_res_dict_to_next_stage"):
send_output_work = self._pp_send_dict_to_next_stage(
@@ -1034,20 +1079,38 @@ class SchedulerPPMixin:
)
if mbs[next_mb_id] is not None:
- with torch.profiler.record_function("recv_res_dict_from_prev_stage"):
- next_pp_outputs = None
+ if self.pp_group.is_last_rank:
+ # Last rank: use the locally-stored output instead of receiving
+ # from the second-to-last rank. Receiving would cause a deadlock
+ # because the chain PP_last→PP0→...→PP(last-2)→PP_last requires
+ # PP0 to have pp_outputs ready, which it doesn't on the first batch.
if not mbs[next_mb_id].forward_mode.is_prebuilt():
- next_pp_outputs = PPProxyTensors(
- self._pp_recv_dict_from_prev_stage()
- )
- if not mbs[next_mb_id].forward_mode.is_prebuilt():
- with self.copy_stream_ctx:
- self.copy_stream.wait_stream(self.default_stream)
- batch_result = self._pp_prep_batch_result(
- mbs[next_mb_id], mb_metadata[next_mb_id], next_pp_outputs
+ q_event, next_pp_outputs = (
+ self.last_rank_local_result_queue.popleft()
)
- d2h_event = torch.cuda.Event()
- d2h_event.record(torch.cuda.current_stream())
+ with self.copy_stream_ctx:
+ torch.cuda.current_stream().wait_event(q_event)
+ self.copy_stream.wait_stream(self.default_stream)
+ batch_result = self._pp_prep_batch_result(
+ mbs[next_mb_id], mb_metadata[next_mb_id], next_pp_outputs
+ )
+ d2h_event = torch.cuda.Event()
+ d2h_event.record(torch.cuda.current_stream())
+ else:
+ with torch.profiler.record_function("recv_res_dict_from_prev_stage"):
+ next_pp_outputs = None
+ if not mbs[next_mb_id].forward_mode.is_prebuilt():
+ next_pp_outputs = PPProxyTensors(
+ self._pp_recv_dict_from_prev_stage()
+ )
+ if not mbs[next_mb_id].forward_mode.is_prebuilt():
+ with self.copy_stream_ctx:
+ self.copy_stream.wait_stream(self.default_stream)
+ batch_result = self._pp_prep_batch_result(
+ mbs[next_mb_id], mb_metadata[next_mb_id], next_pp_outputs
+ )
+ d2h_event = torch.cuda.Event()
+ d2h_event.record(torch.cuda.current_stream())
return next_pp_outputs, batch_result, d2h_event, send_output_work
@@ -1085,9 +1148,12 @@ class SchedulerPPMixin:
"""
Used by PP, get the required rids with the given poll statuses.
"""
+ gloo_group = self.attn_tp_cpu_group
+ if self.attn_cp_size > 1:
+ gloo_group = self.tp_cpu_group
polls = poll_and_all_reduce(
[req.disagg_kv_sender if is_send else req.kv_receiver for req in req_queue],
- self.attn_tp_cpu_group,
+ gloo_group,
)
rids: List = []
for poll_statuses in poll_statuses_group:
@@ -347,7 +347,7 @@ class SchedulerProfilerMixin:
if self.profiler_prefill_ct > self.profiler_target_prefill_ct:
if self.profile_in_progress:
self.stop_profile(stage=ForwardMode.EXTEND)
- elif batch.forward_mode.is_decode():
+ elif batch.forward_mode.is_decode() or batch.forward_mode.is_prebuilt():
if self.profiler_decode_ct == 0:
if self.profile_in_progress:
# force trace flush
@@ -12,6 +12,7 @@ from sglang.srt.constants import (
GPU_MEMORY_TYPE_KV_CACHE,
GPU_MEMORY_TYPE_WEIGHTS,
)
+from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.managers.io_struct import (
CheckWeightsReqInput,
CheckWeightsReqOutput,
@@ -21,6 +22,8 @@ from sglang.srt.managers.io_struct import (
GetWeightsByNameReqOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
+ PostProcessWeightsReqInput,
+ PostProcessWeightsReqOutput,
ReleaseMemoryOccupationReqInput,
ReleaseMemoryOccupationReqOutput,
ResumeMemoryOccupationReqInput,
@@ -114,6 +117,11 @@ class SchedulerUpdateWeightsMixin:
torch.distributed.barrier(group=self.tp_cpu_group)
return UpdateWeightsFromIPCReqOutput(success, message)
+ def post_process_weights(self, recv_req: PostProcessWeightsReqInput):
+ """Optional post-processing for updated weights (e.g., Marlin conversion)."""
+ success, message = self.tp_worker.post_process_weights(recv_req)
+ return PostProcessWeightsReqOutput(success, message)
+
def get_weights_by_name(self: Scheduler, recv_req: GetWeightsByNameReqInput):
parameter = self.tp_worker.get_weights_by_name(recv_req)
return GetWeightsByNameReqOutput(parameter)
@@ -137,6 +145,15 @@ class SchedulerUpdateWeightsMixin:
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
self.flush_cache()
+ if self.disaggregation_mode == DisaggregationMode.DECODE:
+ if hasattr(self, "disagg_decode_transfer_queue"):
+ self.disagg_decode_transfer_queue.release_memory_occupation()
+ if hasattr(self, "disagg_decode_prealloc_queue"):
+ self.disagg_decode_prealloc_queue.release_memory_occupation()
+ elif self.disaggregation_mode == DisaggregationMode.PREFILL:
+ if hasattr(self, "disagg_prefill_bootstrap_queue"):
+ self.disagg_prefill_bootstrap_queue.release_memory_occupation()
+
if GPU_MEMORY_TYPE_WEIGHTS in tags:
self.stashed_model_static_state = _export_static_state(
self.tp_worker.model_runner.model
@@ -177,6 +194,15 @@ class SchedulerUpdateWeightsMixin:
if GPU_MEMORY_TYPE_KV_CACHE in tags:
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE)
+ if self.disaggregation_mode == DisaggregationMode.DECODE:
+ if hasattr(self, "disagg_decode_transfer_queue"):
+ self.disagg_decode_transfer_queue.resume_memory_occupation()
+ if hasattr(self, "disagg_decode_prealloc_queue"):
+ self.disagg_decode_prealloc_queue.resume_memory_occupation()
+ elif self.disaggregation_mode == DisaggregationMode.PREFILL:
+ if hasattr(self, "disagg_prefill_bootstrap_queue"):
+ self.disagg_prefill_bootstrap_queue.resume_memory_occupation()
+
return ResumeMemoryOccupationReqOutput()
def check_weights(self: Scheduler, recv_req: CheckWeightsReqInput):
@@ -59,6 +59,8 @@ from sglang.srt.managers.io_struct import (
LoadLoRAAdapterReqOutput,
LoRAUpdateOutput,
OpenSessionReqInput,
+ PostProcessWeightsReqInput,
+ PostProcessWeightsReqOutput,
ProfileReq,
ProfileReqOutput,
ProfileReqType,
@@ -187,6 +189,9 @@ class TokenizerCommunicatorMixin:
self.update_weights_from_ipc_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
+ self.post_process_weights_communicator = _Communicator(
+ self.send_to_scheduler, server_args.dp_size
+ )
self.get_weights_by_name_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
@@ -272,6 +277,10 @@ class TokenizerCommunicatorMixin:
UpdateWeightsFromIPCReqOutput,
self.update_weights_from_ipc_communicator.handle_recv,
),
+ (
+ PostProcessWeightsReqOutput,
+ self.post_process_weights_communicator.handle_recv,
+ ),
(
GetWeightsByNameReqOutput,
self.get_weights_by_name_communicator.handle_recv,
@@ -522,6 +531,17 @@ class TokenizerCommunicatorMixin:
return success, message
+ async def post_process_weights(
+ self: TokenizerManager,
+ obj: PostProcessWeightsReqInput,
+ request: Optional[fastapi.Request] = None,
+ ) -> Tuple[bool, str]:
+ """Trigger post-processing hooks for weights after loading (e.g., Marlin conversion)."""
+ self.auto_create_handle_loop()
+ async with self.model_update_lock.writer_lock:
+ results = await self.post_process_weights_communicator(obj)
+ return _Communicator.merge_results(results)
+
async def init_weights_send_group_for_remote_instance(
self,
obj: InitWeightsSendGroupForRemoteInstanceReqInput,
@@ -324,8 +324,12 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi
context, zmq.PULL, port_args.tokenizer_ipc_name, True
)
if self.server_args.tokenizer_worker_num == 1:
+ self.send_to_scheduler_context = zmq.Context(1)
self.send_to_scheduler = get_zmq_socket(
- context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
+ self.send_to_scheduler_context,
+ zmq.PUSH,
+ port_args.scheduler_input_ipc_name,
+ True,
)
else:
from sglang.srt.managers.multi_tokenizer_mixin import SenderWrapper
@@ -1327,7 +1331,7 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi
async with self.is_pause_cond:
self.is_pause = True
if obj.mode != "abort":
- await self.send_to_scheduler.send_pyobj(obj)
+ self.send_to_scheduler.send_pyobj(obj)
else:
# we are using the model_update_lock to check if there is still on-going requests.
while True:
@@ -1341,7 +1345,7 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi
async def continue_generation(self, obj: ContinueGenerationReqInput):
async with self.is_pause_cond:
self.is_pause = False
- await self.send_to_scheduler.send_pyobj(obj)
+ self.send_to_scheduler.send_pyobj(obj)
self.is_pause_cond.notify_all()
async def update_weights_from_disk(
@@ -1510,6 +1514,40 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi
self._add_metric_if_present(
recv_obj, "prefill_finished_ts", meta_info, i
)
+ # PD disaggregation timing
+ self._add_metric_if_present(
+ recv_obj, "pd_prefill_bootstrap_queue_duration", meta_info, i
+ )
+ self._add_metric_if_present(
+ recv_obj, "pd_prefill_forward_duration", meta_info, i
+ )
+ self._add_metric_if_present(
+ recv_obj, "pd_prefill_transfer_queue_duration", meta_info, i
+ )
+ self._add_metric_if_present(
+ recv_obj, "pd_decode_prealloc_duration", meta_info, i
+ )
+ self._add_metric_if_present(
+ recv_obj, "pd_decode_transfer_duration", meta_info, i
+ )
+ self._add_metric_if_present(
+ recv_obj, "pd_decode_forward_duration", meta_info, i
+ )
+ self._add_metric_if_present(
+ recv_obj, "pd_bootstrap_duration", meta_info, i
+ )
+ self._add_metric_if_present(
+ recv_obj, "pd_alloc_waiting_duration", meta_info, i
+ )
+ self._add_metric_if_present(
+ recv_obj, "pd_transfer_speed_gb_s", meta_info, i
+ )
+ self._add_metric_if_present(
+ recv_obj, "pd_transfer_total_mb", meta_info, i
+ )
+ self._add_metric_if_present(
+ recv_obj, "pd_prefill_retry_count", meta_info, i
+ )
if getattr(state.obj, "return_logprob", False):
self.convert_logprob_style(
@@ -1955,19 +1993,17 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi
if custom_labels
else self.metrics_collector.labels
)
- if (
- state.first_token_time == 0.0
- and self.disaggregation_mode != DisaggregationMode.PREFILL
- ):
+ if state.first_token_time == 0.0:
state.first_token_time = state.last_time = time.time()
state.first_token_time_perf = time.perf_counter()
state.last_completion_tokens = completion_tokens
- self.metrics_collector.observe_time_to_first_token(
- labels, state.first_token_time - state.created_time
- )
+ if self.disaggregation_mode != DisaggregationMode.PREFILL:
+ self.metrics_collector.observe_time_to_first_token(
+ labels, state.first_token_time - state.created_time
+ )
else:
num_new_tokens = completion_tokens - state.last_completion_tokens
- if num_new_tokens:
+ if num_new_tokens > 0:
new_time = time.time()
interval = new_time - state.last_time
self.metrics_collector.observe_inter_token_latency(
@@ -1976,7 +2012,7 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi
num_new_tokens,
)
state.last_time = new_time
- state.last_completion_tokens = completion_tokens
+ state.last_completion_tokens = completion_tokens
if state.finished:
retraction_count = (
@@ -29,6 +29,7 @@ from sglang.srt.managers.io_struct import (
InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterFromTensorsReqInput,
LoadLoRAAdapterReqInput,
+ PostProcessWeightsReqInput,
SendWeightsToRemoteInstanceReqInput,
UnloadLoRAAdapterReqInput,
UpdateWeightFromDiskReqInput,
@@ -168,6 +169,11 @@ class BaseTpWorker(ABC):
success, message = self.model_runner.update_weights_from_ipc(recv_req)
return success, message
+ def post_process_weights(self, recv_req: PostProcessWeightsReqInput):
+ """Perform optional post-processing on the updated model weights (e.g., Marlin conversion)."""
+ success, message = self.model_runner.post_process_weights(recv_req)
+ return success, message
+
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.model_runner.get_weights_by_name(
recv_req.name, recv_req.truncate_size
@@ -347,6 +347,84 @@ def alloc_decode_kernel(
tl.store(out_indices + pid, page * page_size)
+def alloc_extend_torch_fallback(
+ prefix_lens_cpu: torch.Tensor,
+ seq_lens_cpu: torch.Tensor,
+ last_loc: torch.Tensor,
+ free_pages: torch.Tensor,
+ out_indices: torch.Tensor,
+ page_size: int,
+ debug_mode: bool = False,
+):
+ extend_lens_cpu = (seq_lens_cpu - prefix_lens_cpu).to(torch.int64)
+ if extend_lens_cpu.numel() == 0:
+ return
+
+ output_start_locs_cpu = torch.cumsum(extend_lens_cpu, dim=0) - extend_lens_cpu
+ num_pages_after = (seq_lens_cpu + page_size - 1) // page_size
+ num_pages_before = (prefix_lens_cpu + page_size - 1) // page_size
+ num_new_pages_cpu = num_pages_after - num_pages_before
+ page_start_locs_cpu = torch.cumsum(num_new_pages_cpu, dim=0) - num_new_pages_cpu
+
+ total_new_pages = int(num_new_pages_cpu.sum().item())
+ if total_new_pages > free_pages.numel():
+ return
+
+ if debug_mode:
+ assert int(extend_lens_cpu.sum().item()) == out_indices.numel()
+
+ prefix_lens_list = prefix_lens_cpu.tolist()
+ seq_lens_list = seq_lens_cpu.tolist()
+ extend_lens_list = extend_lens_cpu.tolist()
+ out_start_list = output_start_locs_cpu.tolist()
+ page_start_list = page_start_locs_cpu.tolist()
+ num_new_pages_list = num_new_pages_cpu.tolist()
+
+ device = out_indices.device
+ dtype = out_indices.dtype
+ offsets_page = torch.arange(page_size, device=device, dtype=dtype)
+
+ for i, extend_len in enumerate(extend_lens_list):
+ if extend_len == 0:
+ continue
+
+ pre_len = prefix_lens_list[i]
+ seq_len = seq_lens_list[i]
+ out_start = out_start_list[i]
+ page_start = page_start_list[i]
+ num_new_pages = num_new_pages_list[i]
+
+ pre_mod = pre_len % page_size
+ part1 = min(extend_len, page_size - pre_mod) if pre_mod != 0 else 0
+ if part1:
+ start_val = last_loc[i] + 1
+ out_indices[out_start : out_start + part1] = start_val + torch.arange(
+ part1, device=device, dtype=dtype
+ )
+ if part1 == extend_len:
+ continue
+
+ ceil_pre_pages = (pre_len + page_size - 1) // page_size
+ full_pages_after = seq_len // page_size
+ num_full_pages = full_pages_after - ceil_pre_pages
+ if num_full_pages < 0:
+ num_full_pages = 0
+ part2 = num_full_pages * page_size
+ if part2:
+ pages = free_pages[page_start : page_start + num_full_pages]
+ full_indices = (pages[:, None] * page_size + offsets_page).reshape(-1)
+ out_indices[out_start + part1 : out_start + part1 + part2] = full_indices
+ if part1 + part2 == extend_len:
+ continue
+
+ part3 = extend_len - part1 - part2
+ if part3:
+ last_page = free_pages[page_start + num_new_pages - 1]
+ out_indices[out_start + part1 + part2 : out_start + extend_len] = (
+ last_page * page_size + torch.arange(part3, device=device, dtype=dtype)
+ )
+
+
class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
"""
An allocator managing the indices to kv cache data.
@@ -411,7 +489,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
self.seen_max_num_extend_tokens_next_power_of_2 = max(
self.seen_max_num_extend_tokens_next_power_of_2,
- min(tl.core.TRITON_MAX_TENSOR_NUMEL, next_power_of_2(extend_num_tokens)),
+ min(65536, next_power_of_2(extend_num_tokens)),
)
bs = len(prefix_lens)
@@ -424,7 +502,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
(extend_num_tokens,), dtype=torch.int64, device=self.device
)
- if extend_num_tokens < tl.core.TRITON_MAX_TENSOR_NUMEL:
+ if extend_num_tokens < 65536:
alloc_extend_kernel[(bs,)](
prefix_lens,
seq_lens,
@@ -76,6 +76,7 @@ class HiRadixCache(RadixCache):
allocator_type=server_args.hicache_storage_backend,
)
elif isinstance(self.kv_cache, NSATokenToKVPool):
+ # Check NSA before MLA since NSATokenToKVPool is a subclass of MLATokenToKVPool
self.token_to_kv_pool_host = NSATokenToKVPoolHost(
self.kv_cache,
server_args.hicache_ratio,
@@ -94,7 +95,7 @@ class HiRadixCache(RadixCache):
allocator_type=server_args.hicache_storage_backend,
)
else:
- raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
+ raise ValueError(f"HiRadixCache only supports MHA and MLA and NSA yet")
self.tp_group = params.tp_cache_group
self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
@@ -750,9 +751,8 @@ class HiRadixCache(RadixCache):
self._update_leaf_status(node)
self._update_host_leaf_status(node)
if node.parent is None:
- assert (
- node is self.root_node
- ), f"This request holds the node from another tree"
+ # Node belongs to a stale (flushed) tree — stop traversal gracefully.
+ break
node = node.parent
return delta
@@ -827,6 +827,7 @@ class HiRadixCache(RadixCache):
self._update_host_leaf_status(node)
# update leaf status for the parent because the node is evicted
self._update_leaf_status(node.parent)
+ self._update_host_leaf_status(node.parent)
return num_evicted
def _evict_regular(self, node: TreeNode):
@@ -1330,6 +1331,7 @@ class HiRadixCache(RadixCache):
self._update_host_leaf_status(node)
# update parent status as a new leaf is added into device
self._update_leaf_status(node.parent)
+ self._update_host_leaf_status(node.parent)
else:
self._inc_hit_count(node, chunked)
total_prefix_length += prefix_len
@@ -1345,6 +1347,7 @@ class HiRadixCache(RadixCache):
self._update_host_leaf_status(new_node)
# update parent status as a new leaf is added into device
self._update_leaf_status(new_node.parent)
+ self._update_host_leaf_status(new_node.parent)
else:
self._inc_hit_count(new_node, chunked)
total_prefix_length += prefix_len
@@ -1777,9 +1777,12 @@ class NSATokenToKVPool(MLATokenToKVPool):
else:
assert self.page_size == 64
with (
- torch.cuda.use_mem_pool(self.custom_mem_pool)
- if self.custom_mem_pool
- else nullcontext()
+ (
+ torch.cuda.use_mem_pool(self.custom_mem_pool)
+ if self.custom_mem_pool
+ else nullcontext()
+ ),
+ self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE),
):
self.index_k_with_scale_buffer = [
torch.zeros(
@@ -1801,6 +1804,11 @@ class NSATokenToKVPool(MLATokenToKVPool):
)
for _ in range(layer_num)
]
+ self.index_k_with_scale_buffer_ptrs = torch.tensor(
+ [x.data_ptr() for x in self.index_k_with_scale_buffer],
+ dtype=torch.uint64,
+ device=self.device,
+ )
self._finalize_allocation_log(size)
def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor:
@@ -1876,6 +1884,50 @@ class NSATokenToKVPool(MLATokenToKVPool):
]
return data_ptrs, data_lens, item_lens
+ def get_cpu_copy(self, indices):
+ # First, save the kv_buffer (inherited from MLATokenToKVPool)
+ kv_cache_cpu = super().get_cpu_copy(indices)
+
+ # Additionally, save the index_k_with_scale_buffer (page-indexed)
+ page_indices = indices[:: self.page_size] // self.page_size
+ torch.cuda.synchronize()
+ index_k_cpu = []
+ chunk_size = self.cpu_offloading_chunk_size
+ # Convert chunk_size from token-level to page-level
+ page_chunk_size = max(1, chunk_size // self.page_size)
+ for layer_id in range(self.layer_num):
+ index_k_cpu.append([])
+ for i in range(0, len(page_indices), page_chunk_size):
+ chunk_page_indices = page_indices[i : i + page_chunk_size]
+ idx_cpu = self.index_k_with_scale_buffer[layer_id][
+ chunk_page_indices
+ ].to("cpu", non_blocking=True)
+ index_k_cpu[-1].append(idx_cpu)
+ torch.cuda.synchronize()
+
+ return {"kv": kv_cache_cpu, "index_k": index_k_cpu}
+
+ def load_cpu_copy(self, kv_cache_cpu_dict, indices):
+ # Restore the kv_buffer (inherited from MLATokenToKVPool)
+ super().load_cpu_copy(kv_cache_cpu_dict["kv"], indices)
+
+ # Restore the index_k_with_scale_buffer (page-indexed)
+ page_indices = indices[:: self.page_size] // self.page_size
+ index_k_cpu = kv_cache_cpu_dict["index_k"]
+ torch.cuda.synchronize()
+ chunk_size = self.cpu_offloading_chunk_size
+ page_chunk_size = max(1, chunk_size // self.page_size)
+ for layer_id in range(self.layer_num):
+ for i in range(0, len(page_indices), page_chunk_size):
+ chunk_page_indices = page_indices[i : i + page_chunk_size]
+ idx_cpu = index_k_cpu[layer_id][i // page_chunk_size]
+ assert idx_cpu.shape[0] == len(chunk_page_indices)
+ idx_chunk = idx_cpu.to(
+ self.index_k_with_scale_buffer[0].device, non_blocking=True
+ )
+ self.index_k_with_scale_buffer[layer_id][chunk_page_indices] = idx_chunk
+ torch.cuda.synchronize()
+
def get_kv_size_bytes(self):
kv_size_bytes = super().get_kv_size_bytes()
for index_k_cache in self.index_k_with_scale_buffer:
@@ -495,7 +495,17 @@ class RadixCache(BasePrefixCache):
if self.disable:
return
- token_ids = req.fill_ids
+ # Limit to kv_committed_len to avoid including tokens (e.g., the just-generated
+ # token in disagg prefill) that don't have computed KV yet. If fill_ids is longer
+ # than kv_committed_len, the extra tokens would produce stale values (0 from
+ # req_to_token_pool initialization), leading to spurious tree nodes and memory
+ # leak when page-aligned token counts happen to cross a page boundary.
+ kv_committed_len = req.kv_committed_len
+ token_ids = (
+ req.fill_ids[:kv_committed_len]
+ if kv_committed_len < len(req.fill_ids)
+ else req.fill_ids
+ )
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
]
@@ -619,9 +629,8 @@ class RadixCache(BasePrefixCache):
node.lock_ref -= 1
self._update_leaf_status(node)
if node.parent is None:
- assert (
- node is self.root_node
- ), f"This request holds the node from another tree"
+ # Node belongs to a stale (flushed) tree — stop traversal gracefully.
+ break
node = node.parent
return delta
@@ -20,7 +20,10 @@ import time
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
-from sglang.srt.disaggregation.utils import DisaggregationMode
+from sglang.srt.disaggregation.utils import (
+ DisaggregationMode,
+ is_slime_profiling_enabled,
+)
from sglang.srt.environ import envs
from sglang.srt.metrics.utils import exponential_buckets, generate_buckets
from sglang.srt.model_executor.forward_batch_info import ForwardMode
@@ -77,6 +80,17 @@ class TimeStats:
# Number of prefill retries for this request
prefill_retry_count: int = 0
+ # Prefill-side durations forwarded via metadata transfer from P to D instance.
+ # Set on the decode instance after KV cache transfer completes.
+ fwd_prefill_bootstrap_queue_duration: Optional[float] = None
+ fwd_prefill_forward_duration: Optional[float] = None
+ fwd_prefill_transfer_queue_duration: Optional[float] = None
+ fwd_bootstrap_duration: Optional[float] = None
+ fwd_alloc_waiting_duration: Optional[float] = None
+ fwd_transfer_speed_gb_s: Optional[float] = None
+ fwd_transfer_total_mb: Optional[float] = None
+ fwd_prefill_retry_count: Optional[int] = None
+
# Timestamp when prefill phase finishes, obtained from `time.time()`.
# Note that this differs from the other `_time` fields tracked by the
# `TimeStats` class, which are obtained from `time.perf_counter()`.
@@ -102,6 +116,148 @@ class TimeStats:
return self.prefill_finished_ts
return None
+ # --- PD disaggregation timing getters ---
+
+ def get_pd_prefill_bootstrap_queue_duration(self) -> Optional[float]:
+ """P instance: time spent in bootstrap queue before entering the wait queue."""
+ if not is_slime_profiling_enabled():
+ return None
+ if self.fwd_prefill_bootstrap_queue_duration is not None:
+ return self.fwd_prefill_bootstrap_queue_duration
+ if (
+ self.disagg_mode == DisaggregationMode.PREFILL
+ and self.prefill_bootstrap_queue_entry_time > 0.0
+ and self.wait_queue_entry_time > 0.0
+ ):
+ return self.wait_queue_entry_time - self.prefill_bootstrap_queue_entry_time
+ return None
+
+ def get_pd_prefill_forward_duration(self) -> Optional[float]:
+ """P instance: time for the actual prefill forward computation."""
+ if not is_slime_profiling_enabled():
+ return None
+ if self.fwd_prefill_forward_duration is not None:
+ return self.fwd_prefill_forward_duration
+ if (
+ self.disagg_mode == DisaggregationMode.PREFILL
+ and self.forward_entry_time > 0.0
+ and self.completion_time > 0.0
+ ):
+ return self.completion_time - self.forward_entry_time
+ return None
+
+ def get_pd_prefill_transfer_queue_duration(self) -> Optional[float]:
+ """P instance: time spent in the transfer queue (KV cache send)."""
+ if not is_slime_profiling_enabled():
+ return None
+ if self.fwd_prefill_transfer_queue_duration is not None:
+ return self.fwd_prefill_transfer_queue_duration
+ if (
+ self.disagg_mode == DisaggregationMode.PREFILL
+ and self.prefill_transfer_queue_entry_time > 0.0
+ and self.completion_time > 0.0
+ ):
+ return self.completion_time - self.prefill_transfer_queue_entry_time
+ return None
+
+ def get_pd_decode_prealloc_duration(self) -> Optional[float]:
+ """D instance: time spent in the pre-alloc queue (waiting for KV cache slot allocation)."""
+ if not is_slime_profiling_enabled():
+ return None
+ if (
+ self.disagg_mode == DisaggregationMode.DECODE
+ and self.decode_prealloc_queue_entry_time > 0.0
+ and self.decode_transfer_queue_entry_time > 0.0
+ ):
+ return (
+ self.decode_transfer_queue_entry_time
+ - self.decode_prealloc_queue_entry_time
+ )
+ return None
+
+ def get_pd_decode_transfer_duration(self) -> Optional[float]:
+ """D instance: time spent waiting for KV cache transfer to complete."""
+ if not is_slime_profiling_enabled():
+ return None
+ if (
+ self.disagg_mode == DisaggregationMode.DECODE
+ and self.decode_transfer_queue_entry_time > 0.0
+ and self.wait_queue_entry_time > 0.0
+ ):
+ return self.wait_queue_entry_time - self.decode_transfer_queue_entry_time
+ return None
+
+ def get_pd_decode_forward_duration(self) -> Optional[float]:
+ """D instance: time for the actual decode forward computation."""
+ if not is_slime_profiling_enabled():
+ return None
+ if (
+ self.disagg_mode == DisaggregationMode.DECODE
+ and self.forward_entry_time > 0.0
+ and self.completion_time > 0.0
+ ):
+ return self.completion_time - self.forward_entry_time
+ return None
+
+ def get_pd_bootstrap_duration(self) -> Optional[float]:
+ """Bootstrap handshake duration (both P and D instances)."""
+ if not is_slime_profiling_enabled():
+ return None
+ if self.fwd_bootstrap_duration is not None:
+ return self.fwd_bootstrap_duration
+ if (
+ self.disagg_mode != DisaggregationMode.NULL
+ and self.bootstrap_duration > 0.0
+ ):
+ return self.bootstrap_duration
+ return None
+
+ def get_pd_alloc_waiting_duration(self) -> Optional[float]:
+ """KV cache allocation waiting duration (both P and D instances)."""
+ if not is_slime_profiling_enabled():
+ return None
+ if self.fwd_alloc_waiting_duration is not None:
+ return self.fwd_alloc_waiting_duration
+ if (
+ self.disagg_mode != DisaggregationMode.NULL
+ and self.alloc_waiting_duration > 0.0
+ ):
+ return self.alloc_waiting_duration
+ return None
+
+ def get_pd_transfer_speed_gb_s(self) -> Optional[float]:
+ """KV cache transfer speed in GB/s."""
+ if not is_slime_profiling_enabled():
+ return None
+ if self.fwd_transfer_speed_gb_s is not None:
+ return self.fwd_transfer_speed_gb_s
+ if (
+ self.disagg_mode != DisaggregationMode.NULL
+ and self.transfer_speed_gb_s > 0.0
+ ):
+ return self.transfer_speed_gb_s
+ return None
+
+ def get_pd_transfer_total_mb(self) -> Optional[float]:
+ """Total KV cache transferred in MB."""
+ if not is_slime_profiling_enabled():
+ return None
+ if self.fwd_transfer_total_mb is not None:
+ return self.fwd_transfer_total_mb
+ if self.disagg_mode != DisaggregationMode.NULL and self.transfer_total_mb > 0.0:
+ return self.transfer_total_mb
+ return None
+
+ def get_pd_prefill_retry_count(self) -> Optional[int]:
+ """Number of prefill retries for this request."""
+ if not is_slime_profiling_enabled():
+ return None
+ if self.fwd_prefill_retry_count is not None:
+ return self.fwd_prefill_retry_count
+ if self.disagg_mode == DisaggregationMode.PREFILL:
+ return self.prefill_retry_count
+ return None
+
def convert_to_duration(self) -> str:
if self.disagg_mode == DisaggregationMode.NULL:
queue_duration = self.forward_entry_time - self.wait_queue_entry_time
@@ -909,6 +909,28 @@ class ForwardBatch(ForwardBatchDeepSeekMHAMixin):
tokens_padded = (tokens + rank_size - 1) // rank_size * rank_size
self._pad_inputs_to_size(model_runner, tokens_padded, self.batch_size)
+ def prepare_cp_padding(self, model_runner: ModelRunner):
+ """Pad input_ids and extend_num_tokens to CP size multiples.
+
+ In the PP disagg prefill + CP path, MLP sync is skipped so
+ prepare_mlp_sync_batch never runs. This method performs the
+ subset of padding that CP collective communication requires:
+ input_ids (and related tensors) must be divisible by cp_size.
+ """
+ attn_cp_size = get_attention_cp_size()
+ if attn_cp_size <= 1:
+ return
+ if not self.forward_mode.is_extend():
+ return
+
+ tokens = self.input_ids.shape[0]
+ tokens_padded = ceil_align(tokens, attn_cp_size)
+ if tokens_padded == tokens:
+ return
+
+ self._pad_inputs_to_size(model_runner, tokens_padded, self.batch_size)
+ self.extend_num_tokens = tokens_padded
+
def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput):
self.forward_mode = getattr(self, "_original_forward_mode", self.forward_mode)
@@ -395,7 +395,12 @@ class ModelRunner(ModelRunnerKVCacheMixin):
self.forward_stream = torch.get_device_module(self.device).Stream()
# CPU offload
- set_offloader(create_offloader_from_server_args(server_args, dp_rank=dp_rank))
+ # For draft worker (e.g., MTP), do not set offloader to avoid overriding
+ # the main model's offloader. Draft worker uses NoopOffloader instead.
+ if not is_draft_worker:
+ set_offloader(
+ create_offloader_from_server_args(server_args, dp_rank=dp_rank)
+ )
self._weight_checker = WeightChecker(model_runner=self)
@@ -600,7 +605,8 @@ class ModelRunner(ModelRunnerKVCacheMixin):
)
# Init routed experts capturer
- self.init_routed_experts_capturer()
+ if not self.is_draft_worker:
+ self.init_routed_experts_capturer()
if self.device == "cuda" or self.device == "musa":
self.init_cublas()
@@ -2429,11 +2435,19 @@ class ModelRunner(ModelRunnerKVCacheMixin):
output.expert_distribution_metrics = recorder_outputs.get("metrics")
# Copy cached routing experts' buffers back to CPU cache
- get_global_experts_capturer().on_forward_end(
- forward_batch=forward_batch,
- can_run_graph=output.can_run_graph,
- cuda_graph_batch=getattr(self.graph_runner, "bs", None),
- )
+ if not self.is_draft_worker:
+ # In speculative decoding, num_tokens_per_bs > 1, so we need to pass
+ # the actual number of tokens per dp rank in cuda graph, not batch size.
+ cuda_graph_num_tokens = None
+ if getattr(self.graph_runner, "bs", None):
+ cuda_graph_num_tokens = (
+ self.graph_runner.bs * self.graph_runner.num_tokens_per_bs
+ )
+ get_global_experts_capturer().on_forward_end(
+ forward_batch=forward_batch,
+ can_run_graph=output.can_run_graph,
+ cuda_graph_batch=cuda_graph_num_tokens,
+ )
if self.eplb_manager is not None:
self.eplb_manager.on_forward_pass_end()
@@ -2472,6 +2486,9 @@ class ModelRunner(ModelRunnerKVCacheMixin):
forward_batch.prepare_mlp_sync_batch(self)
else:
forward_batch.prepare_attn_tp_scatter_input(self)
+ # In PP disagg prefill + CP, MLP sync is skipped so CP padding
+ # must be done separately to keep input_ids divisible by cp_size.
+ forward_batch.prepare_cp_padding(self)
# Normalize num_token_non_padded to be local to this attention TP rank if needed.
if (
@@ -2664,6 +2681,42 @@ class ModelRunner(ModelRunnerKVCacheMixin):
device=self.device,
)
+ def post_process_weights(self, recv_req):
+ """
+ Execute post-processing logic for model weights, such as Marlin quantization format conversion.
+ """
+ from sglang.srt.model_loader.loader import device_loading_context
+
+ target_device = torch.device("cuda", torch.cuda.current_device())
+
+ if recv_req.restore_weights_before_load:
+ for _, module in self.model.named_modules():
+ quant_method = getattr(module, "quant_method", None)
+
+ # Check if the module supports restoring weights
+ if quant_method is not None and hasattr(
+ quant_method, "restore_weights_before_loading"
+ ):
+
+ with device_loading_context(module, target_device):
+ quant_method.restore_weights_before_loading(module)
+
+ if recv_req.post_process_quantization:
+ # Iterate through all modules to apply specific post-loading processing
+ for _, module in self.model.named_modules():
+ quant_method = getattr(module, "quant_method", None)
+
+ # Check if the module supports quantization post-processing
+ if quant_method is not None and hasattr(
+ quant_method, "process_weights_after_loading"
+ ):
+
+ # Apply the post-processing (e.g., repacking weights for Marlin kernel)
+ with device_loading_context(module, target_device):
+ quant_method.process_weights_after_loading(module)
+
+ return True, "Success"
+
def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
params_dict = dict(model.named_parameters())
@@ -1,4 +1,5 @@
from sglang.srt.compilation.piecewise_context_manager import is_in_piecewise_cuda_graph
+from sglang.srt.layers.attention.hybrid_attn_backend import HybridAttnBackend
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
from sglang.srt.models.deepseek_common.attention_forward_methods.forward_methods import (
AttnForwardMethod,
@@ -150,6 +151,8 @@ def handle_attention_nsa(attn, forward_batch):
backend = forward_batch.attn_backend
if isinstance(backend, TboAttnBackend): # if enable tbo, get primary backend
backend = backend.primary
+ if isinstance(backend, HybridAttnBackend):
+ backend = backend._select_backend(forward_batch.forward_mode)
if hasattr(backend, "use_mha") and backend.use_mha:
return AttnForwardMethod.MHA_ONE_SHOT
return AttnForwardMethod.MLA
@@ -29,6 +29,7 @@ from sglang.srt.layers.attention.nsa.utils import (
can_cp_split,
cp_all_gather_rerange_output,
cp_split_and_rebuild_data,
+ cp_split_and_rebuild_position,
is_nsa_enable_prefill_cp,
nsa_use_prefill_cp,
prepare_input_dp_with_cp_dsa,
@@ -160,15 +161,17 @@ class DeepseekModelNextN(nn.Module):
if nsa_use_prefill_cp(forward_batch, self.nsa_enable_prefill_cp):
hidden_states = cp_split_and_rebuild_data(forward_batch, hidden_states)
+ positions = cp_split_and_rebuild_position(forward_batch, positions)
residual = None
with get_global_expert_distribution_recorder().disable_this_region():
- hidden_states, residual = self.decoder(
+ hidden_states, residual, *rest = self.decoder(
positions,
hidden_states,
forward_batch,
residual,
zero_allocator,
)
+ topk_indices = rest[0] if rest else None
if not forward_batch.forward_mode.is_idle():
if residual is not None:
@@ -1085,6 +1085,7 @@ class DeepseekV2AttentionMLA(nn.Module, DeepseekMHAForwardMixin):
prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
skip_rope: bool = False,
+ is_nextn: bool = False,
) -> None:
super().__init__()
self.layer_id = layer_id
@@ -1154,6 +1155,8 @@ class DeepseekV2AttentionMLA(nn.Module, DeepseekMHAForwardMixin):
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
)
+ self.skip_topk = False
+ self.next_skip_topk = False
if self.use_nsa:
is_neox_style = not getattr(config, "indexer_rope_interleave", False)
self.indexer = Indexer(
@@ -1174,6 +1177,31 @@ class DeepseekV2AttentionMLA(nn.Module, DeepseekMHAForwardMixin):
layer_id=layer_id,
alt_stream=alt_stream,
)
+ if not is_nextn:
+ self.index_topk_freq = getattr(config, "index_topk_freq", 1)
+ self.index_topk_pattern = getattr(config, "index_topk_pattern", None)
+ self.index_skip_topk_offset = getattr(
+ config, "index_skip_topk_offset", 2
+ )
+ if self.index_topk_pattern is None:
+ self.skip_topk = (
+ max(layer_id - self.index_skip_topk_offset + 1, 0)
+ % self.index_topk_freq
+ != 0
+ )
+ self.next_skip_topk = (
+ max(layer_id - self.index_skip_topk_offset + 2, 0)
+ % self.index_topk_freq
+ != 0
+ )
+ else:
+ self.skip_topk = self.index_topk_pattern[layer_id] == "S"
+ if layer_id < len(self.index_topk_pattern) - 1:
+ self.next_skip_topk = (
+ self.index_topk_pattern[layer_id + 1] == "S"
+ )
+ else:
+ self.next_skip_topk = False
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
@@ -1362,6 +1390,7 @@ class DeepseekV2AttentionMLA(nn.Module, DeepseekMHAForwardMixin):
forward_batch: ForwardBatch,
zero_allocator: BumpAllocator,
llama_4_scaling: Optional[torch.Tensor] = None,
+ prev_topk_indices: Optional[torch.Tensor] = None,
):
s = self.forward_prepare(
positions=positions,
@@ -1369,6 +1398,7 @@ class DeepseekV2AttentionMLA(nn.Module, DeepseekMHAForwardMixin):
forward_batch=forward_batch,
zero_allocator=zero_allocator,
llama_4_scaling=llama_4_scaling,
+ prev_topk_indices=prev_topk_indices,
)
return self.forward_core(s)
@@ -1379,6 +1409,7 @@ class DeepseekV2AttentionMLA(nn.Module, DeepseekMHAForwardMixin):
forward_batch: ForwardBatch,
zero_allocator: BumpAllocator,
llama_4_scaling: Optional[torch.Tensor] = None,
+ prev_topk_indices: Optional[torch.Tensor] = None,
):
if self.attn_mha.kv_b_proj is None:
self.attn_mha.kv_b_proj = self.kv_b_proj
@@ -1418,7 +1449,12 @@ class DeepseekV2AttentionMLA(nn.Module, DeepseekMHAForwardMixin):
)
elif attn_forward_method == AttnForwardMethod.MLA:
inner_state = self.forward_absorb_prepare(
- positions, hidden_states, forward_batch, zero_allocator, llama_4_scaling
+ positions,
+ hidden_states,
+ forward_batch,
+ zero_allocator,
+ llama_4_scaling,
+ prev_topk_indices,
)
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
inner_state = self.forward_absorb_fused_mla_rope_prepare(
@@ -1529,6 +1565,7 @@ class DeepseekV2AttentionMLA(nn.Module, DeepseekMHAForwardMixin):
forward_batch: ForwardBatch,
zero_allocator: BumpAllocator,
llama_4_scaling: Optional[torch.Tensor] = None,
+ prev_topk_indices: Optional[torch.Tensor] = None,
):
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
@@ -1620,18 +1657,7 @@ class DeepseekV2AttentionMLA(nn.Module, DeepseekMHAForwardMixin):
q = self.q_b_proj(q)[0].view(
-1, self.num_local_heads, self.qk_head_dim
)
- topk_indices = self.indexer(
- x=hidden_states,
- q_lora=q_lora,
- positions=positions,
- forward_batch=forward_batch,
- layer_id=self.layer_id,
- )
- current_stream.wait_stream(self.alt_stream)
- else:
- k_nope = k_nope.unsqueeze(1)
- q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
- if q_lora is not None:
+ if not self.skip_topk:
topk_indices = self.indexer(
x=hidden_states,
q_lora=q_lora,
@@ -1639,6 +1665,23 @@ class DeepseekV2AttentionMLA(nn.Module, DeepseekMHAForwardMixin):
forward_batch=forward_batch,
layer_id=self.layer_id,
)
+ else:
+ topk_indices = prev_topk_indices
+ current_stream.wait_stream(self.alt_stream)
+ else:
+ k_nope = k_nope.unsqueeze(1)
+ q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
+ if q_lora is not None:
+ if not self.skip_topk:
+ topk_indices = self.indexer(
+ x=hidden_states,
+ q_lora=q_lora,
+ positions=positions,
+ forward_batch=forward_batch,
+ layer_id=self.layer_id,
+ )
+ else:
+ topk_indices = prev_topk_indices
else:
q = self.q_proj(hidden_states)[0].view(
-1, self.num_local_heads, self.qk_head_dim
@@ -1929,8 +1972,10 @@ class DeepseekV2AttentionMLA(nn.Module, DeepseekMHAForwardMixin):
).transpose(0, 1),
)
output, _ = self.o_proj(attn_bmm_output)
-
- return output
+ if not self.next_skip_topk:
+ return output, None
+ else:
+ return output, topk_indices
def forward_absorb_fused_mla_rope_prepare(
self,
@@ -2275,6 +2320,7 @@ class DeepseekV2DecoderLayer(nn.Module):
reduce_results=False,
prefix=add_prefix("self_attn", prefix),
alt_stream=alt_stream,
+ is_nextn=is_nextn,
)
self.is_layer_sparse = self._is_layer_sparse(layer_id, is_nextn=is_nextn)
@@ -2357,6 +2403,7 @@ class DeepseekV2DecoderLayer(nn.Module):
zero_allocator: BumpAllocator,
gemm_output_zero_allocator: BumpAllocator = None,
llama_4_scaling: Optional[torch.Tensor] = None,
+ prev_topk_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
quant_format = (
"mxfp4"
@@ -2398,7 +2445,12 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch=forward_batch,
zero_allocator=zero_allocator,
llama_4_scaling=llama_4_scaling,
+ prev_topk_indices=prev_topk_indices,
)
+ if isinstance(hidden_states, tuple):
+ hidden_states, topk_indices = hidden_states
+ else:
+ topk_indices = None
hidden_states, residual = self.layer_communicator.prepare_mlp(
hidden_states, residual, forward_batch
@@ -2434,7 +2486,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states, residual, forward_batch
)
- return hidden_states, residual
+ return hidden_states, residual, topk_indices
def op_comm_prepare_attn(
self,
@@ -2710,6 +2762,7 @@ class DeepseekV2Model(nn.Module):
elif self.first_k_dense_replace < normal_start_layer:
normal_end_layer = normal_start_layer = 0
aux_hidden_states = []
+ topk_indices = None
for i in range(normal_start_layer, normal_end_layer):
# NOTE: torch dynamo does not support graph break in context manager
ctx = (
@@ -2727,7 +2780,7 @@ class DeepseekV2Model(nn.Module):
else:
aux_hidden_states.append(hidden_states + residual)
layer = self.layers[i]
- hidden_states, residual = layer(
+ hidden_states, residual, *rest = layer(
positions,
hidden_states,
forward_batch,
@@ -2735,7 +2788,9 @@ class DeepseekV2Model(nn.Module):
zero_allocator,
gemm_output_zero_allocator,
llama_4_scaling,
+ prev_topk_indices=topk_indices,
)
+ topk_indices = rest[0] if rest else None
if normal_end_layer != self.end_layer:
hidden_states, residual = model_forward_maybe_tbo(
@@ -678,8 +678,13 @@ class Glm4MoeDecoderLayer(nn.Module):
nn.Module.__init__(self)
self.hidden_size = config.hidden_size
self.config = config
- rope_theta = getattr(config, "rope_theta", 10000)
- rope_scaling = getattr(config, "rope_scaling", None)
+ # rope_theta may be stored in rope_parameters dict (e.g. GLM-4.6V)
+ _rope_params = getattr(config, "rope_parameters", None)
+ if isinstance(_rope_params, dict) and "rope_theta" in _rope_params:
+ rope_theta = _rope_params["rope_theta"]
+ else:
+ rope_theta = getattr(config, "rope_theta", 10000)
+ rope_scaling = getattr(config, "rope_scaling", None) or _rope_params
partial_rotary_factor = getattr(
getattr(config, "rope_parameters", None), "partial_rotary_factor", None
) or getattr(config, "partial_rotary_factor", 0.5)
@@ -773,6 +778,7 @@ class Glm4MoeDecoderLayer(nn.Module):
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
+ **kwargs,
) -> torch.Tensor:
hidden_states, residual = self.layer_communicator.prepare_attn(
@@ -103,7 +103,7 @@ class Glm4MoeModelNextN(nn.Module):
residual = None
with get_global_expert_distribution_recorder().disable_this_region():
- hidden_states, residual = self.decoder(
+ hidden_states, residual, *rest = self.decoder(
positions, hidden_states, forward_batch, residual
)
@@ -52,11 +52,31 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
self.num_fused_shared_experts = 0
self.determine_num_fused_shared_experts()
- self.model = Glm4MoeModel(
- config,
- quant_config,
- prefix=add_prefix("language_model", prefix),
- )
+ if not self.config.encoder_only:
+ self.model = Glm4MoeModel(
+ config,
+ quant_config,
+ prefix=add_prefix("language_model", prefix),
+ )
+
+ if self.pp_group.is_last_rank:
+ if self.pp_group.world_size == 1 and self.config.tie_word_embeddings:
+ self.lm_head = self.model.embed_tokens
+ else:
+ self.lm_head = ParallelLMHead(
+ config.vocab_size,
+ config.hidden_size,
+ quant_config=quant_config,
+ prefix=add_prefix("lm_head", prefix),
+ use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
+ )
+ else:
+ # ranks other than the last rank will have a placeholder layer
+ self.lm_head = PPMissingLayer()
+ else:
+ # encoder_only mode: no language model, so no lm_head needed
+ self.lm_head = None
+
self.visual = Glm4vVisionModel(
config.vision_config,
quant_config=quant_config,
@@ -64,24 +84,14 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
use_data_parallel=self.use_data_parallel,
)
- if self.pp_group.is_last_rank:
- if self.pp_group.world_size == 1 and self.config.tie_word_embeddings:
- self.lm_head = self.model.embed_tokens
- else:
- self.lm_head = ParallelLMHead(
- config.vocab_size,
- config.hidden_size,
- quant_config=quant_config,
- prefix=add_prefix("lm_head", prefix),
- use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
- )
- else:
- # ranks other than the last rank will have a placeholder layer
- self.lm_head = PPMissingLayer()
-
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
- self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
+ _rope_cfg = (
+ getattr(self.config, "rope_scaling", None)
+ or getattr(self.config, "rope_parameters", None)
+ or {}
+ )
+ self.is_mrope_enabled = "mrope_section" in _rope_cfg
# For EAGLE3 support
self.capture_aux_hidden_states = False
@@ -219,6 +229,11 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
+ # Skip loading visual/language model weights
+ if (
+ self.config.encoder_only or self.config.language_only
+ ) and name not in params_dict:
+ continue
if name not in params_dict:
continue
@@ -234,6 +249,8 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
+ if "visual" in name or self.config.encoder_only:
+ continue
# Mark as expert weight regardless of whether we can process it
is_expert_weight = True
@@ -265,6 +282,11 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
+ # Skip loading mm/language parameters
+ if (
+ self.config.encoder_only or self.config.language_only
+ ) and name not in params_dict:
+ continue
if name not in params_dict:
continue
@@ -17,6 +17,7 @@
import logging
import math
+import re
from collections.abc import Iterable
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Union
@@ -1065,6 +1066,12 @@ class GptOssForCausalLM(nn.Module):
weight_loader(param, loaded_weight, shard_id)
break
else:
+ # Try per-expert format: experts.{id}.{gate_proj|up_proj|down_proj}.{weight|bias}
+ per_expert_match = _PER_EXPERT_RE.match(name)
+ if per_expert_match:
+ _load_per_expert_param(per_expert_match, loaded_weight, params_dict)
+ continue
+
for mapping in expert_params_mapping:
param_name, weight_name, shard_id = mapping
if weight_name not in name:
@@ -1143,6 +1150,88 @@ class GptOssForCausalLM(nn.Module):
return get_attention_sliding_window_size(self.config)
+# Regex for per-expert weight names: model.layers.X.mlp.experts.E.{proj}.{weight|bias}
+_PER_EXPERT_RE = re.compile(
+ r"(.+\.mlp\.experts\.)(\d+)\.(gate_proj|up_proj|down_proj)\.(weight|bias)"
+)
+
+
+def _load_per_expert_param(match, loaded_weight, params_dict):
+ """Load a per-expert weight/bias tensor into the fused FusedMoE parameter.
+
+ Handles the mapping from per-expert names (e.g., experts.0.gate_proj.weight)
+ to fused parameters (e.g., experts.w13_weight).
+ """
+ prefix, eid_str, proj, ptype = match.groups()
+ eid = int(eid_str)
+
+ # Determine target fused parameter name
+ if proj in ("gate_proj", "up_proj"):
+ key = prefix + ("w13_weight" if ptype == "weight" else "w13_weight_bias")
+ else: # down_proj
+ key = prefix + ("w2_weight" if ptype == "weight" else "w2_weight_bias")
+
+ if key not in params_dict:
+ return
+
+ param = params_dict[key]
+ expert_slice = param.data[eid] # slice for this expert
+
+ # Detect triton transposed layout from shape:
+ # w13: transposed=(E, hidden, 2*inter), non-transposed=(E, 2*inter, hidden)
+ # For w13, the larger dim is 2*intermediate; if it's dim -1, layout is transposed.
+ is_transposed = getattr(param, "is_transposed", False)
+ if not is_transposed and ptype == "weight" and "w13" in key:
+ # Infer from shape: transposed has shape[-1] > shape[-2]
+ is_transposed = param.data.shape[-1] > param.data.shape[-2]
+
+ if ptype == "weight":
+ if proj in ("gate_proj", "up_proj"):
+ # w13_weight: gate in first half, up in second half
+ if is_transposed:
+ # Triton layout: (E, hidden, 2*intermediate)
+ half = expert_slice.shape[1] // 2
+ dst = (
+ expert_slice[:, :half]
+ if proj == "gate_proj"
+ else expert_slice[:, half:]
+ )
+ else:
+ # Standard layout: (E, 2*intermediate, hidden)
+ half = expert_slice.shape[0] // 2
+ dst = (
+ expert_slice[:half] if proj == "gate_proj" else expert_slice[half:]
+ )
+ # loaded_weight shape: (intermediate, hidden)
+ if is_transposed:
+ dst.copy_(loaded_weight.t())
+ else:
+ dst.copy_(loaded_weight)
+ else:
+ # w2_weight: loaded_weight shape (hidden, intermediate)
+ # Detect transposition for w2 as well
+ w2_transposed = is_transposed
+ if not w2_transposed:
+ w2_transposed = param.data.shape[-1] > param.data.shape[-2]
+ if w2_transposed:
+ expert_slice.copy_(loaded_weight.t())
+ else:
+ expert_slice.copy_(loaded_weight)
+ else:
+ # Bias handling
+ if proj in ("gate_proj", "up_proj"):
+ # w13_weight_bias: (E, 2*intermediate)
+ half = expert_slice.shape[0] // 2
+ dst = expert_slice[:half] if proj == "gate_proj" else expert_slice[half:]
+ dst.copy_(loaded_weight)
+ else:
+ # w2_weight_bias: (E, hidden) - only rank 0 loads, others zero
+ if get_moe_tensor_parallel_rank() == 0:
+ expert_slice.copy_(loaded_weight)
+ else:
+ expert_slice.zero_()
+
+
def _canonicalize_weights(config, weights_in: Iterable[Tuple[str, torch.Tensor]]):
weights_out_dict = dict(weights_in)
@@ -666,25 +666,30 @@ class KimiK25ForConditionalGeneration(nn.Module):
self.config = config
self.quant_config = quant_config
self.use_data_parallel = get_global_server_args().mm_enable_dp_encoder
- # Create vision tower
- self.vision_tower = MoonViT3dPretrainedModel(
- config.vision_config, use_data_parallel=self.use_data_parallel
- )
- # Create mm projector
- self.mm_projector = K2VLMultiModalProjector(config.vision_config)
- self.language_model = DeepseekV3ForCausalLM(config.text_config, quant_config)
+ # EPD: conditionally create components based on encoder_only / language_only
+ if not getattr(self.config, "language_only", False):
+ # Create vision tower and mm projector (needed for encoder_only and normal mode)
+ self.vision_tower = MoonViT3dPretrainedModel(
+ config.vision_config, use_data_parallel=self.use_data_parallel
+ )
+ self.mm_projector = K2VLMultiModalProjector(config.vision_config)
+
+ if not getattr(self.config, "encoder_only", False):
+ # Create language model (needed for language_only and normal mode)
+ self.language_model = DeepseekV3ForCausalLM(
+ config.text_config, quant_config
+ )
- # Ensure that the dtype of the vision_tower and mm_projector matches that of the language_model.
- # This solves the dtype mismatch issue when using device_map="auto" and torch_dtype.
- if hasattr(self.language_model, "dtype"):
- target_dtype = self.language_model.dtype
- self.vision_tower = self.vision_tower.to(dtype=target_dtype)
- self.mm_projector = self.mm_projector.to(dtype=target_dtype)
+ # Ensure dtype consistency between vision and language components
+ if hasattr(self, "vision_tower") and hasattr(self.language_model, "dtype"):
+ target_dtype = self.language_model.dtype
+ self.vision_tower = self.vision_tower.to(dtype=target_dtype)
+ self.mm_projector = self.mm_projector.to(dtype=target_dtype)
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
- pixel_values = torch.cat([item.feature for item in items], dim=0).type(
- self.vision_tower.dtype
+ pixel_values = torch.cat([item.feature for item in items], dim=0).to(
+ dtype=self.vision_tower.dtype, device=self.vision_tower.device
)
grid_thws = torch.concat([item.grid_thws for item in items], dim=0).to(
self.vision_tower.device
@@ -735,41 +740,59 @@ class KimiK25ForConditionalGeneration(nn.Module):
return hidden_states
+ def set_eagle3_layers_to_capture(self, layer_ids=None):
+ self.language_model.set_eagle3_layers_to_capture(layer_ids)
+
+ def get_embed_and_head(self):
+ return self.language_model.get_embed_and_head()
+
+ def set_embed_and_head(self, embed, head):
+ self.language_model.set_embed_and_head(embed, head)
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
"""Load weights for the model, separating vision and language weights"""
mapper = getattr(self, "hf_to_sglang_mapper", None)
if mapper is not None:
weights = mapper.apply(weights)
+ is_encoder_only = getattr(self.config, "encoder_only", False)
+ is_language_only = getattr(self.config, "language_only", False)
+
# Separate vision tower weights and language model weights
vision_weights = []
language_weights = []
for name, loaded_weight in weights:
if "vision_tower" in name or "mm_projector" in name:
+ # Skip vision weights in language_only mode
+ if is_language_only:
+ continue
name = name.replace(r"wqkv.", r"attn.qkv_proj.")
name = name.replace(r"wo.", r"attn.proj.")
name = name.replace("mm_projector.proj.0", "mm_projector.linear_1")
name = name.replace("mm_projector.proj.2", "mm_projector.linear_2")
vision_weights.append((name, loaded_weight))
else:
+ # Skip language weights in encoder_only mode
+ if is_encoder_only:
+ continue
name = name.replace("language_model.", "")
# All other weights go to language model
language_weights.append((name, loaded_weight))
# Load vision tower weights
- vision_state_dict = dict(vision_weights)
- params_dict = dict(self.named_parameters(remove_duplicate=False))
- for name, loaded_weight in vision_state_dict.items():
- if name not in params_dict:
- raise ValueError(f"Weight {name} not found in params_dict")
- param = params_dict[name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- # loaded_weight = self._pad_vit_attn_dummy_heads(name, loaded_weight)
- weight_loader(param, loaded_weight)
+ if not is_language_only:
+ vision_state_dict = dict(vision_weights)
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
+ for name, loaded_weight in vision_state_dict.items():
+ if name not in params_dict:
+ raise ValueError(f"Weight {name} not found in params_dict")
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
+ weight_loader(param, loaded_weight)
# Load language model weights
- if language_weights:
+ if not is_encoder_only and language_weights:
self.language_model.load_weights(language_weights)
@@ -85,6 +85,11 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
embeds = self.input_layernorm(embeds)
hidden_states = self.hidden_norm(hidden_states)
+ if embeds.dtype != hidden_states.dtype:
+ raise RuntimeError(
+ f"Eagle3 dtype mismatch: embeds.dtype={embeds.dtype}, "
+ f"hidden_states.dtype={hidden_states.dtype}"
+ )
hidden_states = torch.cat([embeds, hidden_states], dim=-1)
# Self Attention
hidden_states = self.self_attn(
@@ -160,6 +165,11 @@ class LlamaModel(nn.Module):
hidden_states = forward_batch.spec_info.hidden_states
if hidden_states.shape[-1] != embeds.shape[-1]:
+ if hidden_states.dtype != self.fc.weight.dtype:
+ raise RuntimeError(
+ f"Eagle3 dtype mismatch: hidden_states.dtype={hidden_states.dtype}, "
+ f"fc.weight.dtype={self.fc.weight.dtype}"
+ )
hidden_states = self.fc(hidden_states)
# idle batch
@@ -372,6 +372,7 @@ class Qwen3_5LinearDecoderLayer(nn.Module):
input_layernorm=self.input_layernorm,
post_attention_layernorm=self.post_attention_layernorm,
allow_reduce_scatter=True,
+ is_last_layer=(layer_id == config.num_hidden_layers - 1),
)
def forward(
@@ -400,11 +401,24 @@ class Qwen3_5LinearDecoderLayer(nn.Module):
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
forward_batch
)
- hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter)
- hidden_states, residual = self.layer_communicator.postprocess_layer(
- hidden_states, residual, forward_batch
+ should_allreduce_fusion = (
+ self.layer_communicator.should_fuse_mlp_allreduce_with_next_layer(
+ forward_batch
+ )
)
+ if isinstance(self.mlp, Qwen2MoeSparseMoeBlock):
+ hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter)
+ else:
+ hidden_states = self.mlp(
+ hidden_states, should_allreduce_fusion, use_reduce_scatter
+ )
+ if should_allreduce_fusion:
+ hidden_states._sglang_needs_allreduce_fusion = True
+ else:
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
+ hidden_states, residual, forward_batch
+ )
return hidden_states, residual
@@ -549,6 +563,7 @@ class Qwen3_5AttentionDecoderLayer(nn.Module):
input_layernorm=self.input_layernorm,
post_attention_layernorm=self.post_attention_layernorm,
allow_reduce_scatter=True,
+ is_last_layer=(layer_id == config.num_hidden_layers - 1),
)
self.alt_stream = alt_stream
@@ -633,11 +648,24 @@ class Qwen3_5AttentionDecoderLayer(nn.Module):
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
forward_batch
)
- hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter)
- hidden_states, residual = self.layer_communicator.postprocess_layer(
- hidden_states, residual, forward_batch
+ should_allreduce_fusion = (
+ self.layer_communicator.should_fuse_mlp_allreduce_with_next_layer(
+ forward_batch
+ )
)
+ if isinstance(self.mlp, Qwen2MoeSparseMoeBlock):
+ hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter)
+ else:
+ hidden_states = self.mlp(
+ hidden_states, should_allreduce_fusion, use_reduce_scatter
+ )
+ if should_allreduce_fusion:
+ hidden_states._sglang_needs_allreduce_fusion = True
+ else:
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
+ hidden_states, residual, forward_batch
+ )
return hidden_states, residual
@@ -711,14 +711,19 @@ class Qwen3LLMModel(Qwen3Model):
hidden_states + residual if residual is not None else hidden_states
)
+ deepstack_embeds = None
+ if input_deepstack_embeds is not None:
+ prev_layer_idx = layer_idx - 1
+ if prev_layer_idx in self.deepstack_embed_to_decoder_layer:
+ sep = self.hidden_size * prev_layer_idx
+ deepstack_embeds = input_deepstack_embeds[
+ :, sep : sep + self.hidden_size
+ ]
+
# SGLang applies residual at the START of the next layer, not at the END like HuggingFace.
# See: https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py#L549
# To match HF behavior, deepstack must be added AFTER residual: (hidden_states + residual) + deepstack
# The order matters because addition with different tensors is not associative in practice.
- # Deepstack for prev_layer is applied at the start of current layer via post_residual_addition.
- deepstack_embeds = self.get_deepstack_embeds(
- layer_idx - 1, input_deepstack_embeds
- )
hidden_states, residual = layer(
positions,
hidden_states,
@@ -1,6 +1,9 @@
from typing import List, Union
+import torch
+
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
+from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.glm4v import Glm4vForConditionalGeneration
from sglang.srt.models.glm4v_moe import Glm4vMoeForConditionalGeneration
from sglang.srt.multimodal.processors.base_processor import (
@@ -45,6 +48,8 @@ class Glm4vImageProcessor(SGLangBaseProcessor):
self.IMAGE_END_TOKEN_ID = hf_config.image_end_token_id
self.VIDEO_START_TOKEN_ID = hf_config.video_start_token_id
self.VIDEO_END_TOKEN_ID = hf_config.video_end_token_id
+ self.IM_START_TOKEN_ID = self.IMAGE_START_TOKEN_ID
+ self.IM_END_TOKEN_ID = self.IMAGE_END_TOKEN_ID
# Vision config
self.IMAGE_FACTOR = 28
@@ -59,6 +64,36 @@ class Glm4vImageProcessor(SGLangBaseProcessor):
video_token_id=self.IM_TOKEN_ID,
).build(_processor)
+ def get_mm_data(self, prompt, embeddings, img_grid_thw):
+ input_ids, offsets = self.build_input_ids(prompt, img_grid_thw)
+ mm_items = [
+ MultimodalDataItem(
+ modality=Modality.IMAGE,
+ offsets=offsets,
+ precomputed_embeddings=embeddings,
+ )
+ ]
+
+ input_ids_tensor = torch.tensor(input_ids)
+ mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index_glm4v(
+ input_ids=input_ids_tensor.unsqueeze(0),
+ hf_config=self.hf_config,
+ image_grid_thw=img_grid_thw,
+ video_grid_thw=None,
+ attention_mask=None,
+ )
+ mrope_positions = mrope_positions.squeeze(1)
+
+ return {
+ "input_ids": input_ids,
+ "mm_items": mm_items,
+ "im_start_id": self.IM_START_TOKEN_ID,
+ "im_end_id": self.IM_END_TOKEN_ID,
+ "im_token_id": self.IM_TOKEN_ID,
+ "mrope_positions": mrope_positions,
+ "mrope_position_delta": mrope_position_delta,
+ }
+
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
@@ -25,6 +25,18 @@ class KimiK2_5VLImageProcessor(SGLangBaseProcessor):
image_token_id=hf_config.media_placeholder_token_id,
image_token_regex=re.compile(r"(?:<\|media_pad\|>)+"),
).build(_processor)
+ # Required by base class get_mm_data / build_input_ids for EPD mode
+ self.IM_TOKEN_ID = hf_config.media_placeholder_token_id
+ self.IM_START_TOKEN_ID = None
+ self.IM_END_TOKEN_ID = None
+ merge_kernel = getattr(hf_config.vision_config, "merge_kernel_size", [2, 2])
+ self._spatial_merge_size = (
+ merge_kernel[0] if isinstance(merge_kernel, (list, tuple)) else merge_kernel
+ )
+
+ @property
+ def spatial_merge_size(self):
+ return self._spatial_merge_size
async def process_mm_data_async(
self,
@@ -317,7 +317,7 @@ class QwenVLImageProcessor(SGLangBaseProcessor):
**kwargs,
):
entry_time = time.perf_counter()
- base_output = self.load_mm_data(
+ base_output = self.legacy_load_mm_data(
prompt=input_text,
image_data=image_data,
video_data=request_obj.video_data,
@@ -580,6 +580,7 @@ class ServerArgs:
cuda_graph_max_bs: Optional[int] = None
cuda_graph_bs: Optional[List[int]] = None
disable_cuda_graph: bool = False
+ disable_draft_cuda_graph: bool = False
disable_cuda_graph_padding: bool = False
enable_profile_cuda_graph: bool = False
enable_cudagraph_gc: bool = False
@@ -635,6 +636,7 @@ class ServerArgs:
# Context parallelism used in the long sequence prefill phase of DeepSeek v3.2
enable_nsa_prefill_context_parallel: bool = False
nsa_prefill_cp_mode: str = "round-robin-split"
+ disable_indexer_rope_neox_style: bool = False
enable_fused_qk_norm_rope: bool = False
enable_precise_embedding_interpolation: bool = False
@@ -2089,7 +2091,16 @@ class ServerArgs:
assert (
self.tp_size % (self.dp_size * self.attn_cp_size) == 0
), "tp_size must be divisible by dp_size * attn_cp_size"
- assert self.pp_size == 1, "PP is not supported with context parallelism"
+ if self.pp_size > 1:
+ assert (
+ self.disaggregation_mode == "prefill"
+ and self.enable_nsa_prefill_context_parallel
+ and self.nsa_prefill_cp_mode == "round-robin-split"
+ ), (
+ "PP with context parallelism is only supported for PD prefill "
+ "with --enable-nsa-prefill-context-parallel and "
+ "--nsa-prefill-cp-mode round-robin-split."
+ )
if self.moe_dp_size > 1:
# The tp_size is the world size, not the real tensor parallel size
@@ -4491,6 +4502,11 @@ class ServerArgs:
action="store_true",
help="Disable cuda graph.",
)
+ parser.add_argument(
+ "--disable-draft-cuda-graph",
+ action="store_true",
+ help="Disable cuda graph for draft model in speculative decoding.",
+ )
parser.add_argument(
"--disable-cuda-graph-padding",
action="store_true",
@@ -4781,6 +4797,12 @@ class ServerArgs:
help="Token splitting mode for the prefill phase of DeepSeek v3.2 under context parallelism. Optional values: 'round-robin-split'(default), 'in-seq-split' "
"'round-robin-split' distributes tokens across ranks based on token_idx %% cp_size. It supports multi-batch prefill, fused MoE, and FP8 KV cache.",
)
+ parser.add_argument(
+ "--disable-indexer-rope-neox-style",
+ action="store_true",
+ help="Disable NSA indexer RoPE neox style (equivalent to INDEXER_ROPE_NEOX_STYLE=0). "
+ "If the environment variable INDEXER_ROPE_NEOX_STYLE is also set and conflicts, an error is raised.",
+ )
parser.add_argument(
"--enable-fused-qk-norm-rope",
action="store_true",
@@ -5636,6 +5658,54 @@ class PortArgs:
)
if not server_args.enable_dp_attention:
+ # In multi-node prefill PD with PP/CP, use TCP transport for tokenizer<->scheduler
+ # IPC occasionally stalls in this topology.
+ if (
+ server_args.nnodes > 1
+ and server_args.disaggregation_mode == "prefill"
+ and server_args.dist_init_addr is not None
+ ):
+ if server_args.dist_init_addr.startswith("["): # ipv6 address
+ port_num, host = configure_ipv6(server_args.dist_init_addr)
+ dist_init_addr = (host, str(port_num))
+ else:
+ dist_init_addr = server_args.dist_init_addr.split(":")
+
+ assert (
+ len(dist_init_addr) == 2
+ ), "please provide --dist-init-addr as host:port of head node"
+
+ dist_init_host, dist_init_port = dist_init_addr
+ dist_init_port = int(dist_init_port)
+ port_base = dist_init_port + ZMQ_TCP_PORT_DELTA
+ detokenizer_port = port_base + 1
+ rpc_port = port_base + 2
+ metrics_port = port_base + 3
+ scheduler_input_port = port_base + 4
+
+ try:
+ wait_port_available(dist_init_port, "dist_init_port")
+ wait_port_available(port_base, "port_base")
+ wait_port_available(detokenizer_port, "detokenizer_port")
+ wait_port_available(nccl_port, "nccl_port")
+ wait_port_available(rpc_port, "rpc_port")
+ wait_port_available(metrics_port, "metrics_port")
+ wait_port_available(scheduler_input_port, "scheduler_input_port")
+ except ValueError:
+ logger.exception(
+ f"Port is already in use. {dist_init_port=} {port_base=} {detokenizer_port=} {nccl_port=} {scheduler_input_port=}"
+ )
+ raise
+
+ return PortArgs(
+ tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
+ scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
+ detokenizer_ipc_name=f"tcp://{dist_init_host}:{detokenizer_port}",
+ nccl_port=nccl_port,
+ rpc_ipc_name=f"tcp://{dist_init_host}:{rpc_port}",
+ metrics_ipc_name=f"tcp://{dist_init_host}:{metrics_port}",
+ tokenizer_worker_ipc_name=tokenizer_worker_ipc_name,
+ )
# Normal case, use IPC within a single node
return PortArgs(
tokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
@@ -341,7 +341,10 @@ class EAGLEDraftCudaGraphRunner:
self.seq_lens.fill_(self.seq_len_fill_value)
self.out_cache_loc.zero_()
self.positions.zero_()
-
+ self.topk_p.zero_()
+ self.topk_index.zero_()
+ self.hidden_states.zero_()
+ self.req_pool_indices.zero_()
num_tokens = bs * self.num_tokens_per_bs
# Common inputs
@@ -350,8 +353,12 @@ class EAGLEDraftCudaGraphRunner:
forward_batch.out_cache_loc
)
self.positions[:raw_num_token].copy_(forward_batch.positions)
- self.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p)
- self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
+ self.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p.clamp(0, 1))
+ self.topk_index[:raw_bs].copy_(
+ forward_batch.spec_info.topk_index.clamp(
+ 0, self.model_runner.model_config.vocab_size - 1
+ )
+ )
self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
@@ -337,7 +337,7 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
sampling_info.top_ks, self.draft_token_num, dim=0
),
) # (bs * draft_token_num, vocab_size)
- if not torch.all(sampling_info.top_ps == 1.0):
+ if sampling_info.need_top_p_sampling:
target_probs = top_p_renorm_prob(
target_probs,
torch.repeat_interleave(
@@ -774,6 +774,10 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
self.topk_index = self.topk_index[: len(new_indices)]
self.hidden_states = self.hidden_states[: len(new_indices)]
self.verified_id = self.verified_id[: len(new_indices)]
+ if self.accept_length is not None:
+ self.accept_length = self.accept_length[: len(new_indices)]
+ if self.accept_length_cpu is not None:
+ self.accept_length_cpu = self.accept_length_cpu[: len(new_indices)]
else:
# in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index`
self.topk_p = self.topk_p[new_indices]
@@ -805,6 +809,27 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
self.topk_p = torch.cat([self.topk_p, spec_info.topk_p])
self.topk_index = torch.cat([self.topk_index, spec_info.topk_index])
+ if self.accept_length is not None and spec_info.accept_length is not None:
+ self.accept_length = torch.cat(
+ [self.accept_length, spec_info.accept_length]
+ )
+ self.accept_length_cpu = self.accept_length.tolist()
+ elif self.accept_length is not None:
+ zeros = torch.zeros(
+ [spec_info.verified_id.shape[0]],
+ dtype=self.accept_length.dtype,
+ device=self.accept_length.device,
+ )
+ self.accept_length = torch.cat([self.accept_length, zeros])
+ self.accept_length_cpu = self.accept_length.tolist()
+ elif spec_info.accept_length is not None:
+ zeros = torch.zeros(
+ [self.verified_id.shape[0]],
+ dtype=self.accept_length.dtype,
+ device=self.accept_length.device,
+ )
+ self.accept_length = torch.cat([zeros, spec_info.accept_length])
+ self.accept_length_cpu = self.accept_length.tolist()
@dataclass
@@ -234,7 +234,10 @@ class EAGLEWorker(TpModelWorker):
self.cuda_graph_runner = None
self.cuda_graph_runner_for_draft_extend = None
- if self.server_args.disable_cuda_graph:
+ if (
+ self.server_args.disable_cuda_graph
+ or self.server_args.disable_draft_cuda_graph
+ ):
return
Device2DraftCudaGraphRunner = {
@@ -2359,6 +2359,8 @@ class SafeUnpickler(pickle.Unpickler):
"sglang.srt.model_executor.model_runner.",
"sglang.srt.layers.",
"sglang.srt.utils.",
+ # --- slime ---
+ "slime.",
}
DENY_CLASSES = {
@@ -69,6 +69,9 @@ def _check_tensors(
actual_should_compare,
actual,
) in zip(expect_tensors, actual_tensors, strict=True):
+ if ".cos_sin_cache" in expect_name:
+ # skip cos/sin cache which is deterministic from shape and dtype and may have different shapes due to different implementations.
+ continue
assert expect_name == actual_name, f"{expect_name=} {actual_name=}"
assert (
expect_should_compare == actual_should_compare