"""
自动并行配置搜索主要入口
"""
import os
import torch
import logging
import multiprocessing
from mindspeed.auto_settings.config.post_info import PostInfo
from mindspeed.auto_settings.config.model_config import set_model_config, get_model_config, \
update_model_config
from mindspeed.auto_settings.config.system_config import set_system_config, get_system_config
from mindspeed.auto_settings.profile.profiler import Profiler
from mindspeed.auto_settings.module.searcher import WhiteSearcher, BlackSearcher, MixedSearcher
from mindspeed.auto_settings.search_space import SearchSpace
from mindspeed.auto_settings.utils.file_utils import restricted_read
from mindspeed.auto_settings.utils.logger import get_logger
from mindspeed.auto_settings.mindspeed_adaptor.mindspeed_executor import ExecutorFlag
from mindspeed.auto_settings.module.search.auto_tuning_disttrain import ModelSpaces, SingleModel
if os.environ.get('AUTO_SETTING_SEARCH_CPU_NUM') is None:
AUTO_SETTING_SEARCH_CPU_NUM = multiprocessing.cpu_count()
elif int(os.environ.get('AUTO_SETTING_SEARCH_CPU_NUM')) > 1:
AUTO_SETTING_SEARCH_CPU_NUM = int(os.environ.get('AUTO_SETTING_SEARCH_CPU_NUM'))
else:
AUTO_SETTING_SEARCH_CPU_NUM = 4
class AutoSettings(object):
def __init__(self):
self.search_spaces = SearchSpace()
self.profiler = Profiler()
self.logger = get_logger("AutoSettings")
def init(self, args):
"""
初始化相关配置
"""
self._init_configs(args)
self._init_hardware(args)
self._init_global_group()
def _init_hardware(self, args):
self.profiler.run(PostInfo.FILENAME, None, ExecutorFlag.PARSE_ARGS)
post_info = restricted_read(os.path.join(str(get_system_config().work_dir), PostInfo.FILENAME))
get_system_config().load_settings(post_info)
update_model_config(post_info.model_config)
def _init_configs(self, args):
"""
初始化相关配置
"""
set_system_config(args)
set_model_config(args)
def _init_global_group(self):
"""
初始化相关
"""
sys_config = get_system_config()
torch.distributed.init_process_group(
backend=torch.distributed.Backend.GLOO,
rank=sys_config.node_rank,
world_size=sys_config.nnodes
)
def _get_searcher(self, search_type):
"""
根据用户配置获取对应的搜索器
"""
if search_type == "white":
return WhiteSearcher()
if search_type == "black":
return BlackSearcher()
if search_type == "mixed":
searcher = MixedSearcher()
searcher.set_white_topk(5)
return searcher
return MixedSearcher()
def search(self, args):
"""
搜索入口
"""
self.logger.info("model config is that:\n %s", str(get_model_config()))
search_type = args.auto_settings_type
searcher = self._get_searcher(search_type)
self.search_spaces = SearchSpace(get_model_config())
search_configs = self.search_spaces.build_search_spaces()
final_configs = searcher.search(configs=search_configs, topk=3)
return final_configs
def auto_setting_fun(self, args):
"""
入口函数
"""
self.logger.info('Auto settings start')
self.logger.info('Auto settings start init')
if hasattr(args, "dist_train") and args.dist_train and args.auto_settings_type == "white":
self._init_configs(args)
self.auto_tuning_disttrain(args)
return
self.init(args)
if hasattr(args, "dist_train") and args.dist_train and args.auto_settings_type == "white":
self.logger.info('Auto settings for disttrain of white')
return
if get_system_config().node_rank != 0:
self.logger.info('Auto settings run on slaves')
self.profiler.run_on_slaves(args)
return
self.logger.info('<==========Auto settings search config start==========>')
final_configs = self.search(args)
self.logger.info("<==========Final config generated==========>")
self.logger.info("The recommended configurations are:")
for i, final_cfg in enumerate(final_configs):
if final_cfg:
self.logger.info("<==========Top #%s config==========>", str(i))
if self.logger.getEffectiveLevel() == logging.DEBUG:
self.logger.debug("\n %s", str(final_cfg))
else:
self.logger.info("\n %s", str(final_cfg))
self.logger.info("<==========Auto settings search completed==========>")
return final_configs
def auto_tuning_disttrain(self, args):
model_spaces = ModelSpaces()
logger = get_logger("main")
if hasattr(args, "dist_train") and args.dist_train:
model_spaces.model_type = "vit+gpt"
try:
from mindspeed.auto_settings.module.search.multimodal_search_surrogate import MultimodalSpaceSearchSurrogate
model_spaces.space_search_model = MultimodalSpaceSearchSurrogate()
model_spaces.dist_train_config()
except Exception as e:
logger.info(str(e))
else:
try:
from mindspeed.auto_settings.module.search.search_engine import SpaceSearch
model_spaces.space_search_model = SpaceSearch()
model_spaces.model_list = [SingleModel(model_spaces.time_recorder)]
model_spaces.model_list[0].model_settings = model_spaces.sources_setting
model_spaces.model_list[0].model_config.sub_work_dir = model_spaces.sources_setting.work_dir
except Exception as e:
logger.info(str(e))
for model in model_spaces.model_list:
model.parse_args(model_spaces.model_type)
if model.model_settings.node_rank != 0:
model.model_settings.executor.wait(
model_config=model.model_config,
gloo_group=model.model_settings.gloo_group
)
return
model._memory_model()
model._performance_model()
model.model_settings.gloo_group = None
model_spaces.space_search_model._space_search(models=model_spaces.model_list, cpu_num=AUTO_SETTING_SEARCH_CPU_NUM)