from enum import Enum, auto
from collections import deque
from multiprocessing import Queue, Process
from multiprocessing import Pool
from ms_service_profiler.task.task_register import filter_dag
from ms_service_profiler.task.task_register import TaskDag
from ms_service_profiler.task.task import Task
from ms_service_profiler.utils.timer import timer
from ms_service_profiler.utils.log import logger, Color
from ms_service_profiler.utils.error import OtherTaskError
class DefaultValue(Enum):
UNFILLED = auto()
class SubprocessInfo:
def __init__(self) -> None:
self.executor = None
self.queues = []
self.processes = []
def get_queues(self):
return self.queues
def add_queue(self, queue):
self.queues.append(queue)
def new_process(self, send_queue, args):
recv_queue = Queue()
args = args + (recv_queue, send_queue)
process = Process(target=task_run, args=args)
self.queues.append(recv_queue)
self.processes.append(process)
process.start()
def is_alive(self):
return any((x.is_alive for x in self.processes))
class TaskManager:
def __init__(self, task_dag: TaskDag) -> None:
self.task_dag = task_dag
self.task_manager_info_dict = dict()
self.manager_recv_queue = Queue()
self.pool = []
self.pool_owner = []
self.total_tasks = 0
self.finished_tasks = 0
self.error_tasks = 0
def init_task(self, task_name) -> None:
self.task_manager_info_dict.setdefault(task_name, dict(process_pool_info=[],
wait_pool_index=[],
queues=[],
state="unstart",
gather_data=deque()))
return self.task_manager_info_dict[task_name]
def init_task_waiting_pool(self, src_dag, pool_index):
for task_name, _ in src_dag.get_ordered_task_names():
task_manager_info = self.init_task(task_name)
task_manager_info.get("wait_pool_index").append(pool_index)
def create_pool(self, data_source_task, single_data_list, src_dag, args):
task_manager_info = self.init_task(data_source_task.name)
process_info = SubprocessInfo()
pool_index = len(self.pool)
self.pool.append(process_info)
self.pool_owner.append(data_source_task.name)
task_manager_info.get("process_pool_info").append(process_info)
queues = process_info.get_queues()
self.init_task_waiting_pool(src_dag, pool_index)
for _, data in enumerate(single_data_list):
process_info.new_process(self.manager_recv_queue, (data, src_dag, pool_index, args))
task_manager_info.get("queues").extend(queues)
self.send_go(data_source_task.name)
def set_no_source_data(self, task_name):
task_manager_info = self.init_task(task_name)
task_manager_info["state"] = "no_source_data"
for next_task_name in self.task_dag.get_next_task_names(task_name):
no_source_flag = True
for prev_task_name in self.task_dag.get_prev_task_names(next_task_name):
prev_task_manager_info = self.init_task(prev_task_name)
if prev_task_manager_info.get("state") != "no_source_data":
no_source_flag = False
break
if no_source_flag:
self.set_no_source_data(next_task_name)
def get_task_state(self, task_name):
task_manager_info = self.init_task(task_name)
return task_manager_info.get("state", "unstart")
def get_task_process_cnt(self, task_name):
task_manager_info = self.init_task(task_name)
return len(task_manager_info.get("queues", []))
def is_all_finished(self):
return all((x is None for x in self.pool_owner)) or all((not x.is_alive() for x in self.pool))
def is_all_prev_finished(self, task_name):
error_flag = False
for prev_task_name in self.task_dag.get_prev_task_names(task_name):
if self.get_task_state(prev_task_name) == "error":
error_flag = True
if self.get_task_state(prev_task_name) not in ["finished", "no_source_data", "error"]:
return False, error_flag
task_manager_info = self.init_task(task_name)
for wait_pool_index in task_manager_info.get("wait_pool_index"):
if self.pool_owner[wait_pool_index] != task_name:
return False, error_flag
return True, error_flag
def update_task_progress(self, flag='finished'):
if flag == 'finished':
self.finished_tasks += 1
elif flag == 'error':
self.error_tasks += 1
completed_tasks = self.finished_tasks + self.error_tasks
if self.total_tasks > 0:
progress = min(completed_tasks / self.total_tasks * 100, 100)
error_msg = f", {self.error_tasks} tasks failed" if self.error_tasks > 0 else ""
logger.info(f"Progress: {progress:.1f}% ({completed_tasks}/{self.total_tasks} tasks completed{error_msg})")
else:
logger.info(f"Progress: 100.0% (No tasks need to be processed)")
def set_task_finished(self, finished_task_name, next_task_set):
task_manager_info = self.init_task(finished_task_name)
if task_manager_info["state"] != 'error':
task_manager_info["state"] = "finished"
logger.info(f"{Color.BRIGHT_GREEN}task [{finished_task_name}] finished.{Color.RESET}")
self.update_task_progress()
for pool_index, next_task_name in next_task_set:
self.pool_owner[pool_index] = next_task_name
if next_task_name is None:
continue
next_task_manager_info = self.init_task(next_task_name)
next_task_manager_info.get("process_pool_info", []).append(self.pool[pool_index])
next_task_manager_info.get("queues", []).extend(self.pool[pool_index].get_queues())
all_finished, has_err = self.is_all_prev_finished(next_task_name)
if next_task_manager_info["state"] != "unstart":
continue
if all_finished:
next_task_manager_info = self.init_task(next_task_name)
next_task_manager_info["state"] = "started"
if has_err:
self.send_go(next_task_name, "error")
else:
self.send_go(next_task_name)
else:
pass
def set_task_error(self, error_task_name, error_index, err_msg):
task_manager_info = self.init_task(error_task_name)
if task_manager_info.get("state") != "error":
task_manager_info.get("gather_data").clear()
if err_msg:
logger.error(f"{Color.BRIGHT_RED}task [{error_task_name}] error. due to {err_msg} {Color.RESET}")
self.update_task_progress(flag='error')
task_manager_info["state"] = "error"
return self.fill_gater_data(error_task_name, error_index, err_msg, ignore_error_state=True)
def send_msg_to_one_process(self, task_name, task_index, msg, param):
task_manager_info = self.init_task(task_name)
if task_index < len(task_manager_info.get("queues", [])):
task_manager_info.get("queues", [])[task_index].put((msg, param))
def send_msg_to_one_task(self, task_name, msg, param):
task_manager_info = self.init_task(task_name)
for queue in task_manager_info.get("queues", []):
queue.put((msg, param))
def send_go(self, task_name, go_msg="go"):
logger.info(f"{Color.BRIGHT_BLUE}task [{task_name}] start. {Color.RESET}")
for index in range(self.get_task_process_cnt(task_name)):
self.send_msg_to_one_process(task_name, index, go_msg, index)
def fill_gater_data(self, task_name, task_index, data, ignore_error_state=False):
task_manager_info = self.init_task(task_name)
if task_manager_info.get('state') == 'error' and ignore_error_state is False:
return None
gather_data = task_manager_info.get("gather_data")
for deque_index, list_item in enumerate(gather_data):
if list_item[task_index] is not DefaultValue.UNFILLED:
continue
list_item[task_index] = data
if deque_index == 0 and all((x is not DefaultValue.UNFILLED for x in list_item)):
return gather_data.popleft()
break
else:
cnt = len(task_manager_info.get("queues", []))
if cnt == 1:
return [data]
list_item = [DefaultValue.UNFILLED] * cnt
list_item[task_index] = data
gather_data.append(list_item)
return None
def start(self):
while (True):
who_task_name, who_index, msg, param = self.manager_recv_queue.get()
if msg == "finished":
data, after_error = param
gather_data = self.fill_gater_data(who_task_name, who_index, data, ignore_error_state=after_error)
if gather_data is not None:
self.send_msg_to_one_task(who_task_name, msg, None)
self.set_task_finished(who_task_name, set(gather_data))
if self.is_all_finished():
break
elif msg == "error":
err_msg = param
if err_msg is not None:
self.send_msg_to_one_task(who_task_name, msg, f"task {who_index} occurs error. due to {err_msg}")
gather_err = self.set_task_error(who_task_name, who_index, err_msg)
if gather_err is not None:
self.send_msg_to_one_task(who_task_name, "error_gather", gather_err)
if self.is_all_finished():
break
elif msg == "crash":
for task_name in self.task_manager_info_dict.keys():
self.send_msg_to_one_task(task_name, "error", None)
break
elif msg == "gather":
dst, data = param
gather_data = self.fill_gater_data(who_task_name, who_index, data)
if gather_data is not None:
self.send_msg_to_one_process(who_task_name, dst, msg, gather_data)
elif msg == "all_gather":
data = param
gather_data = self.fill_gater_data(who_task_name, who_index, data)
if gather_data is not None:
self.send_msg_to_one_task(who_task_name, msg, gather_data)
elif msg == "broadcast":
data = param
self.send_msg_to_one_task(who_task_name, msg, data)
elif msg == "send_to":
dst, data = param
self.send_msg_to_one_process(who_task_name, dst, 'p2p', data)
else:
pass
def task_run(input_data, src_dag, pool_index, args, recv_queue, send_queue):
task_index = None
run_res_data = dict(prof_path=input_data)
def recv():
msg, gather_data = recv_queue.get()
if msg == 'error':
raise OtherTaskError(gather_data)
return msg, gather_data
def recv_ignore_error():
msg = "error"
while msg == 'error':
msg, gather_data = recv_queue.get()
if msg != 'error':
return msg, gather_data
def finished_sync(task_name, task_index, next_task_name, after_error=False):
send_queue.put((task_name, task_index, "finished", ((pool_index, next_task_name), after_error)))
msg, _ = recv()
if msg != 'finished':
raise ValueError("Expected 'finished' message, but received: {}".format(msg))
def error_sync(task_name, task_index, err_msg=None):
send_queue.put((task_name, task_index, "error", err_msg))
msg, gather_data = recv_ignore_error()
return msg, gather_data
def crash(task_name, task_index):
send_queue.put((task_name, task_index, "crash", None))
for task_name, next_task_name in src_dag.get_ordered_task_names():
try:
_, task_index = recv()
task_info = TaskDag.get_task_reg_info(task_name)
if isinstance(task_info.task_cls, Task):
task_ins = task_info.task_cls
else:
task_ins = task_info.task_cls(args)
task_ins.init(task_name, task_index, recv, send_queue)
for depends_name in TaskDag.get_depends_data_names(task_name):
if depends_name in run_res_data:
task_ins.set_depends_result(depends_name, run_res_data.get(depends_name, None))
task_res = task_ins.run()
for output_name in TaskDag.get_outputs_data_names(task_name):
run_res_data.setdefault(output_name, task_res)
finished_sync(task_name, task_index, next_task_name)
except OtherTaskError as e:
error_sync(task_name, task_index, None)
finished_sync(task_name, task_index, next_task_name, after_error=True)
if args.log_level == 'verbose':
break
except Exception as e:
error_sync(task_name, task_index, str(e))
finished_sync(task_name, task_index, next_task_name, after_error=True)
logger.exception(f'{task_name}-{task_index} error. {str(e)}')
if args.log_level == 'verbose':
crash(task_name, task_index)
raise
@timer()
def tasks_run(data_source_tasks, task_dag, input_path, args):
task_manager = TaskManager(task_dag)
has_tasks = False
unique_tasks = set()
for data_source_task in data_source_tasks:
single_data_list = data_source_task.task_cls.get_prof_paths(input_path)
if len(single_data_list) == 0:
task_manager.set_no_source_data(data_source_task.name)
continue
has_tasks = True
src_dag = filter_dag(task_dag, data_source_task.name)
for task_name, _ in src_dag.get_ordered_task_names():
unique_tasks.add(task_name)
task_manager.create_pool(data_source_task, single_data_list, src_dag, args)
task_manager.total_tasks = len(unique_tasks)
if has_tasks:
task_manager.start()