diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py
index aa10cb08d..d41c31a09 100644
--- a/python/sglang/srt/configs/model_config.py
+++ b/python/sglang/srt/configs/model_config.py
@@ -268,6 +268,12 @@ class ModelConfig:
         ):
             self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN"
 
+        if (
+            is_draft_model
+            and self.hf_config.architectures[0] == "DeepseekV32ForCausalLM"
+        ):
+            self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN"
+
         if is_draft_model and self.hf_config.architectures[0] == "Glm4MoeForCausalLM":
             self.hf_config.architectures[0] = "Glm4MoeForCausalLMNextN"
 
diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py
index 51af67636..54716de5c 100644
--- a/python/sglang/srt/disaggregation/decode.py
+++ b/python/sglang/srt/disaggregation/decode.py
@@ -315,6 +315,13 @@ class DecodePreallocQueue:
         )
         return kv_manager
 
+    def release_memory_occupation(self):
+        if hasattr(self.kv_manager, "close"):
+            self.kv_manager.close()
+
+    def resume_memory_occupation(self):
+        self.kv_manager = self._init_kv_manager()
+
     def add(self, req: Req, is_retracted: bool = False) -> None:
         """Add a request to the pending queue."""
         if self._check_if_req_exceed_kv_capacity(req):
diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py
index 32e8c0b69..df913da7b 100644
--- a/python/sglang/srt/disaggregation/mooncake/conn.py
+++ b/python/sglang/srt/disaggregation/mooncake/conn.py
@@ -1079,6 +1079,19 @@ class MooncakeKVManager(CommonKVManager):
             f"Losing connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}), {len(affected_rooms)} requests affected"
         )
 
+    def close(self):
+        # Batch deregister KV data buffers
+        if self.kv_args.kv_data_ptrs:
+            self.engine.batch_deregister(self.kv_args.kv_data_ptrs)
+
+        # Batch deregister auxiliary data buffers
+        if self.kv_args.aux_data_ptrs:
+            self.engine.batch_deregister(self.kv_args.aux_data_ptrs)
+
+        # Batch deregister state/extra pool data buffers
+        if self.kv_args.state_data_ptrs:
+            self.engine.batch_deregister(self.kv_args.state_data_ptrs)
+
 
 class MooncakeKVSender(CommonKVSender):
 
diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py
index a6eed743a..0124d8917 100644
--- a/python/sglang/srt/disaggregation/prefill.py
+++ b/python/sglang/srt/disaggregation/prefill.py
@@ -306,6 +306,13 @@ class PrefillBootstrapQueue:
         else:
             return bootstrapped_reqs, failed_reqs
 
