import threading
from multiprocessing.connection import AuthenticationError
import torch_npu
from mindiesd.utils.logs.logging import logger
from .eplb_scheduler import get_manager_client
from .task_payload import TaskType, TaskPayload
from .task_transfer import ProfileTaskTransfer
from .task_handler import handle_profile_task, handle_update_layout_task, handle_unknown_task
TASK_DISPATCHER = {
TaskType.PROFILE: handle_profile_task,
TaskType.UPDATE_LAYOUT: handle_update_layout_task,
}
def _log_unknown_instruction(instruction):
logger.error(
"[MindIE-SD/eplb] Unknown instruction ignored. "
"issue=instruction is not a TaskPayload, instruction=%s. "
"possible_cause=an unsupported object was inserted into the EPLB instruction queue. "
"Troubleshooting: ensure the queue only contains PROFILE or UPDATE_LAYOUT TaskPayload objects.",
instruction,
)
def parse_module(module):
dispatcher_list = []
expert_load_collector_list = []
for _, child in module.named_modules():
if hasattr(child, 'dispatcher') and hasattr(child, 'expert_load_collector'):
dispatcher_list.append(child.dispatcher)
expert_load_collector_list.append(child.expert_load_collector)
return dispatcher_list, expert_load_collector_list
def expert_info_transfer_pool(module, instruction_queue, upload_queue, device):
dispatcher_list, expert_load_collector_list = parse_module(module)
transfer_stream = torch_npu.npu.Stream(device)
for idx, collector in enumerate(expert_load_collector_list):
collector.task_transfer = ProfileTaskTransfer(instruction_queue, idx, collector.lb_interval)
while True:
instruction = instruction_queue.get()
if instruction is None or instruction == 'exit':
logger.debug("[MindIE-SD/eplb] Expert info transfer pool received exit instruction.")
break
if isinstance(instruction, TaskPayload):
handler_function = TASK_DISPATCHER.get(instruction.task_type, handle_unknown_task)
handler_function(instruction, upload_queue, expert_load_collector_list, dispatcher_list, transfer_stream)
else:
_log_unknown_instruction(instruction)
def connect_to_schedule_manager(rank_in_group, ip, port, auth_key):
addr = (ip, port)
manager = get_manager_client(addr, auth_key)
try:
manager.connect()
except AuthenticationError as error:
logger.error(
"[MindIE-SD/eplb] EPLB worker authentication failed. "
"issue=scheduler rejected worker credentials, rank_in_group=%s, manager_addr=%s:%s, actual_error=%s. "
"possible_cause=the scheduler and worker use different EPLB authentication keys. "
"Troubleshooting: configure the same EPLB_AUTH_KEY or auth_key for the scheduler and every worker.",
rank_in_group,
ip,
port,
error,
)
raise
except OSError as error:
logger.error(
"[MindIE-SD/eplb] EPLB worker failed to connect to scheduler. "
"issue=scheduler connection unavailable, rank_in_group=%s, manager_addr=%s:%s, actual_error=%s. "
"possible_cause=the scheduler is not running, the configured address is incorrect, "
"or the network path is unavailable. "
"Troubleshooting: check the scheduler process, listening address, worker configuration, and network.",
rank_in_group,
ip,
port,
error,
)
raise
logger.debug(
"[MindIE-SD/eplb] Connected to schedule manager. rank_in_group=%s, manager_addr=%s:%s.",
rank_in_group,
ip,
port,
)
instruction_queue = manager.get_instruction_queues(rank=rank_in_group)
upload_queue = manager.get_upload_queues(rank=rank_in_group)
return instruction_queue, upload_queue
def construct_expert_info_transfer_pool(**kwargs):
module = kwargs['module']
rank_in_group = kwargs['rank_in_group']
device = kwargs['device']
ip = kwargs['ip']
port = kwargs['port']
auth_key = kwargs['auth_key']
instruction_queue, upload_queue = connect_to_schedule_manager(rank_in_group, ip, port, auth_key)
if instruction_queue is None or upload_queue is None:
return None, None
worker = threading.Thread(
target=expert_info_transfer_pool, args=(module, instruction_queue, upload_queue, device), daemon=True
)
worker.start()
return worker, instruction_queue