import random
from dataclasses import dataclass
import numpy as np
import torch
from mindiesd.utils.logs.logging import logger
from mindiesd.utils.exception import ParametersInvalid
@dataclass
class LoadData:
placement: np.ndarray
shared_expert_id: int
total_traffic: np.ndarray
used_mems: dict
global_expert_load: np.ndarray = None
device_to_expert: dict = None
sorted_experts: np.ndarray = None
origin_device_to_expert: dict = None
expert_trans_tensor: torch.Tensor = None
class EPLBService:
"""
EPLB算法的基类
Attributes:
num_devices (int): 设备(NPU卡)的总数 (K)
num_experts (int): MoE层中的专家数量 (N)
expert_mems (dict): 每个专家的内存占用 (M_j)
device_mems (dict): 每个设备的可用内存 (Mem_i)
cost_local (float): 单个Token的本地计算成本 (C_comp)
cost_remote (float): 单个Token的远程通信成本 (C_comm)
"""
def __init__(
self,
num_devices,
num_experts,
expert_mems,
device_mems,
cost_local,
cost_remote,
max_move_number,
load_balance_threshold,
):
self.num_devices = num_devices
self.num_experts = num_experts
self.expert_mems = expert_mems
self.device_mems = device_mems
self.cost_local = cost_local
self.cost_remote = cost_remote
self.max_move_number = max_move_number
self.load_balance_threshold = load_balance_threshold
@staticmethod
def get_expert_total_demand(total_traffic):
return np.sum(total_traffic, axis=0)
def placement_greedy(self, traffic_data, origin_device_to_expert: dict = None, shared_expert_id: int = None):
"""
基于贪心算法求解的专家放置策略
Args:
traffic_data (dict): 每张卡上处理的token路由到各个专家的数量
origin_device_to_expert (dict): 每张卡上部署的专家列表
shared_expert_id (dict): 共享专家 id
Returns:
bool: 是否触发更新
"""
placement, sorted_experts, total_traffic, used_mems, global_expert_load = self.data_preprocess(traffic_data)
if origin_device_to_expert is not None:
expert_trans_tensor_size = 0
for _, value in origin_device_to_expert.items():
expert_trans_tensor_size += len(value)
expert_trans_tensor = torch.eye(expert_trans_tensor_size)
else:
expert_trans_tensor = torch.eye(self.num_experts)
self.process_share_expert(placement, shared_expert_id, used_mems)
initial_load_data = LoadData(
placement=placement,
shared_expert_id=shared_expert_id,
total_traffic=total_traffic,
used_mems=used_mems,
origin_device_to_expert=origin_device_to_expert,
sorted_experts=sorted_experts,
expert_trans_tensor=expert_trans_tensor,
)
device_to_expert = self.initial_placement(initial_load_data)
optimize_load_data = LoadData(
placement=placement,
shared_expert_id=shared_expert_id,
total_traffic=total_traffic,
used_mems=used_mems,
global_expert_load=global_expert_load,
device_to_expert=device_to_expert,
)
self.optimize_min_max(optimize_load_data)
return {
"final_placement": placement,
"final_memory_usage": used_mems,
"device_to_expert_map": device_to_expert,
"expert_trans_tensor": expert_trans_tensor,
}
def optimize_min_max(self, load_data: LoadData):
pass
def initial_placement(self, load_data: LoadData):
logger.debug("--- [Phase 1] Start initial placement ---")
device_to_expert = {i: [] for i in range(self.num_devices)}
if load_data.shared_expert_id is not None:
load_data.sorted_experts = [exp for exp in load_data.sorted_experts if exp != load_data.shared_expert_id]
for expert_id in load_data.sorted_experts:
best_device = -1
min_estimated_load = float('inf')
for device_id in range(self.num_devices):
if load_data.used_mems[device_id] + self.expert_mems[expert_id] <= self.device_mems[device_id]:
if load_data.used_mems[device_id] < min_estimated_load:
min_estimated_load = load_data.used_mems[device_id]
best_device = device_id
if best_device != -1:
load_data.placement[best_device, expert_id] = 1
device_to_expert[best_device].append(expert_id)
load_data.used_mems[best_device] += self.expert_mems[expert_id]
logger.debug(
f" -> Place high-load expert {expert_id} to device {best_device} "
f"(Current memory: {load_data.used_mems[best_device]:.1f}GB)"
)
else:
logger.error(
"[MindIE-SD/eplb] Expert initial placement failed. "
"issue=no device has enough memory for expert, expert_id=%s, expected_free_memory>=%sGB, "
"actual_device_memory=%sGB, actual_used_memory=%sGB. "
"possible_cause=EPLB memory configuration cannot fit every expert. "
"Troubleshooting: increase expert_per_rank/redundant memory budget, reduce expert size, "
"or check world_size and expert_num settings.",
expert_id,
self.expert_mems[expert_id],
self.device_mems,
load_data.used_mems,
)
raise MemoryError(
f"Error:expert {expert_id} (need {self.expert_mems[expert_id]}GB) "
f"Unable to locate initial position for this expert on any device, check memory configuration."
)
logger.debug("--- [Phase 1] Initial placement completed. ---\n")
return device_to_expert
def data_preprocess(self, traffic_data):
total_traffic = np.zeros((self.num_devices, self.num_experts))
for device_id, expert_requests in traffic_data.items():
total_traffic[device_id] += expert_requests
placement = np.zeros((self.num_devices, self.num_experts), dtype=int)
used_mems = {i: 0.0 for i in range(self.num_devices)}
expert_total_demand = self.get_expert_total_demand(total_traffic)
sorted_experts = np.argsort(expert_total_demand)[::-1]
return placement, sorted_experts, total_traffic, used_mems, expert_total_demand
def process_share_expert(self, placement, shared_expert_id, used_mems):
if shared_expert_id is not None:
logger.debug(
f"--- [Preprocessing] Shared Expert {shared_expert_id} detected, forcing deployment on all devices ---"
)
shared_mem = self.expert_mems[shared_expert_id]
for i in range(self.num_devices):
if self.device_mems[i] < shared_mem:
logger.error(
"[MindIE-SD/eplb] Shared expert placement failed. "
"issue=device memory is insufficient for shared expert, device_id=%s, shared_expert_id=%s, "
"expected_device_memory>=%sGB, actual_device_memory=%sGB. "
"possible_cause=shared expert memory requirement exceeds one or more device budgets. "
"Troubleshooting: increase per-device expert memory budget, reduce shared expert size, "
"or adjust EPLB deployment configuration.",
i,
shared_expert_id,
shared_mem,
self.device_mems[i],
)
raise MemoryError(
f"Device {i} (memory {self.device_mems[i]}GB) cannot accommodate the shared expert."
f"{shared_expert_id} (need {shared_mem}GB)。"
)
placement[i, shared_expert_id] = 1
used_mems[i] += shared_mem
logger.debug("--- [Preprocessing] Shared expert deployment completed ---\n")
class A2ARedundantExpertService(EPLBService):
"""
面向All-to-all通信方式下的冗余专家动态调度
"""
def __init__(
self,
num_devices,
num_experts,
expert_mems,
device_mems,
cost_local,
cost_remote,
max_move_number,
load_balance_threshold,
):
super().__init__(
num_devices,
num_experts,
expert_mems,
device_mems,
cost_local,
cost_remote,
max_move_number,
load_balance_threshold,
)
def optimize_min_max(self, load_data: LoadData):
logger.debug("--- [Phase 2] Starting iterative replication optimization.---")
iteration = 1
while True:
best_move = None
max_score = 1e-9
current_loads = np.zeros(self.num_devices)
for i in range(self.num_devices):
load = 0
for j in range(self.num_experts):
traffic = load_data.total_traffic[i, j]
if load_data.placement[i, j] == 1:
load += traffic * self.cost_local
else:
load += traffic * self.cost_remote
current_loads[i] = load
bottleneck_device = np.argmax(current_loads)
max_load = np.max(current_loads)
logger.debug(
f"\nRound {iteration} | The current system bottleneck (maximum load): {max_load:,.0f} "
f"(on device {bottleneck_device})"
)
for expert_id in range(self.num_experts):
if load_data.shared_expert_id is not None and expert_id == load_data.shared_expert_id:
continue
for device_id in range(self.num_devices):
if load_data.placement[device_id, expert_id] == 1:
continue
if load_data.used_mems[device_id] + self.expert_mems[expert_id] > self.device_mems[device_id]:
continue
traffic_from_device = load_data.total_traffic[device_id, expert_id]
gain = traffic_from_device * (self.cost_remote - self.cost_local)
cost_mem = self.expert_mems[expert_id]
if gain <= 0:
continue
score = gain / cost_mem
if score > max_score:
max_score = score
best_move = (device_id, expert_id, gain)
if best_move is None:
logger.debug(
"\n--- [Phase 2] Optimization completed: No more beneficial replication operations found.---"
)
break
dev_to_add, exp_to_add, move_gain = best_move
load_data.placement[dev_to_add, exp_to_add] = 1
load_data.device_to_expert[dev_to_add].append(exp_to_add)
load_data.used_mems[dev_to_add] += self.expert_mems[exp_to_add]
logger.debug(
f" -> Best Move: Copy Expert {exp_to_add} to Device {dev_to_add} (Benefit Score: {max_score:,.2f})"
)
logger.debug(f" Load benefit: Reduced load by {move_gain:,.0f} for Device {dev_to_add}")
logger.debug(
f" New memory status: {load_data.used_mems[dev_to_add]:.1f}GB / {self.device_mems[dev_to_add]}GB"
)
iteration += 1
class AGRedundantExpertService(EPLBService):
"""
面向All-Gather通信方式下的冗余专家动态调度
"""
def __init__(
self,
num_devices,
num_experts,
expert_mems,
device_mems,
cost_local,
cost_remote,
max_move_number,
load_balance_threshold,
):
super().__init__(
num_devices,
num_experts,
expert_mems,
device_mems,
cost_local,
cost_remote,
max_move_number,
load_balance_threshold,
)
@staticmethod
def get_expert_total_demand(total_traffic):
return np.mean(total_traffic, axis=0)
def optimize_min_max(self, load_data: LoadData):
logger.debug("--- [Phase 2] Starting iterative replication optimization ---")
iteration = 1
while True:
replicas_per_expert = np.sum(load_data.placement, axis=0)
replicas_per_expert[replicas_per_expert == 0] = 1e-9
current_loads = np.zeros(self.num_devices)
for i in range(self.num_devices):
load = 0
for j in range(self.num_experts):
if load_data.placement[i, j] == 1:
load += (load_data.global_expert_load[j] / replicas_per_expert[j]) * self.cost_local
current_loads[i] = load
bottleneck_device = np.argmax(current_loads)
max_load = np.max(current_loads)
logger.debug(
f"\nRound {iteration} | The current system bottleneck (maximum load): {max_load:,.0f} "
f"(on device {bottleneck_device})"
)
logger.debug(f" Devices load: {[f'{device_load:,.0f}' for device_load in current_loads]}")
best_move = None
max_score = 1e-9
for expert_id in range(self.num_experts):
if load_data.shared_expert_id is not None and expert_id == load_data.shared_expert_id:
continue
for device_id in range(self.num_devices):
if load_data.placement[device_id, expert_id] == 1:
continue
if load_data.used_mems[device_id] + self.expert_mems[expert_id] > self.device_mems[device_id]:
continue
temp_placement = load_data.placement.copy()
temp_placement[device_id, expert_id] = 1
temp_replicas = np.sum(temp_placement, axis=0)
temp_replicas[temp_replicas == 0] = 1e-9
""" 核心的代码逻辑是为了计算出模拟移动后,整体的新负载
一次模拟的移动后,只有跟这个专家相关的设备有变化,无需全遍历 """
new_loads = current_loads.copy()
device_add_load = (
load_data.global_expert_load[expert_id] / temp_replicas[expert_id]
) * self.cost_local
new_loads[device_id] += device_add_load
devices_for_expert = np.where(load_data.placement[:, expert_id] == 1)[0]
new_loads[devices_for_expert] -= device_add_load / (temp_replicas[expert_id] - 1)
new_max_load = np.max(new_loads)
gain = max_load - new_max_load
cost_mem = self.expert_mems[expert_id]
if gain <= 0:
continue
score = gain / cost_mem
if score > max_score:
max_score = score
best_move = (device_id, expert_id, gain)
if best_move is None:
logger.debug(
"\n--- [Phase 2] Optimization completed: No more beneficial replication operations found.---"
)
break
dev_to_add, exp_to_add, move_gain = best_move
load_data.placement[dev_to_add, exp_to_add] = 1
load_data.device_to_expert[dev_to_add].append(exp_to_add)
load_data.used_mems[dev_to_add] += self.expert_mems[exp_to_add]
logger.debug(
f" -> Best Move: Copy Expert {exp_to_add} to Device {dev_to_add} (Benefit Score: {max_score:,.2f})"
)
logger.debug(f" Load benefit: System maximum load reduced by {move_gain:,.0f}")
logger.debug(
f" New memory status: {load_data.used_mems[dev_to_add]:.1f}GB / {self.device_mems[dev_to_add]}GB"
)
iteration += 1
class ExpertExchangeService(EPLBService):
"""
基于专家交换的动态调度方案
"""
def __init__(
self,
num_devices,
num_experts,
expert_mems,
device_mems,
cost_local,
cost_remote,
max_move_number,
load_balance_threshold,
):
super().__init__(
num_devices,
num_experts,
expert_mems,
device_mems,
cost_local,
cost_remote,
max_move_number,
load_balance_threshold,
)
def initial_placement(self, load_data: LoadData):
device_to_expert = load_data.origin_device_to_expert.copy()
device_loads = [0] * self.num_devices
expert_start_index = 0
if device_to_expert is None:
device_to_expert = {}
for device_id in range(self.num_devices):
expert_end_index = (device_id + 1) * (self.num_experts // self.num_devices)
device_loads[device_id] = load_data.total_traffic[device_id, expert_start_index:expert_end_index].sum()
device_to_expert[device_id] = list(range(expert_start_index, expert_end_index))
for expert_id in range(expert_start_index, expert_end_index):
load_data.used_mems[device_id] += self.expert_mems[expert_id]
expert_start_index = expert_end_index
else:
for device_id in range(self.num_devices):
device_loads[device_id] = load_data.total_traffic[device_id, device_to_expert[device_id]].sum()
for expert_id in device_to_expert[device_id]:
load_data.used_mems[device_id] += self.expert_mems[expert_id]
logger.debug(f"------------ Initial load status of the devices: {device_loads} ---------------------")
move_expert_cost = 0
for move_num in range(self.max_move_number):
logger.debug(f"------------ Round {move_num} ---------------------")
max_load = max(device_loads)
max_device_index = device_loads.index(max_load)
min_load = min(device_loads)
min_device_index = device_loads.index(min_load)
delta_load = (max_load - min_load) // 2
logger.debug(f"--- Max-min device load diff {delta_load * 2}")
if delta_load * 2 < self.load_balance_threshold:
logger.debug(
f"------------ Max-min device load diff less than {self.load_balance_threshold} "
f"End the iteration ---------------------"
)
break
if max_device_index in device_to_expert and min_device_index in device_to_expert:
expert_max_idx = device_to_expert[max_device_index]
expert_min_idx = device_to_expert[min_device_index]
max_load_device_traffic = load_data.total_traffic[max_device_index, expert_max_idx]
min_load_device_traffic = load_data.total_traffic[min_device_index, expert_min_idx]
else:
raise ParametersInvalid("[greedy] expert not in index list")
trans_traffic = np.abs((max_load_device_traffic[:, np.newaxis] - min_load_device_traffic) - delta_load)
flat_traffic = trans_traffic.flatten()
sorted_indices = np.argsort(flat_traffic)
for idx in sorted_indices:
experid_from_max_to_min_index, experid_from_min_to_max_index = np.unravel_index(
idx, trans_traffic.shape
)
experid_from_max_to_min = expert_max_idx[experid_from_max_to_min_index]
experid_from_min_to_max = expert_min_idx[experid_from_min_to_max_index]
if experid_from_min_to_max not in device_to_expert.get(
max_device_index, []
) and experid_from_max_to_min not in device_to_expert.get(min_device_index, []):
break
else:
logger.debug("\n--- No more beneficial replication operations found.---")
break
move_expert_load = (
load_data.total_traffic[max_device_index, experid_from_max_to_min]
- load_data.total_traffic[min_device_index, experid_from_min_to_max]
)
new_delta_load = abs((max_load - move_expert_load) - (min_load + move_expert_load)) // 2
gain = delta_load - new_delta_load
if gain > 0:
if max_device_index in device_to_expert and min_device_index in device_to_expert:
device_to_expert[max_device_index][experid_from_max_to_min_index] = experid_from_min_to_max
device_to_expert[min_device_index][experid_from_min_to_max_index] = experid_from_max_to_min
else:
logger.error(
"[MindIE-SD/eplb] Expert exchange failed. "
"issue=device index is missing from device_to_expert, expected_devices=(%s,%s), "
"actual_devices=%s. possible_cause=EPLB placement state is inconsistent before exchange. "
"Troubleshooting: inspect device_to_expert construction and load report rank coverage.",
max_device_index,
min_device_index,
list(device_to_expert.keys()),
)
raise IndexError("Device or expert index out of bounds")
device_loads[max_device_index] -= move_expert_load
device_loads[min_device_index] += move_expert_load
logger.debug(
f"--- Move expert {experid_from_max_to_min} from device {max_device_index} to {min_device_index}, "
f"device {max_device_index} load "
f"reduced by {load_data.total_traffic[max_device_index, experid_from_max_to_min]}, "
f"device {min_device_index} load "
f"increased by {load_data.total_traffic[max_device_index, experid_from_max_to_min]}; "
f"Move expert {experid_from_min_to_max} from device {min_device_index} to {max_device_index},"
f"device {min_device_index} load "
f"reduced by {load_data.total_traffic[min_device_index, experid_from_min_to_max]}, "
f"device {max_device_index} load "
f"increased {load_data.total_traffic[min_device_index, experid_from_min_to_max]}. "
)
logger.debug(f"--- Latest Max-min device load diff: {new_delta_load * 2}")
move_expert_cost += 2
load_data.used_mems[min_device_index] += self.expert_mems[experid_from_max_to_min]
load_data.used_mems[min_device_index] -= self.expert_mems[experid_from_min_to_max]
load_data.used_mems[max_device_index] += self.expert_mems[experid_from_min_to_max]
load_data.used_mems[max_device_index] -= self.expert_mems[experid_from_max_to_min]
i = experid_from_max_to_min_index + (
max_device_index * (load_data.expert_trans_tensor.shape[0] // self.num_devices)
)
j = experid_from_min_to_max_index + (
min_device_index * (load_data.expert_trans_tensor.shape[0] // self.num_devices)
)
load_data.expert_trans_tensor[:, [i, j]] = load_data.expert_trans_tensor[:, [j, i]]
logger.debug(f"--- Current load status: {device_loads} ---")
return device_to_expert
def process_final_placement(results, num_experts):
current_placement = results["device_to_expert_map"]
revert_list = [[] for i in range(num_experts)]
local_expert_indices = [[-1] * num_experts for i in range(len(current_placement))]
for i, result_list in current_placement.items():
for idx, expert in enumerate(result_list):
revert_list[expert].append(i)
local_expert_indices[i][expert] = idx
device_indices = [[] for _ in range(len(current_placement))]
for device_index, device_experts in current_placement.items():
for expert in range(num_experts):
if expert in device_experts:
device_indices[device_index].append(device_index)
else:
device_indices[device_index].append(random.choice(revert_list[expert]))
local_expert_list = []
for _, val in current_placement.items():
local_expert_list.append(val)
expert_trans_tensor = results["expert_trans_tensor"]
return device_indices, local_expert_indices, local_expert_list, expert_trans_tensor
def process_expert_num(load_dic, num_experts):
for idx, load in load_dic.items():
load_dic[idx] = np.diff(load, prepend=0).astype(dtype=np.int64)
return load_dic
def eplb_greedy(**kwargs):
"""
算法入口
Args:
response (dict): 每张卡上处理的token路由到各个专家的数量
格式: {device_id: np.array([...])}
device_to_expert (dict): 每张卡上部署的专家列表
格式: {device_id: np.array([...])}
algorithm_type (str): 根据不同场景分为 []
world_size (int): ep_size
expert_num (int): 全局专家个数
max_move (int): EX模式下最大移动专家数量
redundant (int): 冗余专家个数
Returns:
bool: 是否触发更新
dict: 设备token转发路由
dict: 每个设备上专家的本地索引
list: 每个设备上部署的最新专家列表
tensor: 在EX(专家交换)场景下生成的专家交换张量,可与原专家分布张量做矩阵乘法后得到最新的专家分布
"""
response = kwargs['response']
algorithm_type = kwargs['algorithm_type']
device_to_expert = kwargs['device_to_expert']
world_size = kwargs['world_size']
expert_num = kwargs['expert_num']
max_move = kwargs['max_move']
redundant = kwargs['redundant']
load_balance_threshold = 100
cost_local = 1
cost_remote = 10
current_pattern = process_expert_num(response, expert_num)
update = True
if (response[0] == 0).all():
update = False
expert_per_rank = (expert_num // world_size) + redundant
expert_mems = {i: 1.0 for i in range(expert_num)}
device_mems = {i: expert_per_rank for i in range(world_size)}
handlers = {
'A2A': A2ARedundantExpertService(
world_size, expert_num, expert_mems, device_mems, cost_local, cost_remote, max_move, load_balance_threshold
),
'AG': AGRedundantExpertService(
world_size, expert_num, expert_mems, device_mems, cost_local, cost_remote, max_move, load_balance_threshold
),
'EX': ExpertExchangeService(
world_size, expert_num, expert_mems, device_mems, cost_local, cost_remote, max_move, load_balance_threshold
),
}
algorithm_service = handlers.get(algorithm_type)
result = algorithm_service.placement_greedy(current_pattern, device_to_expert)
output = process_final_placement(result, expert_num)
device_indices, local_expert_indices, local_expert_list, expert_trans_tensor = output
logger.debug("[MindIE-SD/eplb] Current expert placement computed. local_expert_list=%s.", local_expert_list)
return update, device_indices, local_expert_indices, local_expert_list, expert_trans_tensor