import os
from enum import IntEnum
import torch.distributed as dist
from typing import Any, Dict, List, Optional
from mindspeed.auto_settings.config.search_config import SearchConfig, ExecutorFlag
from mindspeed.auto_settings.config.system_config import get_system_config
from mindspeed.auto_settings.module.parse.profiling_parse.profiling_node_parse import GatherNodeProfiling
from mindspeed.auto_settings.profile.argv import BaseArgv, FilterArgv, SearchConfigArgv
from mindspeed.auto_settings.profile.runner import Runner
from mindspeed.auto_settings.utils.file_utils import restricted_write
from mindspeed.auto_settings.utils.logger import get_logger
from mindspeed.auto_settings.utils.singleton import Singleton
from mindspeed.auto_settings.utils.utils import check_file_exists, get_prof_dir
class Profiler(metaclass=Singleton):
PARSE_ARGS_ENV = "OOTB_OPTIMIZER_PARSE_ARGS"
PARSE_MODEL_ENV = "OOTB_OPTIMIZER_PARSE_MODEL"
PROFILING_ENV = "OOTB_OPTIMIZER_PROFILING"
PROFILING_ENV_BLACK = "OOTB_OPTIMIZER_PROFILING_BLACK"
ENABLED_ENV_MARKER = "TRUE"
def __init__(self):
self._logger = get_logger("Profiler")
def init(self):
self.runner = Runner()
def run(self, output_filename: str,
cfg: Optional[SearchConfig] = None,
flag: ExecutorFlag = ExecutorFlag.PROFILE):
"""
运行在主节点
"""
self.init()
if flag == ExecutorFlag.PARSE_ARGS:
return_code = self._prepare(output_filename, cfg=cfg, flag=flag)
return return_code
dist.monitored_barrier(wait_all_ranks=True)
dist.broadcast_object_list([output_filename, cfg, flag])
return_code = self._prepare(output_filename, cfg=cfg, flag=flag)
dist.barrier()
return return_code
def run_on_slaves(self, args):
"""
运行在从节点上
"""
self.init()
count = 0
while True:
try:
self._logger.info(f"[#{count}] Waiting for master.....")
dist.monitored_barrier(wait_all_ranks=True)
bcast_list: List[Any] = [None] * 3
dist.broadcast_object_list(bcast_list)
output_filename, cfg, flag = bcast_list
self._prepare(output_filename, cfg=cfg, flag=flag)
dist.barrier()
except RuntimeError as e:
if "successfully reached monitoredBarrier" in str(e):
count += 1
if count > 10:
self._logger.critical(f"Wait timeout, shutting down.")
raise e
elif "Connection closed by peer" in str(e):
self._logger.info("Master shuts down, exiting.....")
return
else:
raise e
def _prepare(
self,
output_filename: str,
cfg: Optional[SearchConfig] = None,
flag: ExecutorFlag = ExecutorFlag.PROFILE
) -> int:
system_config = get_system_config()
work_dir = system_config.work_dir
save_path = os.path.join(work_dir, output_filename)
modified_env = self._update_env(flag)
modified_argv = self._update_argv(save_path, cfg)
if not os.path.exists(work_dir):
os.mkdir(work_dir)
if cfg:
restricted_write(os.path.join(system_config.work_dir, f"at_{system_config.node_rank}.pkl"), cfg)
return_code = self.runner.run(modified_argv, modified_env)
return return_code
def _update_env(self, flag: ExecutorFlag) -> Dict[str, str]:
"""
更新环境变量
"""
env = self.runner.get_base_env()
env.pop(self.PARSE_ARGS_ENV, None)
env.pop(self.PARSE_MODEL_ENV, None)
env.pop(self.PROFILING_ENV, None)
env.pop(self.PROFILING_ENV_BLACK, None)
if flag == ExecutorFlag.PARSE_ARGS:
env.update({self.PARSE_ARGS_ENV: self.ENABLED_ENV_MARKER})
elif flag == ExecutorFlag.PARSE_MODEL:
env.update({self.PARSE_MODEL_ENV: self.ENABLED_ENV_MARKER})
elif flag == ExecutorFlag.PROFILE:
env.update({self.PROFILING_ENV: self.ENABLED_ENV_MARKER})
elif flag == ExecutorFlag.PROFILE_BLACK:
env.update({self.PROFILING_ENV_BLACK: self.ENABLED_ENV_MARKER})
return env
def _update_argv(
self,
save_path: str,
cfg: Optional[SearchConfig]
) -> List[str]:
"""
更新运行参数
"""
argv = self.runner.get_base_argv()
BaseArgv.base_argv(argv, save_path)
FilterArgv.filter_argv(argv)
SearchConfigArgv.update_argv(argv, config=cfg)
return argv
def profile(self, configs):
"""
对相关数据并行进行数据采集
"""
profile_results = []
self._logger.info("<==========Begin to profile==========>")
for idx, (config, file_name) in enumerate(configs):
if not check_file_exists(file_name):
self._logger.info('<==========the %s/%s loop==========>', str(idx), str(len(configs)))
self._logger.info("profile_db_configs (tp, pp, dp, cp, ep, #layers, seq_len):")
self.run(file_name, config, flag=config.profile_type)
if config.profile_type == ExecutorFlag.PROFILE:
file_path = os.path.join(get_system_config().work_dir, get_prof_dir(config))
profiling_node_parse = GatherNodeProfiling(file_path)
profiling_res = profiling_node_parse.fuse_node_pkl()
profile_results.append([config, profiling_res])
self._logger.info("<==========Finished profiling==========>")
return profile_results