import threading
import torch
import torch_npu
from torch import nn
class ExpertLoadCollector(nn.Module):
def __init__(self, routed_expert_num, lb_interval: int = 1) -> None:
super().__init__()
self.routed_expert_num = routed_expert_num
self.register_buffer('expert_data_buffer', torch.zeros(self.routed_expert_num, dtype=torch.long))
self.register_buffer('expert_group_list', torch.zeros(self.routed_expert_num, dtype=torch.long))
self.experts_load_cpu = torch.zeros(self.routed_expert_num, dtype=torch.long).pin_memory()
self.buffer_lock = threading.Lock()
self.lb_interval = lb_interval
self.task_transfer = None
def get_expert_load(self):
with self.buffer_lock:
self.expert_data_buffer.copy_(self.expert_group_list)
self.reset()
self.experts_load_cpu.copy_(self.expert_data_buffer)
return self.experts_load_cpu
def collect_expert_load(self, indices_expert: torch.Tensor):
expanded_buffer = torch_npu.npu_moe_compute_expert_tokens(indices_expert, self.routed_expert_num)
with self.buffer_lock:
self.expert_group_list.add_(expanded_buffer)
if self.task_transfer:
self.task_transfer.profile_emit_task()
def reset(self):
self.expert_group_list.zero_()