@@ -268,6 +268,12 @@ class ModelConfig:
):
self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN"
+ if (
+ is_draft_model
+ and self.hf_config.architectures[0] == "DeepseekV32ForCausalLM"
+ ):
+ self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN"
+
if is_draft_model and self.hf_config.architectures[0] == "Glm4MoeForCausalLM":
self.hf_config.architectures[0] = "Glm4MoeForCausalLMNextN"
@@ -315,6 +315,13 @@ class DecodePreallocQueue:
)
return kv_manager
+ def release_memory_occupation(self):
+ if hasattr(self.kv_manager, "close"):
+ self.kv_manager.close()
+
+ def resume_memory_occupation(self):
+ self.kv_manager = self._init_kv_manager()
+
def add(self, req: Req, is_retracted: bool = False) -> None:
"""Add a request to the pending queue."""
if self._check_if_req_exceed_kv_capacity(req):
@@ -1079,6 +1079,19 @@ class MooncakeKVManager(CommonKVManager):
f"Losing connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}), {len(affected_rooms)} requests affected"
)
+ def close(self):
+ # Batch deregister KV data buffers
+ if self.kv_args.kv_data_ptrs:
+ self.engine.batch_deregister(self.kv_args.kv_data_ptrs)
+
+ # Batch deregister auxiliary data buffers
+ if self.kv_args.aux_data_ptrs:
+ self.engine.batch_deregister(self.kv_args.aux_data_ptrs)
+
+ # Batch deregister state/extra pool data buffers
+ if self.kv_args.state_data_ptrs:
+ self.engine.batch_deregister(self.kv_args.state_data_ptrs)
+
class MooncakeKVSender(CommonKVSender):
@@ -306,6 +306,13 @@ class PrefillBootstrapQueue:
else:
return bootstrapped_reqs, failed_reqs
+ def release_memory_occupation(self):
+ if hasattr(self.kv_manager, "close"):
+ self.kv_manager.close()
+
+ def resume_memory_occupation(self):
+ self.kv_manager = self._init_kv_manager()
+
class SchedulerDisaggregationPrefillMixin:
"""
@@ -1797,7 +1797,10 @@ def get_tensor_model_parallel_world_size():
def get_tensor_model_parallel_rank():
"""Return my rank for the tensor model parallel group."""
- return get_tp_group().rank_in_group
+ try:
+ return get_tp_group().rank_in_group
+ except Exception:
+ return 0
def get_pipeline_model_parallel_world_size():
@@ -49,6 +49,7 @@ from sglang.srt.managers.io_struct import (
InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput,
MultimodalDataInputFormat,
+ PostProcessWeightsReqInput,
ReleaseMemoryOccupationReqInput,
ResumeMemoryOccupationReqInput,
RpcReqInput,
@@ -593,6 +594,24 @@ class Engine(EngineBase):
self.tokenizer_manager.update_weights_from_ipc(obj, None)
)
+ def post_process_weights(
+ self,
+ restore_weights_before_load: bool = False,
+ post_process_quantization: bool = False,
+ ):
+ """
+ Optional post-processing for updated weights (e.g., Marlin conversion).
+ Should be called after weight update is finished.
+ """
+ obj = PostProcessWeightsReqInput(
+ restore_weights_before_load=restore_weights_before_load,
+ post_process_quantization=post_process_quantization,
+ )
+
+ return self.loop.run_until_complete(
+ self.tokenizer_manager.post_process_weights(obj, None)
+ )
+
def get_weights_by_name(self, name: str, truncate_size: int = 100):
"""Get weights by parameter name."""
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
@@ -107,6 +107,7 @@ from sglang.srt.managers.io_struct import (
OpenSessionReqInput,
ParseFunctionCallReq,
PauseGenerationReqInput,
+ PostProcessWeightsReqInput,
ProfileReqInput,
ReleaseMemoryOccupationReqInput,
ResumeMemoryOccupationReqInput,
@@ -957,6 +958,21 @@ async def update_weights_from_ipc(obj: UpdateWeightsFromIPCReqInput, request: Re
else:
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
+@app.post("/post_process_weights")
+async def post_process_weights(req: PostProcessWeightsReqInput, request: Request):
+ """
+ Optional post-processing for updated weights (e.g., Marlin conversion).
+ This should be called selectively after `update_weights_from_distributed/update_weights_from_tensor`.
+ """
+ success, message = await _global_state.tokenizer_manager.post_process_weights(
+ req, request
+ )
+
+ content = {"success": success, "message": message}
+ return ORJSONResponse(
+ content, status_code=200 if success else HTTPStatus.BAD_REQUEST
+ )
+
@app.post("/update_weight_version")
async def update_weight_version(obj: UpdateWeightVersionReqInput, request: Request):
@@ -3,6 +3,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+import os
import torch
from einops import rearrange
@@ -178,7 +179,7 @@ class Indexer(MultiPlatformOp):
max_position=max_position_embeddings,
base=rope_theta, # type: ignore
rope_scaling=rope_scaling,
- is_neox_style=True,
+ is_neox_style=True if os.environ.get("INDEXER_ROPE_NEOX_STYLE", "1") == "1" else False,
device=get_global_server_args().device,
)
self.block_size = block_size
@@ -188,6 +189,9 @@ class Indexer(MultiPlatformOp):
@torch.compile(dynamic=True)
def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor):
weights, _ = self.weights_proj(x.float())
+ if weights.shape[1] < 32:
+ assert 32 % weights.shape[1] == 0
+ weights = weights.repeat_interleave(32 // weights.shape[1], dim=1)
weights = weights * self.n_heads**-0.5
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
return weights
@@ -837,6 +841,9 @@ class Indexer(MultiPlatformOp):
query, key = self._get_q_k_bf16(
q_lora, x, positions, enable_dual_stream, forward_batch=forward_batch
)
+ if query.shape[1] < 32:
+ assert 32 % query.shape[1] == 0
+ query = query.repeat_interleave(32//query.shape[1], dim=1)
if enable_dual_stream:
current_stream = torch.cuda.current_stream()
@@ -83,15 +83,12 @@ class RMSNorm(MultiPlatformOp):
eps: float = 1e-6,
var_hidden_size: Optional[int] = None,
cast_x_before_out_mul: bool = False,
- fp32_residual: bool = False,
- weight_dtype: Optional = None,
- override_orig_dtype: Optional = None,
+ fp32_residual: bool = True,
) -> None:
super().__init__()
self.cast_x_before_out_mul = cast_x_before_out_mul
self.fp32_residual = fp32_residual
- self.override_orig_dtype = override_orig_dtype
- self.weight = nn.Parameter(torch.ones(hidden_size, dtype=weight_dtype))
+ self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
self.hidden_size = hidden_size
self.variance_size_override = (
@@ -193,10 +190,22 @@ class RMSNorm(MultiPlatformOp):
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if not x.is_contiguous():
x = x.contiguous()
- orig_dtype = self.override_orig_dtype or x.dtype
+ orig_dtype = x.dtype
post_residual_addition = kwargs.get("post_residual_addition")
+
+ if residual is not None and not self.fp32_residual:
+ x = (
+ x
+ + residual
+ + (
+ post_residual_addition
+ if post_residual_addition is not None
+ else 0.0
+ )
+ )
+ residual = x.clone()
x = x.to(torch.float32)
- if residual is not None:
+ if residual is not None and self.fp32_residual:
x = (
x
+ residual.to(torch.float32)
@@ -206,10 +215,7 @@ class RMSNorm(MultiPlatformOp):
else 0.0
)
)
- if self.fp32_residual:
- residual = x.clone()
- else:
- residual = x.to(orig_dtype)
+ residual = x.to(orig_dtype)
hidden_size = x.shape[-1]
if hidden_size != self.hidden_size:
@@ -878,11 +878,6 @@ class LogitsProcessor(nn.Module):
None, # bias
True, # is_vnni
)
- elif get_global_server_args().rl_on_policy_target is not None:
- # Due to tie-weight, we may not be able to change lm_head's weight dtype
- logits = torch.matmul(
- hidden_states.bfloat16(), lm_head.weight.T.bfloat16()
- )
else:
logits = torch.matmul(
hidden_states.to(lm_head.weight.dtype), lm_head.weight.T
@@ -14,6 +14,7 @@ import torch.nn.functional as F
import triton.language as tl
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
+from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import (
cpu_has_amx_support,
get_bool_env_var,
@@ -573,7 +574,10 @@ def fused_experts_impl(
).squeeze(dim=1)
else:
# According to micro benchmark results, torch.compile can get better performance for small token.
- if tokens_in_chunk <= 32:
+ if (
+ not get_global_server_args().enable_deterministic_inference
+ and tokens_in_chunk <= 32
+ ):
moe_sum_reduce_torch_compile(
intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx],
@@ -1,5 +1,6 @@
import logging
from abc import ABC
+from contextlib import contextmanager
from typing import Optional
import numpy as np
@@ -8,13 +9,18 @@ import torch
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.layers.dp_attention import (
+ attn_tp_all_gather_into_tensor,
get_attention_dp_rank,
+ get_attention_tp_size,
get_dp_local_info,
is_dp_attention_enabled,
)
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import get_global_server_args
+from sglang.srt.layers.moe import (
+ get_moe_a2a_backend,
+)
logger = logging.getLogger(__name__)
@@ -181,13 +187,26 @@ class _RoutedExpertsCapturerReal(RoutedExpertsCapturer):
device=device,
)
+ if get_moe_a2a_backend().is_deepep():
+ attn_tp_size = get_attention_tp_size() if is_dp_attention_enabled() else 1
+ self.gather_buffer = torch.empty(
+ (
+ self.device_cache.buffer.shape[0] * attn_tp_size,
+ self.device_cache.buffer.shape[2],
+ ),
+ dtype=torch.int32,
+ device=device,
+ )
+
def _sync_fwd_experts_buffer_DtoH(
self,
forward_batch: ForwardBatch,
can_run_graph: bool,
cuda_graph_batch: int,
):
- if is_dp_attention_enabled():
+ # When DeepEP is enabled, capture() already does all_gather, so device_cache.buffer
+ # contains data from all DP ranks. We should not slice by DP rank in this case.
+ if is_dp_attention_enabled() and not get_moe_a2a_backend().is_deepep():
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
# handle with cuda graph padding
if can_run_graph:
@@ -206,6 +225,12 @@ class _RoutedExpertsCapturerReal(RoutedExpertsCapturer):
].cpu()
def capture(self, layer_id: int, topk_ids: torch.Tensor):
+ if get_moe_a2a_backend().is_deepep():
+ local_topk_ids = topk_ids
+ topk_ids = self.gather_buffer[
+ : local_topk_ids.size(0) * get_attention_tp_size()
+ ]
+ attn_tp_all_gather_into_tensor(topk_ids, local_topk_ids)
self.device_cache.capture_fwd_routed_experts(layer_id, topk_ids)
def get_routed_experts(
@@ -1016,13 +1016,37 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
layer.a2_scale = None
layer.marlin_state = GPTQMarlinState.REPACK
+ if not hasattr(layer, "_original_shapes"):
+ layer._original_shapes = {}
+
+ # Force record: these are the target GPTQ shapes for rollback.
+ layer._original_shapes["w13_weight_packed"] = tuple(w13_weight.shape)
+ layer._original_shapes["w2_weight_packed"] = tuple(w2_weight.shape)
+
+ # Also record the shapes of the scales.
+ layer._original_shapes["w2_weight_scale"] = tuple(w2_scale.shape)
+ layer._original_shapes["w13_weight_scale"] = tuple(w13_scale.shape)
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
+ # Skip if the layer is already converted to Marlin format to prevent double-packing.
+ if getattr(layer, "is_marlin_converted", False):
+ return
+
+ if not hasattr(layer, "_original_shapes"):
+ layer._original_shapes = {}
def replace_tensor(name, new_t):
+ target_attr = getattr(layer, name)
+
+ # Only save if the key doesn't exist to prevent overwriting with Marlin shapes.
+ if name not in layer._original_shapes:
+ # This is a safety check; `create_weights` usually handles this already.
+ layer._original_shapes[name] = tuple(target_attr.shape)
+
# It is important to use resize_() here since it ensures
# the same buffer is reused
- getattr(layer, name).resize_(new_t.shape)
- getattr(layer, name).copy_(new_t)
+ target_attr.resize_(new_t.shape)
+ target_attr.copy_(new_t)
del new_t
num_experts = layer.w13_weight_g_idx.shape[0]
@@ -1078,7 +1102,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
layer.w13_weight_packed.shape[2],
self.num_bits,
)
- replace_parameter(layer, "w13_weight_packed", marlin_w13_qweight)
+ replace_tensor("w13_weight_packed", marlin_w13_qweight)
marlin_w2_qweight = gptq_marlin_moe_repack(
layer.w2_weight_packed,
layer.w2_g_idx_sort_indices,
@@ -1086,7 +1110,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight_packed.shape[2],
self.num_bits,
)
- replace_parameter(layer, "w2_weight_packed", marlin_w2_qweight)
+ replace_tensor("w2_weight_packed", marlin_w2_qweight)
# Repack scales
marlin_w13_scales = marlin_moe_permute_scales(
layer.w13_weight_scale,
@@ -1094,7 +1118,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
layer.w13_weight_scale.shape[2],
self.group_size,
)
- replace_parameter(layer, "w13_weight_scale", marlin_w13_scales)
+ replace_tensor("w13_weight_scale", marlin_w13_scales)
marlin_w2_scales = marlin_moe_permute_scales(
layer.w2_weight_scale,
@@ -1103,7 +1127,22 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight_scale.shape[2],
self.group_size,
)
- replace_parameter(layer, "w2_weight_scale", marlin_w2_scales)
+ replace_tensor("w2_weight_scale", marlin_w2_scales)
+
+ layer.is_marlin_converted = True
+
+ def restore_weights_before_loading(self, layer: torch.nn.Module):
+ """Forcibly resize parameters back to their original shapes (e.g., GPTQ format) before loading weights."""
+ if not hasattr(layer, "_original_shapes"):
+ return
+
+ for name, orig_shape in layer._original_shapes.items():
+ param = getattr(layer, name, None)
+
+ if param is not None and param.shape != orig_shape:
+ param.resize_(orig_shape)
+
+ layer.is_marlin_converted = False
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
@@ -136,9 +136,7 @@ class RotaryEmbedding(MultiPlatformOp):
if get_global_server_args().rl_on_policy_target is not None:
self._forward_method = self.forward_native
- self._apply_rotary_emb_wrapped = torch.compile(dynamic=True)(
- self._apply_rotary_emb_wrapped
- )
+
self.position_cos, self.position_sin = None, None
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
@@ -1578,6 +1576,9 @@ class MRotaryEmbedding(RotaryEmbedding):
key: torch.Tensor,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
+ assert (
+ fused_set_kv_buffer_arg is None
+ ), "fused_set_kv_buffer_arg is not supported for npu implementation"
# TODO: remove this when npu_mrope supports QNumHeads * QHeadSize > 4096
assert (
fused_set_kv_buffer_arg is None
@@ -108,16 +108,11 @@ class Sampler(nn.Module):
if return_logprob and SGLANG_RETURN_ORIGINAL_LOGPROB:
probs_without_temp_scaling = torch.softmax(logits, dim=-1)
- if get_global_server_args().rl_on_policy_target is not None:
- logits_div_temperature = (
- logits.bfloat16().div(sampling_info.temperatures).bfloat16()
- )
- logprobs_via_logsoftmax_kernel = torch.log_softmax(
- logits_div_temperature, dim=-1
- )
-
# Post process logits
logits.div_(sampling_info.temperatures)
+ if get_global_server_args().rl_on_policy_target is not None:
+ logprobs_via_logsoftmax_kernel = torch.log_softmax(logits, dim=-1)
+
# For ascend backend, softmax is not needed before sampling
if not get_global_server_args().sampling_backend == "ascend" or (
return_logprob and not SGLANG_RETURN_ORIGINAL_LOGPROB
@@ -1292,6 +1292,19 @@ class UpdateWeightsFromIPCReqOutput(BaseReq):
success: bool
message: str
+@dataclass
+class PostProcessWeightsReqInput(BaseReq):
+ # Whether to restore weights before loading new weights
+ restore_weights_before_load: bool = False
+ # Whether to enable quantization post-processing
+ post_process_quantization: bool = False
+
+
+@dataclass
+class PostProcessWeightsReqOutput(BaseReq):
+ success: bool
+ message: str
+
@dataclass
class InitWeightsSendGroupForRemoteInstanceReqOutput(BaseReq):
@@ -2186,7 +2186,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def __str__(self):
return (
f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
- f"#req={(len(self.reqs))})"
+ f"#req={(len(self.reqs))}), "
+ f"#out_cache_loc={self.out_cache_loc})"
)
@@ -98,6 +98,7 @@ from sglang.srt.managers.io_struct import (
OpenSessionReqInput,
OpenSessionReqOutput,
PauseGenerationReqInput,
+ PostProcessWeightsReqInput,
ProfileReq,
ReleaseMemoryOccupationReqInput,
ResumeMemoryOccupationReqInput,
@@ -1060,6 +1061,7 @@ class Scheduler(
),
(UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
(UpdateWeightsFromIPCReqInput, self.update_weights_from_ipc),
+ (PostProcessWeightsReqInput, self.post_process_weights),
(GetWeightsByNameReqInput, self.get_weights_by_name),
(ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
@@ -10,6 +10,7 @@ from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.environ import envs
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.moe.routed_experts_capturer import get_global_experts_capturer
+
from sglang.srt.managers.io_struct import (
AbortReq,
BatchEmbeddingOutput,
@@ -1070,7 +1071,7 @@ class SchedulerOutputProcessorMixin:
req.log_time_stats()
# Send to detokenizer
- if reqs or is_idle_batch:
+ if rids or is_idle_batch:
if self.model_config.is_multimodal_gen:
return
@@ -1,6 +1,7 @@
from __future__ import annotations
import logging
+import os
import traceback
from typing import TYPE_CHECKING, Tuple
@@ -12,6 +13,9 @@ from sglang.srt.constants import (
GPU_MEMORY_TYPE_KV_CACHE,
GPU_MEMORY_TYPE_WEIGHTS,
)
+from sglang.srt.disaggregation.utils import DisaggregationMode
+from sglang.srt.distributed import get_moe_ep_group, get_moe_tp_group, get_tp_group
+from sglang.srt.layers.dp_attention import get_attention_tp_group
from sglang.srt.managers.io_struct import (
CheckWeightsReqInput,
CheckWeightsReqOutput,
@@ -21,6 +25,8 @@ from sglang.srt.managers.io_struct import (
GetWeightsByNameReqOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
+ PostProcessWeightsReqInput,
+ PostProcessWeightsReqOutput,
ReleaseMemoryOccupationReqInput,
ReleaseMemoryOccupationReqOutput,
ResumeMemoryOccupationReqInput,
@@ -114,6 +120,11 @@ class SchedulerUpdateWeightsMixin:
torch.distributed.barrier(group=self.tp_cpu_group)
return UpdateWeightsFromIPCReqOutput(success, message)
+ def post_process_weights(self, recv_req: PostProcessWeightsReqInput):
+ """Optional post-processing for updated weights (e.g., Marlin conversion)."""
+ success, message = self.tp_worker.post_process_weights(recv_req)
+ return PostProcessWeightsReqOutput(success, message)
+
def get_weights_by_name(self: Scheduler, recv_req: GetWeightsByNameReqInput):
parameter = self.tp_worker.get_weights_by_name(recv_req)
return GetWeightsByNameReqOutput(parameter)
@@ -137,6 +148,13 @@ class SchedulerUpdateWeightsMixin:
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
self.flush_cache()
+ if self.disaggregation_mode == DisaggregationMode.DECODE:
+ if hasattr(self, "disagg_decode_prealloc_queue"):
+ self.disagg_decode_prealloc_queue.release_memory_occupation()
+ elif self.disaggregation_mode == DisaggregationMode.PREFILL:
+ if hasattr(self, "disagg_prefill_bootstrap_queue"):
+ self.disagg_prefill_bootstrap_queue.release_memory_occupation()
+
if GPU_MEMORY_TYPE_WEIGHTS in tags:
self.stashed_model_static_state = _export_static_state(
self.tp_worker.model_runner.model
@@ -177,6 +195,13 @@ class SchedulerUpdateWeightsMixin:
if GPU_MEMORY_TYPE_KV_CACHE in tags:
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE)
+ if self.disaggregation_mode == DisaggregationMode.DECODE:
+ if hasattr(self, "disagg_decode_prealloc_queue"):
+ self.disagg_decode_prealloc_queue.resume_memory_occupation()
+ elif self.disaggregation_mode == DisaggregationMode.PREFILL:
+ if hasattr(self, "disagg_prefill_bootstrap_queue"):
+ self.disagg_prefill_bootstrap_queue.resume_memory_occupation()
+
return ResumeMemoryOccupationReqOutput()
def check_weights(self: Scheduler, recv_req: CheckWeightsReqInput):
@@ -49,6 +49,8 @@ from sglang.srt.managers.io_struct import (
LoadLoRAAdapterReqOutput,
LoRAUpdateOutput,
OpenSessionReqInput,
+ PostProcessWeightsReqInput,
+ PostProcessWeightsReqOutput,
ProfileReq,
ProfileReqOutput,
ProfileReqType,
@@ -177,6 +179,9 @@ class TokenizerCommunicatorMixin:
self.update_weights_from_ipc_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
+ self.post_process_weights_communicator = _Communicator(
+ self.send_to_scheduler, server_args.dp_size
+ )
self.get_weights_by_name_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
@@ -250,6 +255,10 @@ class TokenizerCommunicatorMixin:
UpdateWeightsFromIPCReqOutput,
self.update_weights_from_ipc_communicator.handle_recv,
),
+ (
+ PostProcessWeightsReqOutput,
+ self.post_process_weights_communicator.handle_recv,
+ ),
(
GetWeightsByNameReqOutput,
self.get_weights_by_name_communicator.handle_recv,
@@ -433,6 +442,17 @@ class TokenizerCommunicatorMixin:
return success, message
+ async def post_process_weights(
+ self: TokenizerManager,
+ obj: PostProcessWeightsReqInput,
+ request: Optional[fastapi.Request] = None,
+ ) -> Tuple[bool, str]:
+ """Trigger post-processing hooks for weights after loading (e.g., Marlin conversion)."""
+ self.auto_create_handle_loop()
+ async with self.model_update_lock.writer_lock:
+ results = await self.post_process_weights_communicator(obj)
+ return _Communicator.merge_results(results)
+
async def init_weights_send_group_for_remote_instance(
self,
obj: InitWeightsSendGroupForRemoteInstanceReqInput,
@@ -27,6 +27,7 @@ from sglang.srt.managers.io_struct import (
InitWeightsSendGroupForRemoteInstanceReqInput,
InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput,
+ PostProcessWeightsReqInput,
SendWeightsToRemoteInstanceReqInput,
UnloadLoRAAdapterReqInput,
UpdateWeightFromDiskReqInput,
@@ -175,6 +176,11 @@ class BaseTpWorker(ABC):
success, message = self.model_runner.update_weights_from_ipc(recv_req)
return success, message
+ def post_process_weights(self, recv_req: PostProcessWeightsReqInput):
+ """Perform optional post-processing on the updated model weights (e.g., Marlin conversion)."""
+ success, message = self.model_runner.post_process_weights(recv_req)
+ return success, message
+
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.model_runner.get_weights_by_name(
recv_req.name, recv_req.truncate_size
@@ -1678,7 +1678,8 @@ class NSATokenToKVPool(MLATokenToKVPool):
with (
torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.custom_mem_pool
- else nullcontext()
+ else nullcontext(),
+ self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE),
):
self.index_k_with_scale_buffer = [
torch.zeros(
@@ -558,7 +558,8 @@ class ModelRunner(ModelRunnerKVCacheMixin):
)
# Init routed experts capturer
- self.init_routed_experts_capturer()
+ if not self.is_draft_worker:
+ self.init_routed_experts_capturer()
if self.device == "cuda":
self.init_cublas()
@@ -2224,11 +2225,19 @@ class ModelRunner(ModelRunnerKVCacheMixin):
output.expert_distribution_metrics = recorder_outputs.get("metrics")
# Copy cached routing experts' buffers back to CPU cache
- get_global_experts_capturer().on_forward_end(
- forward_batch=forward_batch,
- can_run_graph=output.can_run_graph,
- cuda_graph_batch=getattr(self.graph_runner, "bs", None),
- )
+ if not self.is_draft_worker:
+ # In speculative decoding, num_tokens_per_bs > 1, so we need to pass
+ # the actual number of tokens per dp rank in cuda graph, not batch size.
+ cuda_graph_num_tokens = None
+ if getattr(self.graph_runner, "bs", None):
+ cuda_graph_num_tokens = (
+ self.graph_runner.bs * self.graph_runner.num_tokens_per_bs
+ )
+ get_global_experts_capturer().on_forward_end(
+ forward_batch=forward_batch,
+ can_run_graph=output.can_run_graph,
+ cuda_graph_batch=cuda_graph_num_tokens,
+ )
if self.eplb_manager is not None:
self.eplb_manager.on_forward_pass_end()
@@ -2436,6 +2445,41 @@ class ModelRunner(ModelRunnerKVCacheMixin):
logger.error(f"IPC weight update failed: {e}")
return False, str(e)
+ def post_process_weights(self, recv_req):
+ """
+ Execute post-processing logic for model weights, such as Marlin quantization format conversion.
+ """
+ from sglang.srt.model_loader.loader import device_loading_context
+
+ target_device = torch.device("cuda", torch.cuda.current_device())
+
+ if recv_req.restore_weights_before_load:
+ for _, module in self.model.named_modules():
+ quant_method = getattr(module, "quant_method", None)
+
+ # Check if the module supports restoring weights
+ if quant_method is not None and hasattr(
+ quant_method, "restore_weights_before_loading"
+ ):
+
+ with device_loading_context(module, target_device):
+ quant_method.restore_weights_before_loading(module)
+
+ if recv_req.post_process_quantization:
+ # Iterate through all modules to apply specific post-loading processing
+ for _, module in self.model.named_modules():
+ quant_method = getattr(module, "quant_method", None)
+
+ # Check if the module supports quantization post-processing
+ if quant_method is not None and hasattr(
+ quant_method, "process_weights_after_loading"
+ ):
+
+ # Apply the post-processing (e.g., repacking weights for Marlin kernel)
+ with device_loading_context(module, target_device):
+ quant_method.process_weights_after_loading(module)
+
+ return True, "Success"
def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
params_dict = dict(model.named_parameters())
@@ -2704,7 +2704,11 @@ class DeepseekV2AttentionMLA(nn.Module):
):
k = k_nope.new_empty(*k_shape)
concat_mla_k(k=k, k_nope=k_nope, k_rope=k_pe)
- elif _is_cuda:
+ elif _is_cuda and all(
+ # (i.bit_count() == 1) == (is_power_of_two(i))
+ i.bit_count() == 1
+ for i in (k_shape[1], k_nope.shape[-1], k_pe.shape[-1])
+ ):
# fa3 mha support fp8 inputs
if (
self.current_attention_backend == "fa3"
@@ -90,9 +90,6 @@ class Qwen2MLP(nn.Module):
self.act_fn = SiluAndMul()
def forward(self, x):
- if get_global_server_args().rl_on_policy_target is not None:
- x = x.bfloat16()
-
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
@@ -279,11 +276,6 @@ class Qwen2Model(nn.Module):
quant_config=quant_config,
enable_tp=not is_dp_attention_enabled(),
prefix=add_prefix("embed_tokens", prefix),
- params_dtype=(
- torch.float32
- if get_global_server_args().rl_on_policy_target is not None
- else None
- ),
)
else:
self.embed_tokens = PPMissingLayer()
@@ -306,10 +298,8 @@ class Qwen2Model(nn.Module):
if self.pp_group.is_last_rank:
norm_kwargs = (
dict(
- weight_dtype=torch.float32,
cast_x_before_out_mul=True,
- override_orig_dtype=torch.float32,
- fp32_residual=True,
+ fp32_residual=False,
)
if get_global_server_args().rl_on_policy_target is not None
else {}
@@ -586,7 +586,17 @@ class Qwen2MoeModel(nn.Module):
prefix=add_prefix("layers", prefix),
)
if self.pp_group.is_last_rank:
- self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ norm_kwargs = (
+ dict(
+ cast_x_before_out_mul=True,
+ fp32_residual=False,
+ )
+ if get_global_server_args().rl_on_policy_target is not None
+ else {}
+ )
+ self.norm = RMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs
+ )
else:
self.norm = PPMissingLayer(return_tuple=True)
@@ -90,8 +90,8 @@ class Qwen3Attention(nn.Module):
norm_kwargs = (
dict(
- weight_dtype=torch.float32,
cast_x_before_out_mul=True,
+ fp32_residual=False,
)
if get_global_server_args().rl_on_policy_target is not None
else {}
@@ -242,10 +242,8 @@ class Qwen3DecoderLayer(nn.Module):
norm_kwargs = (
dict(
- weight_dtype=torch.float32,
cast_x_before_out_mul=True,
- override_orig_dtype=torch.float32,
- fp32_residual=True,
+ fp32_residual=False,
)
if get_global_server_args().rl_on_policy_target is not None
else {}
@@ -22,6 +22,7 @@ import math
from typing import Any, Dict, Iterable, List, Optional, Tuple, TypeVar
import torch
+import torch.nn.functional as F
from torch import nn
from transformers import PretrainedConfig
@@ -50,7 +51,7 @@ from sglang.srt.layers.moe import (
)
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
-from sglang.srt.layers.moe.topk import TopK
+from sglang.srt.layers.moe.topk import StandardTopKOutput, TopK
from sglang.srt.layers.moe.utils import RoutingMethodType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
@@ -229,6 +230,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
use_grouped_topk=False,
layer_id=layer_id,
)
+ self.top_k = config.num_experts_per_tok
self.experts = get_moe_impl_class(quant_config)(
num_experts=config.num_experts
@@ -294,7 +296,22 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
- topk_output = self.topk(hidden_states, router_logits)
+
+ if get_global_server_args().rl_on_policy_target is not None:
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
+ routing_weights, selected_experts = torch.topk(
+ routing_weights, self.top_k, dim=-1
+ )
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
+ routing_weights = routing_weights.to(hidden_states.dtype)
+ topk_output = StandardTopKOutput(
+ topk_weights=routing_weights,
+ topk_ids=selected_experts,
+ router_logits=router_logits,
+ )
+ else:
+ topk_output = self.topk(hidden_states, router_logits)
+
final_hidden_states = self.experts(hidden_states, topk_output)
if (
self.tp_size > 1
@@ -475,13 +492,14 @@ class Qwen3MoeAttention(nn.Module):
)
self.compatible_with_fused_kv_buffer = (
False if isinstance(self.rotary_emb, MRotaryEmbedding) else True
- )
+ ) and (get_global_server_args().rl_on_policy_target is None)
self.compatible_with_fused_qk_norm_rope = (
not isinstance(self.rotary_emb, MRotaryEmbedding)
) and self.head_dim in (64, 128, 256)
self.use_fused_qk_norm_rope = (
get_global_server_args().enable_fused_qk_norm_rope
and self.compatible_with_fused_qk_norm_rope
+ and (get_global_server_args().rl_on_policy_target is None)
)
self._used_fused_qk_norm_rope_last_call = False
@@ -494,8 +512,16 @@ class Qwen3MoeAttention(nn.Module):
prefix=add_prefix("attn", prefix),
)
- self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
- self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
+ norm_kwargs = (
+ dict(
+ cast_x_before_out_mul=True,
+ fp32_residual=False,
+ )
+ if get_global_server_args().rl_on_policy_target is not None
+ else {}
+ )
+ self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps, **norm_kwargs)
+ self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps, **norm_kwargs)
self.alt_stream = alt_stream
def op_prepare(self, state):
@@ -736,9 +762,19 @@ class Qwen3MoeDecoderLayer(nn.Module):
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
- self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ norm_kwargs = (
+ dict(
+ cast_x_before_out_mul=True,
+ fp32_residual=False,
+ )
+ if get_global_server_args().rl_on_policy_target is not None
+ else {}
+ )
+ self.input_layernorm = RMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs
+ )
self.post_attention_layernorm = RMSNorm(
- config.hidden_size, eps=config.rms_norm_eps
+ config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs
)
self.layer_communicator = LayerCommunicator(
@@ -397,28 +397,68 @@ class Qwen3VLMoeVisionModel(nn.Module, RotaryPosMixin):
return cos_combined, sin_combined
def fast_pos_embed_interpolate(self, grid_thw):
- patch_pos_embeds_permute = []
- m_size = self.spatial_merge_size
+ grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2]
+ num_grid_per_side = int(self.num_position_embeddings**0.5)
+ device = self.pos_embed.weight.device
+
+ idx_list = [[] for _ in range(4)]
+ weight_list = [[] for _ in range(4)]
+
+ for t, h, w in zip(grid_ts, grid_hs, grid_ws):
+ h_idxs = torch.linspace(0, num_grid_per_side - 1, h)
+ w_idxs = torch.linspace(0, num_grid_per_side - 1, w)
+
+ h_idxs_floor = h_idxs.int()
+ w_idxs_floor = w_idxs.int()
+ h_idxs_ceil = (h_idxs.int() + 1).clip(max=num_grid_per_side - 1)
+ w_idxs_ceil = (w_idxs.int() + 1).clip(max=num_grid_per_side - 1)
+
+ dh = h_idxs - h_idxs_floor
+ dw = w_idxs - w_idxs_floor
+
+ base_h = h_idxs_floor * num_grid_per_side
+ base_h_ceil = h_idxs_ceil * num_grid_per_side
+
+ indices = [
+ (base_h[None].T + w_idxs_floor[None]).flatten(),
+ (base_h[None].T + w_idxs_ceil[None]).flatten(),
+ (base_h_ceil[None].T + w_idxs_floor[None]).flatten(),
+ (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(),
+ ]
+
+ weights = [
+ ((1 - dh)[None].T * (1 - dw)[None]).flatten(),
+ ((1 - dh)[None].T * dw[None]).flatten(),
+ (dh[None].T * (1 - dw)[None]).flatten(),
+ (dh[None].T * dw[None]).flatten(),
+ ]
- embeds = torch.arange(self.num_grid, device=self.pos_embed.weight.device)
- embeds = (
- self.pos_embed(embeds)
- .permute(1, 0)
- .reshape(1, -1, self.num_grid_per_side, self.num_grid_per_side)
+ for i in range(4):
+ idx_list[i].extend(indices[i].tolist())
+ weight_list[i].extend(weights[i].tolist())
+
+ idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device)
+ weight_tensor = torch.tensor(
+ weight_list, dtype=self.pos_embed.weight.dtype, device=device
)
- for t, h, w in grid_thw:
- pos_embed = torch.nn.functional.interpolate(
- embeds, size=(h, w), mode="bilinear", align_corners=self.align_corners
- )
- pos_embed = pos_embed.reshape(
- -1,
- h // self.spatial_merge_size,
- self.spatial_merge_size,
- w // self.spatial_merge_size,
- self.spatial_merge_size,
+ pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None]
+ patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
+
+ patch_pos_embeds = patch_pos_embeds.split(
+ [h * w for h, w in zip(grid_hs, grid_ws)]
+ )
+
+ patch_pos_embeds_permute = []
+ merge_size = self.spatial_merge_size
+ for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws):
+ pos_embed = pos_embed.repeat(t, 1)
+ pos_embed = (
+ pos_embed.view(
+ t, h // merge_size, merge_size, w // merge_size, merge_size, -1
+ )
+ .permute(0, 1, 3, 2, 4, 5)
+ .flatten(0, 4)
)
- pos_embed = pos_embed.permute(1, 3, 2, 4, 0)
- pos_embed = pos_embed.flatten(0, 3).repeat(t, 1)
patch_pos_embeds_permute.append(pos_embed)
return torch.cat(patch_pos_embeds_permute)
@@ -610,14 +650,19 @@ class Qwen3LLMModel(Qwen3Model):
hidden_states + residual if residual is not None else hidden_states
)
+ deepstack_embeds = None
+ if input_deepstack_embeds is not None:
+ prev_layer_idx = layer_idx - 1
+ if prev_layer_idx in self.deepstack_embed_to_decoder_layer:
+ sep = self.hidden_size * prev_layer_idx
+ deepstack_embeds = input_deepstack_embeds[
+ :, sep : sep + self.hidden_size
+ ]
+
# SGLang applies residual at the START of the next layer, not at the END like HuggingFace.
# See: https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py#L549
# To match HF behavior, deepstack must be added AFTER residual: (hidden_states + residual) + deepstack
# The order matters because addition with different tensors is not associative in practice.
- # Deepstack for prev_layer is applied at the start of current layer via post_residual_addition.
- deepstack_embeds = self.get_deepstack_embeds(
- layer_idx - 1, input_deepstack_embeds
- )
hidden_states, residual = layer(
positions,
hidden_states,
@@ -527,6 +527,7 @@ class ServerArgs:
cuda_graph_max_bs: Optional[int] = None
cuda_graph_bs: Optional[List[int]] = None
disable_cuda_graph: bool = False
+ disable_draft_cuda_graph: bool = False
disable_cuda_graph_padding: bool = False
enable_profile_cuda_graph: bool = False
enable_cudagraph_gc: bool = False
@@ -3980,6 +3981,11 @@ class ServerArgs:
action="store_true",
help="Disable cuda graph.",
)
+ parser.add_argument(
+ "--disable-draft-cuda-graph",
+ action="store_true",
+ help="Disable cuda graph for draft model in speculative decoding.",
+ )
parser.add_argument(
"--disable-cuda-graph-padding",
action="store_true",
@@ -341,7 +341,10 @@ class EAGLEDraftCudaGraphRunner:
self.seq_lens.fill_(self.seq_len_fill_value)
self.out_cache_loc.zero_()
self.positions.zero_()
-
+ self.topk_p.zero_()
+ self.topk_index.zero_()
+ self.hidden_states.zero_()
+ self.req_pool_indices.zero_()
num_tokens = bs * self.num_tokens_per_bs
# Common inputs
@@ -350,8 +353,8 @@ class EAGLEDraftCudaGraphRunner:
forward_batch.out_cache_loc
)
self.positions[:raw_num_token].copy_(forward_batch.positions)
- self.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p)
- self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
+ self.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p.clamp(0, 1))
+ self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index.clamp(0, self.model_runner.model_config.vocab_size - 1))
self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
@@ -778,6 +778,10 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
self.topk_index = self.topk_index[: len(new_indices)]
self.hidden_states = self.hidden_states[: len(new_indices)]
self.verified_id = self.verified_id[: len(new_indices)]
+ if self.accept_length is not None:
+ self.accept_length = self.accept_length[: len(new_indices)]
+ if self.accept_length_cpu is not None:
+ self.accept_length_cpu = self.accept_length_cpu[: len(new_indices)]
else:
# in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index`
self.topk_p = self.topk_p[new_indices]
@@ -809,6 +813,27 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
self.topk_p = torch.cat([self.topk_p, spec_info.topk_p])
self.topk_index = torch.cat([self.topk_index, spec_info.topk_index])
+ if self.accept_length is not None and spec_info.accept_length is not None:
+ self.accept_length = torch.cat(
+ [self.accept_length, spec_info.accept_length]
+ )
+ self.accept_length_cpu = self.accept_length.tolist()
+ elif self.accept_length is not None:
+ zeros = torch.zeros(
+ [spec_info.verified_id.shape[0]],
+ dtype=self.accept_length.dtype,
+ device=self.accept_length.device,
+ )
+ self.accept_length = torch.cat([self.accept_length, zeros])
+ self.accept_length_cpu = self.accept_length.tolist()
+ elif spec_info.accept_length is not None:
+ zeros = torch.zeros(
+ [self.verified_id.shape[0]],
+ dtype=self.accept_length.dtype,
+ device=self.accept_length.device,
+ )
+ self.accept_length = torch.cat([zeros, spec_info.accept_length])
+ self.accept_length_cpu = self.accept_length.tolist()
@dataclass
@@ -231,7 +231,7 @@ class EAGLEWorker(TpModelWorker):
self.cuda_graph_runner = None
self.cuda_graph_runner_for_draft_extend = None
- if self.server_args.disable_cuda_graph:
+ if self.server_args.disable_cuda_graph or self.server_args.disable_draft_cuda_graph:
return
Device2DraftCudaGraphRunner = {
@@ -2224,6 +2224,8 @@ class SafeUnpickler(pickle.Unpickler):
"sglang.srt.model_executor.model_runner.",
"sglang.srt.layers.",
"sglang.srt.utils.",
+ # --- slime ---
+ "slime.",
}
DENY_CLASSES = {