From f45d5289e99793f169252879f92ee253273faac1 Mon Sep 17 00:00:00 2001

From: huyuanquan1 <huyuanquan1@huawei.com>

Date: Sat, 28 Feb 2026 15:09:05 +0800

Subject: [PATCH] feature chunk moe



---

 vllm_ascend/ops/fused_moe/moe_comm_method.py | 80 ++++++++++++++++++++

 1 file changed, 80 insertions(+)



diff --git a/llm_rl/qwen3/vllm_ascend/ops/fused_moe/moe_comm_method.py b/llm_rl/qwen3/vllm_ascend/ops/fused_moe/moe_comm_method.py

index 458557e9..ebb0790e 100644

--- a/llm_rl/qwen3/vllm_ascend/ops/fused_moe/moe_comm_method.py

+++ b/llm_rl/qwen3/vllm_ascend/ops/fused_moe/moe_comm_method.py

@@ -15,6 +15,7 @@

 # This file is a part of the vllm-ascend project.

 from __future__ import annotations

 

+import os

 from abc import ABC, abstractmethod

 from dataclasses import dataclass

 from typing import Dict, Optional

@@ -36,6 +37,83 @@ from vllm_ascend.ops.fused_moe.token_dispatcher import (

 _MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {}

 

 

+def chunk_moe_decorator(fused_experts_func):

+    chunk_moe_size = int(os.environ.get('VLLM_CHUNK_MOE_SIZE', 512))

+    def get_arg(name, kwargs):

+        if name in kwargs:

+            return kwargs.pop(name)

+        return None

+

+    def wrapper(*args, **kwargs):

+        hidden_states = get_arg('hidden_states', kwargs)

+        topk_weights = get_arg('topk_weights', kwargs)

+        topk_ids = get_arg('topk_ids', kwargs)

+

+        chunk_start_index = 0

+        ctx = get_forward_context()

+        from vllm.distributed import get_tensor_model_parallel_world_size

+        tp_size = get_tensor_model_parallel_world_size()

+        max_tokens = (ctx.max_tokens_across_dp + tp_size - 1) // tp_size

+

+        if max_tokens < chunk_moe_size:

+            return fused_experts_func(

+                hidden_states=hidden_states,

+                topk_weights=topk_weights,

+                topk_ids=topk_ids,

+                *args,

+                **kwargs

+            )

+            

+        num_tokens = hidden_states.size(0)

+        final_routed_out = torch.zeros_like(hidden_states)

+        all_expert_tokens = []

+        last_before_dispatch_evt = None

+        last_before_combine_evt = None

+        group_list_type = None

+        for chunk_start in range(0, max_tokens, chunk_moe_size):

+            skip_result_store = chunk_start >= num_tokens

+            chunk_end = min(chunk_start + chunk_moe_size, num_tokens)

+            chunk_start = min(chunk_start, num_tokens - 1)

+            chunk_hidden_states = hidden_states[chunk_start:chunk_end]

+            chunk_topk_ids = topk_ids[chunk_start:chunk_end]

+            chunk_topk_weights = topk_weights[chunk_start:chunk_end]

+            update_kwargs = dict(**kwargs)

+            if update_kwargs.get('shared_experts'):

+                update_kwargs['shared_experts'] = update_kwargs['shared_experts'][chunk_start:chunk_end]

+            

+            res = fused_experts_func(

+                hidden_states=chunk_hidden_states,

+                topk_weights=chunk_topk_weights,

+                topk_ids=chunk_topk_ids,

+                *args,

+                **update_kwargs

+            )

+

+            if skip_result_store:

+                continue

+            chunk_end_idx = chunk_start_index + res.routed_out.shape[0]

+            final_routed_out[chunk_start_index: chunk_end_idx, :] = res.routed_out

+            if res.expert_tokens is not None:

+                all_expert_tokens.append(res.expert_tokens)

+            last_before_dispatch_evt = res.before_dispatch_evt

+            last_before_combine_evt = res.before_combine_evt

+            group_list_type = res.group_list_type

+            chunk_start_index = chunk_end_idx

+        

+        combine_expert_tokens = None

+        if all_expert_tokens:

+            combined_expert_tokens = torch.cat(all_expert_tokens, dim=0)

+        return FusedExpertsResult(

+            routed_out=final_routed_out,

+            before_dispatch_evt=last_before_dispatch_evt,

+            before_combine_evt=last_before_combine_evt,

+            group_list_type=group_list_type,

+            expert_tokens=combined_expert_tokens

+        )

+

+    return wrapper

+

+

 def get_moe_comm_method(

         moe_comm_type: Optional[MoECommType]) -> Optional[MoECommMethod]:

     return _MoECommMethods.get(moe_comm_type, None)

@@ -100,6 +178,7 @@ class MoECommMethod(ABC):

                                                        context_metadata)

         return hidden_states

 

+    @chunk_moe_decorator

     def fused_experts(

             self,

             hidden_states: torch.Tensor,

@@ -275,6 +354,7 @@ class FusedMC2CommImpl(MoECommMethod):

     def _get_prepare_finalize(self):

         return PrepareAndFinalizeWithMC2(self.moe_config)

 

+    @chunk_moe_decorator

     def fused_experts(

             self,

             hidden_states: torch.Tensor,

-- 

2.45.1.windows.1