import os
import threading
from typing import Callable, Any, Optional, List
import ttp_logger
import ttp_c2python_api
RET_OK = 0
RET_ERROR = 1
class MindxEngineHandler:
_instance = None
_lock = threading.Lock()
_action_func_map = None
def __new__(cls, *args, **kwargs):
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls, *args, **kwargs)
cls._action_func_map = {}
return cls._instance
def __init__(self):
self.lock = threading.Lock()
@staticmethod
def support_actions() -> List:
return ['report_fault_ranks', 'report_stop_complete', 'report_strategies', 'report_result']
def check_supported_action(self, action: str) -> bool:
return action in self.support_actions()
def register_api_callback(self, action: str, func: Callable):
with self.lock:
if not self.check_supported_action(action):
ttp_logger.LOGGER.warning(f"register action[{action}] fail, "
f"only support action {MindxEngineHandler.support_actions()}")
return RET_ERROR
if func is None or not callable(func):
ttp_logger.LOGGER.error(f"register action[{action}] fail, func must be a callable!")
return RET_ERROR
self._action_func_map[action] = func
ttp_logger.LOGGER.info(f"register action[{action}] success")
return RET_OK
def execute_callback(self, action: str, args: Optional[Any]):
with self.lock:
try:
ttp_logger.LOGGER.info(f"do action {action}, arg={repr(args)}")
if action not in self._action_func_map:
ttp_logger.LOGGER.error(f"action {action} unregistered")
return RET_ERROR
func = self._action_func_map[action]
if args is None:
func()
else:
func(*args)
return RET_OK
except Exception as e:
ttp_logger.LOGGER.info(f"do action {action} err, err={e.__str__()}, arg={args}")
return RET_ERROR
def is_all_action_registered(self):
count = len(self._action_func_map)
all_registered = count == len(self.support_actions())
ttp_logger.LOGGER.info(f" {count} actions api has been registered..., all registered:{all_registered}")
return RET_OK if all_registered else RET_ERROR
mindx_handler: MindxEngineHandler = MindxEngineHandler()
def tft_register_mindx_callback(action: str, func: Callable):
ret = mindx_handler.register_api_callback(action, func)
return RET_OK if ret == RET_OK else RET_ERROR
def tft_notify_controller_stop_train(fault_ranks: dict, stop_type: str = "stop", timeout: int = None):
if stop_type == "stop":
ret = ttp_c2python_api.mindx_stop_train_callback(fault_ranks)
elif stop_type == "pause":
ret = ttp_c2python_api.mindx_pause_train_callback(timeout)
else:
ret = RET_ERROR
return RET_OK if ret == RET_OK else RET_ERROR
def tft_notify_controller_on_global_rank(fault_ranks: dict, time: int = 1):
ret = ttp_c2python_api.mindx_notify_fault_callback(fault_ranks, time)
return RET_OK if ret == RET_OK else RET_ERROR
def tft_notify_controller_prepare_action(action: str, fault_ranks: dict = None):
fault_ranks = fault_ranks or {}
ret = ttp_c2python_api.mindx_prepare_action_callback(action, fault_ranks)
return RET_OK if ret == RET_OK else RET_ERROR
def tft_notify_controller_change_strategy(strategy: str, params: str = ""):
ret = ttp_c2python_api.mindx_change_strategy_callback(strategy, params)
return RET_OK if ret == RET_OK else RET_ERROR
def tft_notify_controller_dump():
ret = ttp_c2python_api.mindx_notify_dump_callback()
return RET_OK if ret == RET_OK else RET_ERROR
def report_fault_ranks_callback(error_rank_dict, error_rank_code):
ret = mindx_handler.execute_callback('report_fault_ranks', (error_rank_dict, error_rank_code,))
return RET_OK if ret == RET_OK else RET_ERROR
def report_stop_complete_callback(code, msg, error_rank_dict):
ret = mindx_handler.execute_callback('report_stop_complete', (code, msg, error_rank_dict,))
return RET_OK if ret == RET_OK else RET_ERROR
def report_strategies_callback(error_rank_dict, strategy_list):
ret = mindx_handler.execute_callback('report_strategies', (error_rank_dict, strategy_list,))
return RET_OK if ret == RET_OK else RET_ERROR
def report_result_callback(code, msg, error_rank_dict, curr_strategy):
ret = mindx_handler.execute_callback('report_result', (code, msg, error_rank_dict, curr_strategy,))
return RET_OK if ret == RET_OK else RET_ERROR
def register_check_callback():
ret = mindx_handler.is_all_action_registered()
return RET_OK if ret == RET_OK else RET_ERROR