+    def release_memory_occupation(self):
+        if hasattr(self.kv_manager, "close"):
+            self.kv_manager.close()
+
+    def resume_memory_occupation(self):
+        self.kv_manager = self._init_kv_manager()
+
 
 class SchedulerDisaggregationPrefillMixin:
     """
diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py
index 0478526ef..cfb1aa669 100644
--- a/python/sglang/srt/distributed/parallel_state.py
+++ b/python/sglang/srt/distributed/parallel_state.py
@@ -1797,7 +1797,10 @@ def get_tensor_model_parallel_world_size():
 
 def get_tensor_model_parallel_rank():
     """Return my rank for the tensor model parallel group."""
-    return get_tp_group().rank_in_group
+    try:
+        return get_tp_group().rank_in_group
+    except Exception:
+        return 0
 
 
 def get_pipeline_model_parallel_world_size():
diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py
index 6f69fd19b..da20ac2ed 100644
--- a/python/sglang/srt/entrypoints/engine.py
+++ b/python/sglang/srt/entrypoints/engine.py
@@ -49,6 +49,7 @@ from sglang.srt.managers.io_struct import (
     InitWeightsUpdateGroupReqInput,
     LoadLoRAAdapterReqInput,
     MultimodalDataInputFormat,
+    PostProcessWeightsReqInput,
     ReleaseMemoryOccupationReqInput,
     ResumeMemoryOccupationReqInput,
     RpcReqInput,
@@ -593,6 +594,24 @@ class Engine(EngineBase):
             self.tokenizer_manager.update_weights_from_ipc(obj, None)
         )
 
+    def post_process_weights(
+        self,
+        restore_weights_before_load: bool = False,
+        post_process_quantization: bool = False,
+    ):
+        """
+        Optional post-processing for updated weights (e.g., Marlin conversion).
+        Should be called after weight update is finished.
+        """
+        obj = PostProcessWeightsReqInput(
+            restore_weights_before_load=restore_weights_before_load,
+            post_process_quantization=post_process_quantization,
+        )
+
+        return self.loop.run_until_complete(
+            self.tokenizer_manager.post_process_weights(obj, None)
+        )
+
     def get_weights_by_name(self, name: str, truncate_size: int = 100):
         """Get weights by parameter name."""
         obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py
index 88705cc35..c8dc052f1 100644
--- a/python/sglang/srt/entrypoints/http_server.py
+++ b/python/sglang/srt/entrypoints/http_server.py
@@ -107,6 +107,7 @@ from sglang.srt.managers.io_struct import (
     OpenSessionReqInput,
     ParseFunctionCallReq,
     PauseGenerationReqInput,
+    PostProcessWeightsReqInput,
     ProfileReqInput,
     ReleaseMemoryOccupationReqInput,
     ResumeMemoryOccupationReqInput,
@@ -957,6 +958,21 @@ async def update_weights_from_ipc(obj: UpdateWeightsFromIPCReqInput, request: Re
     else:
         return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
 
+@app.post("/post_process_weights")
+async def post_process_weights(req: PostProcessWeightsReqInput, request: Request):
+    """
+    Optional post-processing for updated weights (e.g., Marlin conversion).
+    This should be called selectively after `update_weights_from_distributed/update_weights_from_tensor`.
+    """
+    success, message = await _global_state.tokenizer_manager.post_process_weights(
+        req, request
+    )
+
+    content = {"success": success, "message": message}
+    return ORJSONResponse(
+        content, status_code=200 if success else HTTPStatus.BAD_REQUEST
+    )
+
 
 @app.post("/update_weight_version")
 async def update_weight_version(obj: UpdateWeightVersionReqInput, request: Request):
diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py
index c9e82e4b1..f2584546a 100644
--- a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py
+++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py
@@ -3,6 +3,7 @@ from __future__ import annotations
 from abc import ABC, abstractmethod
 from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
 
+import os
 import torch
 from einops import rearrange
 
@@ -178,7 +179,7 @@ class Indexer(MultiPlatformOp):
             max_position=max_position_embeddings,
             base=rope_theta,  # type: ignore
             rope_scaling=rope_scaling,
-            is_neox_style=True,
+            is_neox_style=True if os.environ.get("INDEXER_ROPE_NEOX_STYLE", "1") == "1" else False,
             device=get_global_server_args().device,
         )
         self.block_size = block_size
@@ -188,6 +189,9 @@ class Indexer(MultiPlatformOp):
     @torch.compile(dynamic=True)
     def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor):
         weights, _ = self.weights_proj(x.float())
+        if weights.shape[1] < 32:
+            assert 32 % weights.shape[1] == 0
+            weights = weights.repeat_interleave(32 // weights.shape[1], dim=1)
         weights = weights * self.n_heads**-0.5
         weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
         return weights
@@ -837,6 +841,9 @@ class Indexer(MultiPlatformOp):
         query, key = self._get_q_k_bf16(
             q_lora, x, positions, enable_dual_stream, forward_batch=forward_batch
         )
+        if query.shape[1] < 32:
+            assert 32 % query.shape[1] == 0
+            query = query.repeat_interleave(32//query.shape[1], dim=1)
 
         if enable_dual_stream:
             current_stream = torch.cuda.current_stream()
diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py
index 7bef9d2ab..f588cbdb0 100644
--- a/python/sglang/srt/layers/layernorm.py
+++ b/python/sglang/srt/layers/layernorm.py
@@ -83,15 +83,12 @@ class RMSNorm(MultiPlatformOp):
         eps: float = 1e-6,
         var_hidden_size: Optional[int] = None,
         cast_x_before_out_mul: bool = False,
-        fp32_residual: bool = False,
-        weight_dtype: Optional = None,
-        override_orig_dtype: Optional = None,
+        fp32_residual: bool = True,
     ) -> None:
         super().__init__()
         self.cast_x_before_out_mul = cast_x_before_out_mul
         self.fp32_residual = fp32_residual
-        self.override_orig_dtype = override_orig_dtype
-        self.weight = nn.Parameter(torch.ones(hidden_size, dtype=weight_dtype))
+        self.weight = nn.Parameter(torch.ones(hidden_size))
         self.variance_epsilon = eps
         self.hidden_size = hidden_size
         self.variance_size_override = (
@@ -193,10 +190,22 @@ class RMSNorm(MultiPlatformOp):
     ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
         if not x.is_contiguous():
             x = x.contiguous()
-        orig_dtype = self.override_orig_dtype or x.dtype
+        orig_dtype = x.dtype
         post_residual_addition = kwargs.get("post_residual_addition")
+
+        if residual is not None and not self.fp32_residual:
+            x = (
+                x
+                + residual
+                + (
+                    post_residual_addition
+                    if post_residual_addition is not None
+                    else 0.0
+                )
+            )
+            residual = x.clone()
         x = x.to(torch.float32)
-        if residual is not None:
+        if residual is not None and self.fp32_residual:
             x = (
                 x
                 + residual.to(torch.float32)
@@ -206,10 +215,7 @@ class RMSNorm(MultiPlatformOp):
                     else 0.0
                 )
             )
-            if self.fp32_residual:
-                residual = x.clone()
-            else:
-                residual = x.to(orig_dtype)
+            residual = x.to(orig_dtype)
 
         hidden_size = x.shape[-1]
         if hidden_size != self.hidden_size:
diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py
index fa7431048..cd33ea735 100644
--- a/python/sglang/srt/layers/logits_processor.py
+++ b/python/sglang/srt/layers/logits_processor.py
@@ -878,11 +878,6 @@ class LogitsProcessor(nn.Module):
                     None,  # bias
                     True,  # is_vnni
                 )
-            elif get_global_server_args().rl_on_policy_target is not None:
-                # Due to tie-weight, we may not be able to change lm_head's weight dtype
-                logits = torch.matmul(
-                    hidden_states.bfloat16(), lm_head.weight.T.bfloat16()
-                )
             else:
                 logits = torch.matmul(
                     hidden_states.to(lm_head.weight.dtype), lm_head.weight.T
diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
index a1885fade..14d692365 100644
--- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
+++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
@@ -14,6 +14,7 @@ import torch.nn.functional as F
 import triton.language as tl
 
 from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
+from sglang.srt.server_args import get_global_server_args
 from sglang.srt.utils import (
     cpu_has_amx_support,
     get_bool_env_var,
@@ -573,7 +574,10 @@ def fused_experts_impl(
                 ).squeeze(dim=1)
             else:
                 # According to micro benchmark results, torch.compile can get better performance for small token.
-                if tokens_in_chunk <= 32:
+                if (
+                    not get_global_server_args().enable_deterministic_inference
+                    and tokens_in_chunk <= 32
+                ):
                     moe_sum_reduce_torch_compile(
                         intermediate_cache3.view(*intermediate_cache3.shape),
                         out_hidden_states[begin_chunk_idx:end_chunk_idx],
diff --git a/python/sglang/srt/layers/moe/routed_experts_capturer.py b/python/sglang/srt/layers/moe/routed_experts_capturer.py
index 00bd68755..5a3ca8a67 100644
--- a/python/sglang/srt/layers/moe/routed_experts_capturer.py
+++ b/python/sglang/srt/layers/moe/routed_experts_capturer.py
@@ -1,5 +1,6 @@
 import logging
 from abc import ABC
+from contextlib import contextmanager
 from typing import Optional
 
 import numpy as np
@@ -8,13 +9,18 @@ import torch
 
 from sglang.srt.configs.model_config import ModelConfig
 from sglang.srt.layers.dp_attention import (
+    attn_tp_all_gather_into_tensor,
     get_attention_dp_rank,
+    get_attention_tp_size,
     get_dp_local_info,
     is_dp_attention_enabled,
 )
 from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
 from sglang.srt.model_executor.forward_batch_info import ForwardBatch
 from sglang.srt.server_args import get_global_server_args
+from sglang.srt.layers.moe import (
+    get_moe_a2a_backend,
+)
 
 logger = logging.getLogger(__name__)
 
@@ -181,13 +187,26 @@ class _RoutedExpertsCapturerReal(RoutedExpertsCapturer):
             device=device,
         )
 
+        if get_moe_a2a_backend().is_deepep():
+            attn_tp_size = get_attention_tp_size() if is_dp_attention_enabled() else 1
+            self.gather_buffer = torch.empty(
+                (
+                    self.device_cache.buffer.shape[0] * attn_tp_size,
+                    self.device_cache.buffer.shape[2],
+                ),
+                dtype=torch.int32,
+                device=device,
+            )
+
     def _sync_fwd_experts_buffer_DtoH(
         self,
         forward_batch: ForwardBatch,
         can_run_graph: bool,
         cuda_graph_batch: int,
     ):
-        if is_dp_attention_enabled():
+        # When DeepEP is enabled, capture() already does all_gather, so device_cache.buffer
+        # contains data from all DP ranks. We should not slice by DP rank in this case.
+        if is_dp_attention_enabled() and not get_moe_a2a_backend().is_deepep():
             local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
             # handle with cuda graph padding
             if can_run_graph:
@@ -206,6 +225,12 @@ class _RoutedExpertsCapturerReal(RoutedExpertsCapturer):
         ].cpu()
 
     def capture(self, layer_id: int, topk_ids: torch.Tensor):
+        if get_moe_a2a_backend().is_deepep():
+            local_topk_ids = topk_ids
+            topk_ids = self.gather_buffer[
+                : local_topk_ids.size(0) * get_attention_tp_size()
+            ]
+            attn_tp_all_gather_into_tensor(topk_ids, local_topk_ids)
         self.device_cache.capture_fwd_routed_experts(layer_id, topk_ids)
 
     def get_routed_experts(
diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
index c5e5a11fc..dd321fa13 100644
--- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
+++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
@@ -1016,13 +1016,37 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
         layer.a2_scale = None
         layer.marlin_state = GPTQMarlinState.REPACK
 
+        if not hasattr(layer, "_original_shapes"):
+            layer._original_shapes = {}
+
+        # Force record: these are the target GPTQ shapes for rollback.
+        layer._original_shapes["w13_weight_packed"] = tuple(w13_weight.shape)
+        layer._original_shapes["w2_weight_packed"] = tuple(w2_weight.shape)
+
+        # Also record the shapes of the scales.
+        layer._original_shapes["w2_weight_scale"] = tuple(w2_scale.shape)
+        layer._original_shapes["w13_weight_scale"] = tuple(w13_scale.shape)
+
     def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
+        # Skip if the layer is already converted to Marlin format to prevent double-packing.
+        if getattr(layer, "is_marlin_converted", False):
+            return
+
+        if not hasattr(layer, "_original_shapes"):
+            layer._original_shapes = {}
 
         def replace_tensor(name, new_t):
+            target_attr = getattr(layer, name)
+
+            # Only save if the key doesn't exist to prevent overwriting with Marlin shapes.
+            if name not in layer._original_shapes:
+                # This is a safety check; `create_weights` usually handles this already.
+                layer._original_shapes[name] = tuple(target_attr.shape)
+
             # It is important to use resize_() here since it ensures
             # the same buffer is reused
-            getattr(layer, name).resize_(new_t.shape)
-            getattr(layer, name).copy_(new_t)
+            target_attr.resize_(new_t.shape)
+            target_attr.copy_(new_t)
             del new_t
 
         num_experts = layer.w13_weight_g_idx.shape[0]
@@ -1078,7 +1102,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
             layer.w13_weight_packed.shape[2],
             self.num_bits,
         )
-        replace_parameter(layer, "w13_weight_packed", marlin_w13_qweight)
+        replace_tensor("w13_weight_packed", marlin_w13_qweight)
         marlin_w2_qweight = gptq_marlin_moe_repack(
             layer.w2_weight_packed,
             layer.w2_g_idx_sort_indices,
@@ -1086,7 +1110,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
             layer.w2_weight_packed.shape[2],
             self.num_bits,
         )
-        replace_parameter(layer, "w2_weight_packed", marlin_w2_qweight)
+        replace_tensor("w2_weight_packed", marlin_w2_qweight)
         # Repack scales
         marlin_w13_scales = marlin_moe_permute_scales(
             layer.w13_weight_scale,
@@ -1094,7 +1118,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
             layer.w13_weight_scale.shape[2],
             self.group_size,
         )
-        replace_parameter(layer, "w13_weight_scale", marlin_w13_scales)
+        replace_tensor("w13_weight_scale", marlin_w13_scales)
 
         marlin_w2_scales = marlin_moe_permute_scales(
             layer.w2_weight_scale,
@@ -1103,7 +1127,22 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
             layer.w2_weight_scale.shape[2],
             self.group_size,
         )
-        replace_parameter(layer, "w2_weight_scale", marlin_w2_scales)
+        replace_tensor("w2_weight_scale", marlin_w2_scales)
+
+        layer.is_marlin_converted = True
+
+    def restore_weights_before_loading(self, layer: torch.nn.Module):
+        """Forcibly resize parameters back to their original shapes (e.g., GPTQ format) before loading weights."""
+        if not hasattr(layer, "_original_shapes"):
+            return
+
+        for name, orig_shape in layer._original_shapes.items():
+            param = getattr(layer, name, None)
+
+            if param is not None and param.shape != orig_shape:
+                param.resize_(orig_shape)
+
+        layer.is_marlin_converted = False
 
     def create_moe_runner(
         self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py
index 480579e01..dd8ca7d4f 100644
--- a/python/sglang/srt/layers/rotary_embedding.py
+++ b/python/sglang/srt/layers/rotary_embedding.py
@@ -136,9 +136,7 @@ class RotaryEmbedding(MultiPlatformOp):
 
         if get_global_server_args().rl_on_policy_target is not None:
             self._forward_method = self.forward_native
-            self._apply_rotary_emb_wrapped = torch.compile(dynamic=True)(
-                self._apply_rotary_emb_wrapped
-            )
+
         self.position_cos, self.position_sin = None, None
 
     def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
@@ -1578,6 +1576,9 @@ class MRotaryEmbedding(RotaryEmbedding):
         key: torch.Tensor,
         fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
+        assert (
+            fused_set_kv_buffer_arg is None
+        ), "fused_set_kv_buffer_arg is not supported for npu implementation"
         # TODO: remove this when npu_mrope supports QNumHeads * QHeadSize > 4096
         assert (
             fused_set_kv_buffer_arg is None
diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py
index 55bef5652..35ad68b1c 100644
--- a/python/sglang/srt/layers/sampler.py
+++ b/python/sglang/srt/layers/sampler.py
@@ -108,16 +108,11 @@ class Sampler(nn.Module):
             if return_logprob and SGLANG_RETURN_ORIGINAL_LOGPROB:
                 probs_without_temp_scaling = torch.softmax(logits, dim=-1)
 
-            if get_global_server_args().rl_on_policy_target is not None:
-                logits_div_temperature = (
-                    logits.bfloat16().div(sampling_info.temperatures).bfloat16()
-                )
-                logprobs_via_logsoftmax_kernel = torch.log_softmax(
-                    logits_div_temperature, dim=-1
-                )
-
             # Post process logits
             logits.div_(sampling_info.temperatures)
+            if get_global_server_args().rl_on_policy_target is not None:
+                logprobs_via_logsoftmax_kernel = torch.log_softmax(logits, dim=-1)
+
             # For ascend backend, softmax is not needed before sampling
             if not get_global_server_args().sampling_backend == "ascend" or (
                 return_logprob and not SGLANG_RETURN_ORIGINAL_LOGPROB
diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py
index 2ecd8542f..89ef8200d 100644
--- a/python/sglang/srt/managers/io_struct.py
+++ b/python/sglang/srt/managers/io_struct.py
@@ -1292,6 +1292,19 @@ class UpdateWeightsFromIPCReqOutput(BaseReq):
     success: bool
     message: str
 
+@dataclass
+class PostProcessWeightsReqInput(BaseReq):
+    # Whether to restore weights before loading new weights
+    restore_weights_before_load: bool = False
+    # Whether to enable quantization post-processing
+    post_process_quantization: bool = False
+
+
+@dataclass
+class PostProcessWeightsReqOutput(BaseReq):
+    success: bool
+    message: str
+
 
 @dataclass
 class InitWeightsSendGroupForRemoteInstanceReqOutput(BaseReq):
diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py
index d423e61d7..9156d543c 100644
--- a/python/sglang/srt/managers/schedule_batch.py
+++ b/python/sglang/srt/managers/schedule_batch.py
@@ -2186,7 +2186,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
     def __str__(self):
         return (
             f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
-            f"#req={(len(self.reqs))})"
+            f"#req={(len(self.reqs))}), "
+            f"#out_cache_loc={self.out_cache_loc})"
         )
 
 
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
index 92d286897..43bfab691 100644
--- a/python/sglang/srt/managers/scheduler.py
+++ b/python/sglang/srt/managers/scheduler.py
@@ -98,6 +98,7 @@ from sglang.srt.managers.io_struct import (
     OpenSessionReqInput,
     OpenSessionReqOutput,
     PauseGenerationReqInput,
+    PostProcessWeightsReqInput,
     ProfileReq,
     ReleaseMemoryOccupationReqInput,
     ResumeMemoryOccupationReqInput,
@@ -1060,6 +1061,7 @@ class Scheduler(
                 ),
                 (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
                 (UpdateWeightsFromIPCReqInput, self.update_weights_from_ipc),
+                (PostProcessWeightsReqInput, self.post_process_weights),
                 (GetWeightsByNameReqInput, self.get_weights_by_name),
                 (ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
                 (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py
index e40586c24..243e2b0c2 100644
--- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py
+++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py
@@ -10,6 +10,7 @@ from sglang.srt.disaggregation.utils import DisaggregationMode
 from sglang.srt.environ import envs
 from sglang.srt.layers.logits_processor import LogitsProcessorOutput
 from sglang.srt.layers.moe.routed_experts_capturer import get_global_experts_capturer
+
 from sglang.srt.managers.io_struct import (
     AbortReq,
     BatchEmbeddingOutput,
@@ -1070,7 +1071,7 @@ class SchedulerOutputProcessorMixin:
                 req.log_time_stats()
 
         # Send to detokenizer
-        if reqs or is_idle_batch:
+        if rids or is_idle_batch:
             if self.model_config.is_multimodal_gen:
                 return
 
diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py
index 293a84350..c3a618bcc 100644
--- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py
+++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py
@@ -1,6 +1,7 @@
 from __future__ import annotations
 
 import logging
+import os
 import traceback
 from typing import TYPE_CHECKING, Tuple
 
@@ -12,6 +13,9 @@ from sglang.srt.constants import (
     GPU_MEMORY_TYPE_KV_CACHE,
     GPU_MEMORY_TYPE_WEIGHTS,
 )
+from sglang.srt.disaggregation.utils import DisaggregationMode
+from sglang.srt.distributed import get_moe_ep_group, get_moe_tp_group, get_tp_group
+from sglang.srt.layers.dp_attention import get_attention_tp_group
 from sglang.srt.managers.io_struct import (
     CheckWeightsReqInput,
     CheckWeightsReqOutput,
@@ -21,6 +25,8 @@ from sglang.srt.managers.io_struct import (
     GetWeightsByNameReqOutput,
     InitWeightsUpdateGroupReqInput,
     InitWeightsUpdateGroupReqOutput,
+    PostProcessWeightsReqInput,
+    PostProcessWeightsReqOutput,
     ReleaseMemoryOccupationReqInput,
     ReleaseMemoryOccupationReqOutput,
     ResumeMemoryOccupationReqInput,
@@ -114,6 +120,11 @@ class SchedulerUpdateWeightsMixin:
         torch.distributed.barrier(group=self.tp_cpu_group)
         return UpdateWeightsFromIPCReqOutput(success, message)
 
+    def post_process_weights(self, recv_req: PostProcessWeightsReqInput):
+        """Optional post-processing for updated weights (e.g., Marlin conversion)."""
+        success, message = self.tp_worker.post_process_weights(recv_req)
+        return PostProcessWeightsReqOutput(success, message)
+
     def get_weights_by_name(self: Scheduler, recv_req: GetWeightsByNameReqInput):
         parameter = self.tp_worker.get_weights_by_name(recv_req)
         return GetWeightsByNameReqOutput(parameter)
@@ -137,6 +148,13 @@ class SchedulerUpdateWeightsMixin:
             self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
             self.flush_cache()
 
+            if self.disaggregation_mode == DisaggregationMode.DECODE:
+                if hasattr(self, "disagg_decode_prealloc_queue"):
+                    self.disagg_decode_prealloc_queue.release_memory_occupation()
+            elif self.disaggregation_mode == DisaggregationMode.PREFILL:
+                if hasattr(self, "disagg_prefill_bootstrap_queue"):
+                    self.disagg_prefill_bootstrap_queue.release_memory_occupation()
+
         if GPU_MEMORY_TYPE_WEIGHTS in tags:
             self.stashed_model_static_state = _export_static_state(
                 self.tp_worker.model_runner.model
@@ -177,6 +195,13 @@ class SchedulerUpdateWeightsMixin:
         if GPU_MEMORY_TYPE_KV_CACHE in tags:
             self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE)
 
+            if self.disaggregation_mode == DisaggregationMode.DECODE:
+                if hasattr(self, "disagg_decode_prealloc_queue"):
+                    self.disagg_decode_prealloc_queue.resume_memory_occupation()
+            elif self.disaggregation_mode == DisaggregationMode.PREFILL:
+                if hasattr(self, "disagg_prefill_bootstrap_queue"):
+                    self.disagg_prefill_bootstrap_queue.resume_memory_occupation()
+
         return ResumeMemoryOccupationReqOutput()
 
     def check_weights(self: Scheduler, recv_req: CheckWeightsReqInput):
diff --git a/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/python/sglang/srt/managers/tokenizer_communicator_mixin.py
index e5d42bed8..412293b30 100644
--- a/python/sglang/srt/managers/tokenizer_communicator_mixin.py
+++ b/python/sglang/srt/managers/tokenizer_communicator_mixin.py
@@ -49,6 +49,8 @@ from sglang.srt.managers.io_struct import (
     LoadLoRAAdapterReqOutput,
     LoRAUpdateOutput,
     OpenSessionReqInput,
+    PostProcessWeightsReqInput,
+    PostProcessWeightsReqOutput,
     ProfileReq,
     ProfileReqOutput,
     ProfileReqType,
@@ -177,6 +179,9 @@ class TokenizerCommunicatorMixin:
         self.update_weights_from_ipc_communicator = _Communicator(
             self.send_to_scheduler, server_args.dp_size
         )
+        self.post_process_weights_communicator = _Communicator(
+            self.send_to_scheduler, server_args.dp_size
+        )
         self.get_weights_by_name_communicator = _Communicator(
             self.send_to_scheduler, server_args.dp_size
         )
@@ -250,6 +255,10 @@ class TokenizerCommunicatorMixin:
                     UpdateWeightsFromIPCReqOutput,
                     self.update_weights_from_ipc_communicator.handle_recv,
                 ),
+                (
+                    PostProcessWeightsReqOutput,
+                    self.post_process_weights_communicator.handle_recv,
+                ),
                 (
                     GetWeightsByNameReqOutput,
                     self.get_weights_by_name_communicator.handle_recv,
@@ -433,6 +442,17 @@ class TokenizerCommunicatorMixin:
 
         return success, message
 
+    async def post_process_weights(
+        self: TokenizerManager,
+        obj: PostProcessWeightsReqInput,
+        request: Optional[fastapi.Request] = None,
+    ) -> Tuple[bool, str]:
+        """Trigger post-processing hooks for weights after loading (e.g., Marlin conversion)."""
+        self.auto_create_handle_loop()
+        async with self.model_update_lock.writer_lock:
+            results = await self.post_process_weights_communicator(obj)
+            return _Communicator.merge_results(results)
+
     async def init_weights_send_group_for_remote_instance(
         self,
         obj: InitWeightsSendGroupForRemoteInstanceReqInput,
diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py
index 49f63a198..e4cd0ff2b 100644
--- a/python/sglang/srt/managers/tp_worker.py
+++ b/python/sglang/srt/managers/tp_worker.py
@@ -27,6 +27,7 @@ from sglang.srt.managers.io_struct import (
     InitWeightsSendGroupForRemoteInstanceReqInput,
     InitWeightsUpdateGroupReqInput,
     LoadLoRAAdapterReqInput,
+    PostProcessWeightsReqInput,
     SendWeightsToRemoteInstanceReqInput,
     UnloadLoRAAdapterReqInput,
     UpdateWeightFromDiskReqInput,
@@ -175,6 +176,11 @@ class BaseTpWorker(ABC):
         success, message = self.model_runner.update_weights_from_ipc(recv_req)
         return success, message
 
+    def post_process_weights(self, recv_req: PostProcessWeightsReqInput):
+        """Perform optional post-processing on the updated model weights (e.g., Marlin conversion)."""
+        success, message = self.model_runner.post_process_weights(recv_req)
+        return success, message
+
     def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
         parameter = self.model_runner.get_weights_by_name(
             recv_req.name, recv_req.truncate_size
diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py
index 65d562a27..b00a20e95 100644
--- a/python/sglang/srt/mem_cache/memory_pool.py
+++ b/python/sglang/srt/mem_cache/memory_pool.py
@@ -1678,7 +1678,8 @@ class NSATokenToKVPool(MLATokenToKVPool):
         with (
             torch.cuda.use_mem_pool(self.custom_mem_pool)
             if self.custom_mem_pool
-            else nullcontext()
+            else nullcontext(),
+            self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE),
         ):
             self.index_k_with_scale_buffer = [
                 torch.zeros(
diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py
index 1d69c0582..d984c2e12 100644
--- a/python/sglang/srt/model_executor/model_runner.py
+++ b/python/sglang/srt/model_executor/model_runner.py
@@ -558,7 +558,8 @@ class ModelRunner(ModelRunnerKVCacheMixin):
         )
 
         # Init routed experts capturer
-        self.init_routed_experts_capturer()
+        if not self.is_draft_worker:
+            self.init_routed_experts_capturer()
 
         if self.device == "cuda":
             self.init_cublas()
@@ -2224,11 +2225,19 @@ class ModelRunner(ModelRunnerKVCacheMixin):
         output.expert_distribution_metrics = recorder_outputs.get("metrics")
 
         # Copy cached routing experts' buffers back to CPU cache
-        get_global_experts_capturer().on_forward_end(
-            forward_batch=forward_batch,
-            can_run_graph=output.can_run_graph,
-            cuda_graph_batch=getattr(self.graph_runner, "bs", None),
-        )
+        if not self.is_draft_worker:
+            # In speculative decoding, num_tokens_per_bs > 1, so we need to pass
+            # the actual number of tokens per dp rank in cuda graph, not batch size.
+            cuda_graph_num_tokens = None
+            if getattr(self.graph_runner, "bs", None):
+                cuda_graph_num_tokens = (
+                    self.graph_runner.bs * self.graph_runner.num_tokens_per_bs
+                )
+            get_global_experts_capturer().on_forward_end(
+                forward_batch=forward_batch,
+                can_run_graph=output.can_run_graph,
+                cuda_graph_batch=cuda_graph_num_tokens,
+            )
 
         if self.eplb_manager is not None:
             self.eplb_manager.on_forward_pass_end()
@@ -2436,6 +2445,41 @@ class ModelRunner(ModelRunnerKVCacheMixin):
             logger.error(f"IPC weight update failed: {e}")
             return False, str(e)
 
+    def post_process_weights(self, recv_req):
+        """
+        Execute post-processing logic for model weights, such as Marlin quantization format conversion.
+        """
+        from sglang.srt.model_loader.loader import device_loading_context
+
+        target_device = torch.device("cuda", torch.cuda.current_device())
+
+        if recv_req.restore_weights_before_load:
+            for _, module in self.model.named_modules():
+                quant_method = getattr(module, "quant_method", None)
+
+                # Check if the module supports restoring weights
+                if quant_method is not None and hasattr(
+                    quant_method, "restore_weights_before_loading"
+                ):
+
+                    with device_loading_context(module, target_device):
+                        quant_method.restore_weights_before_loading(module)
+
+        if recv_req.post_process_quantization:
+            # Iterate through all modules to apply specific post-loading processing
+            for _, module in self.model.named_modules():
+                quant_method = getattr(module, "quant_method", None)
+
+                # Check if the module supports quantization post-processing
+                if quant_method is not None and hasattr(
+                    quant_method, "process_weights_after_loading"
+                ):
+
+                    # Apply the post-processing (e.g., repacking weights for Marlin kernel)
+                    with device_loading_context(module, target_device):
+                        quant_method.process_weights_after_loading(module)
+
+        return True, "Success"
 
 def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
     params_dict = dict(model.named_parameters())
diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py
index ed8cc7ada..d44c8aaa0 100644
--- a/python/sglang/srt/models/deepseek_v2.py
+++ b/python/sglang/srt/models/deepseek_v2.py
@@ -2704,7 +2704,11 @@ class DeepseekV2AttentionMLA(nn.Module):
         ):
             k = k_nope.new_empty(*k_shape)
             concat_mla_k(k=k, k_nope=k_nope, k_rope=k_pe)
-        elif _is_cuda:
+        elif _is_cuda and all(
+            # (i.bit_count() == 1) == (is_power_of_two(i))
+            i.bit_count() == 1
+            for i in (k_shape[1], k_nope.shape[-1], k_pe.shape[-1])
+        ):
             # fa3 mha support fp8 inputs
             if (
                 self.current_attention_backend == "fa3"
diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py
index a7dbadec6..c83a41338 100644
--- a/python/sglang/srt/models/qwen2.py
+++ b/python/sglang/srt/models/qwen2.py
@@ -90,9 +90,6 @@ class Qwen2MLP(nn.Module):
         self.act_fn = SiluAndMul()
 
     def forward(self, x):
-        if get_global_server_args().rl_on_policy_target is not None:
-            x = x.bfloat16()
-
         gate_up, _ = self.gate_up_proj(x)
         x = self.act_fn(gate_up)
         x, _ = self.down_proj(x)
@@ -279,11 +276,6 @@ class Qwen2Model(nn.Module):
                 quant_config=quant_config,
                 enable_tp=not is_dp_attention_enabled(),
                 prefix=add_prefix("embed_tokens", prefix),
-                params_dtype=(
-                    torch.float32
-                    if get_global_server_args().rl_on_policy_target is not None
-                    else None
-                ),
             )
         else:
             self.embed_tokens = PPMissingLayer()
@@ -306,10 +298,8 @@ class Qwen2Model(nn.Module):
         if self.pp_group.is_last_rank:
             norm_kwargs = (
                 dict(
-                    weight_dtype=torch.float32,
                     cast_x_before_out_mul=True,
-                    override_orig_dtype=torch.float32,
-                    fp32_residual=True,
+                    fp32_residual=False,
                 )
                 if get_global_server_args().rl_on_policy_target is not None
                 else {}
diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py
index 3ad9f6736..0b9c7f499 100644
--- a/python/sglang/srt/models/qwen2_moe.py
+++ b/python/sglang/srt/models/qwen2_moe.py
@@ -586,7 +586,17 @@ class Qwen2MoeModel(nn.Module):
             prefix=add_prefix("layers", prefix),
         )
         if self.pp_group.is_last_rank:
-            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+            norm_kwargs = (
+                dict(
+                    cast_x_before_out_mul=True,
+                    fp32_residual=False,
+                )
+                if get_global_server_args().rl_on_policy_target is not None
+                else {}
+            )
+            self.norm = RMSNorm(
+                config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs
+            )
         else:
             self.norm = PPMissingLayer(return_tuple=True)
 
diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py
index 9220831f6..47a1a4e4c 100644
--- a/python/sglang/srt/models/qwen3.py
+++ b/python/sglang/srt/models/qwen3.py
@@ -90,8 +90,8 @@ class Qwen3Attention(nn.Module):
 
         norm_kwargs = (
             dict(
-                weight_dtype=torch.float32,
                 cast_x_before_out_mul=True,
+                fp32_residual=False,
             )
             if get_global_server_args().rl_on_policy_target is not None
             else {}
@@ -242,10 +242,8 @@ class Qwen3DecoderLayer(nn.Module):
 
         norm_kwargs = (
             dict(
-                weight_dtype=torch.float32,
                 cast_x_before_out_mul=True,
-                override_orig_dtype=torch.float32,
-                fp32_residual=True,
+                fp32_residual=False,
             )
             if get_global_server_args().rl_on_policy_target is not None
             else {}
diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py
index e11678a9e..e277d46f2 100644
--- a/python/sglang/srt/models/qwen3_moe.py
+++ b/python/sglang/srt/models/qwen3_moe.py
@@ -22,6 +22,7 @@ import math
 from typing import Any, Dict, Iterable, List, Optional, Tuple, TypeVar
 
 import torch
+import torch.nn.functional as F
 from torch import nn
 from transformers import PretrainedConfig
 
@@ -50,7 +51,7 @@ from sglang.srt.layers.moe import (
 )
 from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
 from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
-from sglang.srt.layers.moe.topk import TopK
+from sglang.srt.layers.moe.topk import StandardTopKOutput, TopK
 from sglang.srt.layers.moe.utils import RoutingMethodType
 from sglang.srt.layers.quantization.base_config import QuantizationConfig
 from sglang.srt.layers.radix_attention import RadixAttention
@@ -229,6 +230,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
             use_grouped_topk=False,
             layer_id=layer_id,
         )
+        self.top_k = config.num_experts_per_tok
 
         self.experts = get_moe_impl_class(quant_config)(
             num_experts=config.num_experts
@@ -294,7 +296,22 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
 
         # router_logits: (num_tokens, n_experts)
         router_logits, _ = self.gate(hidden_states)
-        topk_output = self.topk(hidden_states, router_logits)
+
+        if get_global_server_args().rl_on_policy_target is not None:
+            routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
+            routing_weights, selected_experts = torch.topk(
+                routing_weights, self.top_k, dim=-1
+            )
+            routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
+            routing_weights = routing_weights.to(hidden_states.dtype)
+            topk_output = StandardTopKOutput(
+                topk_weights=routing_weights,
+                topk_ids=selected_experts,
+                router_logits=router_logits,
+            )
+        else:
+            topk_output = self.topk(hidden_states, router_logits)
+
         final_hidden_states = self.experts(hidden_states, topk_output)
         if (
             self.tp_size > 1
@@ -475,13 +492,14 @@ class Qwen3MoeAttention(nn.Module):
         )
         self.compatible_with_fused_kv_buffer = (
             False if isinstance(self.rotary_emb, MRotaryEmbedding) else True
-        )
+        ) and (get_global_server_args().rl_on_policy_target is None)
         self.compatible_with_fused_qk_norm_rope = (
             not isinstance(self.rotary_emb, MRotaryEmbedding)
         ) and self.head_dim in (64, 128, 256)
         self.use_fused_qk_norm_rope = (
             get_global_server_args().enable_fused_qk_norm_rope
             and self.compatible_with_fused_qk_norm_rope
+            and (get_global_server_args().rl_on_policy_target is None)
         )
         self._used_fused_qk_norm_rope_last_call = False
 
@@ -494,8 +512,16 @@ class Qwen3MoeAttention(nn.Module):
             prefix=add_prefix("attn", prefix),
         )
 
-        self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
-        self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
+        norm_kwargs = (
+            dict(
+                cast_x_before_out_mul=True,
+                fp32_residual=False,
+            )
+            if get_global_server_args().rl_on_policy_target is not None
+            else {}
+        )
+        self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps, **norm_kwargs)
+        self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps, **norm_kwargs)
         self.alt_stream = alt_stream
 
     def op_prepare(self, state):
@@ -736,9 +762,19 @@ class Qwen3MoeDecoderLayer(nn.Module):
                 quant_config=quant_config,
                 prefix=add_prefix("mlp", prefix),
             )
-        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+        norm_kwargs = (
+            dict(
+                cast_x_before_out_mul=True,
+                fp32_residual=False,
+            )
+            if get_global_server_args().rl_on_policy_target is not None
+            else {}
+        )
+        self.input_layernorm = RMSNorm(
+            config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs
+        )
         self.post_attention_layernorm = RMSNorm(
-            config.hidden_size, eps=config.rms_norm_eps
+            config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs
         )
 
         self.layer_communicator = LayerCommunicator(
diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py
index 079f45843..218e32362 100644
--- a/python/sglang/srt/models/qwen3_vl.py
+++ b/python/sglang/srt/models/qwen3_vl.py
@@ -397,28 +397,68 @@ class Qwen3VLMoeVisionModel(nn.Module, RotaryPosMixin):
         return cos_combined, sin_combined
 
     def fast_pos_embed_interpolate(self, grid_thw):
-        patch_pos_embeds_permute = []
-        m_size = self.spatial_merge_size
+        grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2]
+        num_grid_per_side = int(self.num_position_embeddings**0.5)
+        device = self.pos_embed.weight.device
+
+        idx_list = [[] for _ in range(4)]
+        weight_list = [[] for _ in range(4)]
+
+        for t, h, w in zip(grid_ts, grid_hs, grid_ws):
+            h_idxs = torch.linspace(0, num_grid_per_side - 1, h)
+            w_idxs = torch.linspace(0, num_grid_per_side - 1, w)
+
+            h_idxs_floor = h_idxs.int()
+            w_idxs_floor = w_idxs.int()
+            h_idxs_ceil = (h_idxs.int() + 1).clip(max=num_grid_per_side - 1)
+            w_idxs_ceil = (w_idxs.int() + 1).clip(max=num_grid_per_side - 1)
+
+            dh = h_idxs - h_idxs_floor
+            dw = w_idxs - w_idxs_floor
+
+            base_h = h_idxs_floor * num_grid_per_side
+            base_h_ceil = h_idxs_ceil * num_grid_per_side
+
+            indices = [
+                (base_h[None].T + w_idxs_floor[None]).flatten(),
+                (base_h[None].T + w_idxs_ceil[None]).flatten(),
+                (base_h_ceil[None].T + w_idxs_floor[None]).flatten(),
+                (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(),
+            ]
+
+            weights = [
+                ((1 - dh)[None].T * (1 - dw)[None]).flatten(),
+                ((1 - dh)[None].T * dw[None]).flatten(),
+                (dh[None].T * (1 - dw)[None]).flatten(),
+                (dh[None].T * dw[None]).flatten(),
+            ]
 
-        embeds = torch.arange(self.num_grid, device=self.pos_embed.weight.device)
-        embeds = (
-            self.pos_embed(embeds)
-            .permute(1, 0)
-            .reshape(1, -1, self.num_grid_per_side, self.num_grid_per_side)
+            for i in range(4):
+                idx_list[i].extend(indices[i].tolist())
+                weight_list[i].extend(weights[i].tolist())
+
+        idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device)
+        weight_tensor = torch.tensor(
+            weight_list, dtype=self.pos_embed.weight.dtype, device=device
         )
-        for t, h, w in grid_thw:
-            pos_embed = torch.nn.functional.interpolate(
-                embeds, size=(h, w), mode="bilinear", align_corners=self.align_corners
-            )
-            pos_embed = pos_embed.reshape(
-                -1,
-                h // self.spatial_merge_size,
-                self.spatial_merge_size,
-                w // self.spatial_merge_size,
-                self.spatial_merge_size,
+        pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None]
+        patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
+
+        patch_pos_embeds = patch_pos_embeds.split(
+            [h * w for h, w in zip(grid_hs, grid_ws)]
+        )
+
+        patch_pos_embeds_permute = []
+        merge_size = self.spatial_merge_size
+        for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws):
+            pos_embed = pos_embed.repeat(t, 1)
+            pos_embed = (
+                pos_embed.view(
+                    t, h // merge_size, merge_size, w // merge_size, merge_size, -1
+                )
+                .permute(0, 1, 3, 2, 4, 5)
+                .flatten(0, 4)
             )
-            pos_embed = pos_embed.permute(1, 3, 2, 4, 0)
-            pos_embed = pos_embed.flatten(0, 3).repeat(t, 1)
             patch_pos_embeds_permute.append(pos_embed)
         return torch.cat(patch_pos_embeds_permute)
 
@@ -610,14 +650,19 @@ class Qwen3LLMModel(Qwen3Model):
                     hidden_states + residual if residual is not None else hidden_states
                 )
 
+            deepstack_embeds = None
+            if input_deepstack_embeds is not None:
+                prev_layer_idx = layer_idx - 1
+                if prev_layer_idx in self.deepstack_embed_to_decoder_layer:
+                    sep = self.hidden_size * prev_layer_idx
+                    deepstack_embeds = input_deepstack_embeds[
+                        :, sep : sep + self.hidden_size
+                    ]
+
             # SGLang applies residual at the START of the next layer, not at the END like HuggingFace.
             # See: https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py#L549
             # To match HF behavior, deepstack must be added AFTER residual: (hidden_states + residual) + deepstack
             # The order matters because addition with different tensors is not associative in practice.
-            # Deepstack for prev_layer is applied at the start of current layer via post_residual_addition.
-            deepstack_embeds = self.get_deepstack_embeds(
-                layer_idx - 1, input_deepstack_embeds
-            )
             hidden_states, residual = layer(
                 positions,
                 hidden_states,
diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py
index a2b26e0e0..72db29801 100644
--- a/python/sglang/srt/server_args.py
+++ b/python/sglang/srt/server_args.py
@@ -527,6 +527,7 @@ class ServerArgs:
     cuda_graph_max_bs: Optional[int] = None
     cuda_graph_bs: Optional[List[int]] = None
     disable_cuda_graph: bool = False
+    disable_draft_cuda_graph: bool = False
     disable_cuda_graph_padding: bool = False
     enable_profile_cuda_graph: bool = False
     enable_cudagraph_gc: bool = False
@@ -3980,6 +3981,11 @@ class ServerArgs:
             action="store_true",
             help="Disable cuda graph.",
         )
+        parser.add_argument(
+            "--disable-draft-cuda-graph",
+            action="store_true",
+            help="Disable cuda graph for draft model in speculative decoding.",
+        )
         parser.add_argument(
             "--disable-cuda-graph-padding",
             action="store_true",
diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
index 5fe45086c..c95fbd0f6 100644
--- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
+++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
@@ -341,7 +341,10 @@ class EAGLEDraftCudaGraphRunner:
             self.seq_lens.fill_(self.seq_len_fill_value)
             self.out_cache_loc.zero_()
             self.positions.zero_()
-
+            self.topk_p.zero_()
+            self.topk_index.zero_()
+            self.hidden_states.zero_()
+            self.req_pool_indices.zero_()
         num_tokens = bs * self.num_tokens_per_bs
 
         # Common inputs
@@ -350,8 +353,8 @@ class EAGLEDraftCudaGraphRunner:
             forward_batch.out_cache_loc
         )
         self.positions[:raw_num_token].copy_(forward_batch.positions)
-        self.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p)
-        self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
+        self.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p.clamp(0, 1))
+        self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index.clamp(0, self.model_runner.model_config.vocab_size - 1))
         self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
         self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
 
diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py
index 1bf3816e9..b5b41dba4 100644
--- a/python/sglang/srt/speculative/eagle_info.py
+++ b/python/sglang/srt/speculative/eagle_info.py
@@ -778,6 +778,10 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
             self.topk_index = self.topk_index[: len(new_indices)]
             self.hidden_states = self.hidden_states[: len(new_indices)]
             self.verified_id = self.verified_id[: len(new_indices)]
+            if self.accept_length is not None:
+                self.accept_length = self.accept_length[: len(new_indices)]
+            if self.accept_length_cpu is not None:
+                self.accept_length_cpu = self.accept_length_cpu[: len(new_indices)]
         else:
             # in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index`
             self.topk_p = self.topk_p[new_indices]
@@ -809,6 +813,27 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
         self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
         self.topk_p = torch.cat([self.topk_p, spec_info.topk_p])
         self.topk_index = torch.cat([self.topk_index, spec_info.topk_index])
+        if self.accept_length is not None and spec_info.accept_length is not None:
+            self.accept_length = torch.cat(
+                [self.accept_length, spec_info.accept_length]
+            )
+            self.accept_length_cpu = self.accept_length.tolist()
+        elif self.accept_length is not None:
+            zeros = torch.zeros(
+                [spec_info.verified_id.shape[0]],
+                dtype=self.accept_length.dtype,
+                device=self.accept_length.device,
+            )
+            self.accept_length = torch.cat([self.accept_length, zeros])
+            self.accept_length_cpu = self.accept_length.tolist()
+        elif spec_info.accept_length is not None:
+            zeros = torch.zeros(
+                [self.verified_id.shape[0]],
+                dtype=self.accept_length.dtype,
+                device=self.accept_length.device,
+            )
+            self.accept_length = torch.cat([zeros, spec_info.accept_length])
+            self.accept_length_cpu = self.accept_length.tolist()
 
 
 @dataclass
diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py
index a702df4f8..61d9ae366 100644
--- a/python/sglang/srt/speculative/eagle_worker.py
+++ b/python/sglang/srt/speculative/eagle_worker.py
@@ -231,7 +231,7 @@ class EAGLEWorker(TpModelWorker):
         self.cuda_graph_runner = None
         self.cuda_graph_runner_for_draft_extend = None
 
-        if self.server_args.disable_cuda_graph:
+        if self.server_args.disable_cuda_graph or self.server_args.disable_draft_cuda_graph:
             return
 
         Device2DraftCudaGraphRunner = {
diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py
index 8560246c6..13db860dc 100644
--- a/python/sglang/srt/utils/common.py
+++ b/python/sglang/srt/utils/common.py
@@ -2224,6 +2224,8 @@ class SafeUnpickler(pickle.Unpickler):
         "sglang.srt.model_executor.model_runner.",
         "sglang.srt.layers.",
         "sglang.srt.utils.",
+        # --- slime ---
+        "slime.",
     }
 
     DENY_CLASSES = {