From f45d5289e99793f169252879f92ee253273faac1 Mon Sep 17 00:00:00 2001
From: huyuanquan1 <huyuanquan1@huawei.com>
Date: Sat, 28 Feb 2026 15:09:05 +0800
Subject: [PATCH] feature chunk moe
vllm_ascend/ops/fused_moe/moe_comm_method.py | 80 ++++++++++++++++++++
1 file changed, 80 insertions(+)
@@ -15,6 +15,7 @@
# This file is a part of the vllm-ascend project.
from __future__ import annotations
+import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, Optional
@@ -36,6 +37,83 @@ from vllm_ascend.ops.fused_moe.token_dispatcher import (
_MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {}
+def chunk_moe_decorator(fused_experts_func):
+ chunk_moe_size = int(os.environ.get('VLLM_CHUNK_MOE_SIZE', 512))
+ def get_arg(name, kwargs):
+ if name in kwargs:
+ return kwargs.pop(name)
+ return None
+
+ def wrapper(*args, **kwargs):
+ hidden_states = get_arg('hidden_states', kwargs)
+ topk_weights = get_arg('topk_weights', kwargs)
+ topk_ids = get_arg('topk_ids', kwargs)
+
+ chunk_start_index = 0
+ ctx = get_forward_context()
+ from vllm.distributed import get_tensor_model_parallel_world_size
+ tp_size = get_tensor_model_parallel_world_size()
+ max_tokens = (ctx.max_tokens_across_dp + tp_size - 1) // tp_size
+
+ if max_tokens < chunk_moe_size:
+ return fused_experts_func(
+ hidden_states=hidden_states,
+ topk_weights=topk_weights,
+ topk_ids=topk_ids,
+ *args,
+ **kwargs
+ )
+
+ num_tokens = hidden_states.size(0)
+ final_routed_out = torch.zeros_like(hidden_states)
+ all_expert_tokens = []
+ last_before_dispatch_evt = None
+ last_before_combine_evt = None
+ group_list_type = None
+ for chunk_start in range(0, max_tokens, chunk_moe_size):
+ skip_result_store = chunk_start >= num_tokens
+ chunk_end = min(chunk_start + chunk_moe_size, num_tokens)
+ chunk_start = min(chunk_start, num_tokens - 1)
+ chunk_hidden_states = hidden_states[chunk_start:chunk_end]
+ chunk_topk_ids = topk_ids[chunk_start:chunk_end]
+ chunk_topk_weights = topk_weights[chunk_start:chunk_end]
+ update_kwargs = dict(**kwargs)
+ if update_kwargs.get('shared_experts'):
+ update_kwargs['shared_experts'] = update_kwargs['shared_experts'][chunk_start:chunk_end]
+
+ res = fused_experts_func(
+ hidden_states=chunk_hidden_states,
+ topk_weights=chunk_topk_weights,
+ topk_ids=chunk_topk_ids,
+ *args,
+ **update_kwargs
+ )
+
+ if skip_result_store:
+ continue
+ chunk_end_idx = chunk_start_index + res.routed_out.shape[0]
+ final_routed_out[chunk_start_index: chunk_end_idx, :] = res.routed_out
+ if res.expert_tokens is not None:
+ all_expert_tokens.append(res.expert_tokens)
+ last_before_dispatch_evt = res.before_dispatch_evt
+ last_before_combine_evt = res.before_combine_evt
+ group_list_type = res.group_list_type
+ chunk_start_index = chunk_end_idx
+
+ combine_expert_tokens = None
+ if all_expert_tokens:
+ combined_expert_tokens = torch.cat(all_expert_tokens, dim=0)
+ return FusedExpertsResult(
+ routed_out=final_routed_out,
+ before_dispatch_evt=last_before_dispatch_evt,
+ before_combine_evt=last_before_combine_evt,
+ group_list_type=group_list_type,
+ expert_tokens=combined_expert_tokens
+ )
+
+ return wrapper
+
+
def get_moe_comm_method(
moe_comm_type: Optional[MoECommType]) -> Optional[MoECommMethod]:
return _MoECommMethods.get(moe_comm_type, None)
@@ -100,6 +178,7 @@ class MoECommMethod(ABC):
context_metadata)
return hidden_states
+ @chunk_moe_decorator
def fused_experts(
self,
hidden_states: torch.Tensor,
@@ -275,6 +354,7 @@ class FusedMC2CommImpl(MoECommMethod):
def _get_prepare_finalize(self):
return PrepareAndFinalizeWithMC2(self.moe_config)
+ @chunk_moe_decorator
def fused_experts(
self,
hidden_states: torch.Tensor,
--
2.45.1.windows.1