from typing import Deque, List, Optional, Tuple
from collections import deque
from copy import deepcopy
import os
import sys
import traceback as tb
from mindspeed.auto_settings.utils.logger import get_logger, change_stream_handler
from mindspeed.auto_settings.module.memory.memory_modeling import MemoryModeling
from mindspeed.auto_settings.config.search_config import SearchConfig
from mindspeed.auto_settings.module.search.stage_1_prune import stage_1_discrete_search_space_prune
from mindspeed.auto_settings.config.model_config import ModelConfig
from mindspeed.auto_settings.config.post_info import PostInfo
from mindspeed.auto_settings.config.system_config import get_system_config
from mindspeed.auto_settings.utils.utils import get_prof_dir
from mindspeed.auto_settings.utils.file_utils import restricted_read
from functools import partial
from multiprocessing import Pool, JoinableQueue, Process, Event, queues, Manager
from io import StringIO
import time
from mindspeed.auto_settings.module.model_performance import ModelPerformance
from mindspeed.auto_settings.auto_settings import SingleModel
class SpaceSearch:
def __init__(self):
self._logger = get_logger("search")
self.perf_cfg_map: Deque[Tuple[float, Optional[SearchConfig]]] = None
def _space_search(self, models: list[SingleModel], cpu_num: int):
logger = get_logger("search_logger")
model = models[0]
world_size = models[0].model_settings.search_world_size
search_cfg_start_time = time.time()
best_cfgs = self.search_demo(model=model)
search_cfg_end_time = time.time()
logger.info(">>>>>> Search_cfg cost time: %s ms",
str((search_cfg_end_time - search_cfg_start_time) * 1000))
logger.info("<==========Final config generated==========>")
logger.info("The recommended configs are:")
for i, final_cfg in enumerate(best_cfgs):
if final_cfg:
logger.info("<==========Top #%s config==========>", str(i))
logger.info("\n %s", str(final_cfg))
logger.info("<==========Launch training==========>")
def search_demo(
self,
model: SingleModel,
re_profiling_flag=True
) -> [List[Optional[SearchConfig]], tuple]:
mem_model = model.memory_model
perfmodel = model.model_performance
setting = model.model_settings
model_config = model.model_config
working_dir = model.model_settings.work_dir,
device_mem_cap = get_system_config().memory_cap
self._logger.info(f"Search: total_device_num: {get_system_config().search_world_size}")
self._logger.info(f"Search: device_mem_cap: {device_mem_cap}")
best_perf_cfg_map: Deque[Tuple[float, Optional[SearchConfig]]] = deque([(float("inf"), None)] * 3, 3)
stage_1_valid_ptd_configs = stage_1_discrete_search_space_prune(model_config)
self._logger.info(f"Stage [1] pruned result: number of valid PTD configurations [{len(stage_1_valid_ptd_configs)}]")
for cfg in stage_1_valid_ptd_configs:
self._logger.info(f"Stage [1] pruned config: TP=[{cfg.tp}] PP=[{cfg.pp}] LAYERS_PER_VPP=[{cfg.layers_per_vpp}] DP=[{cfg.dp}] CP=[{cfg.cp}] EP=[{cfg.ep}] ZeRO=[{cfg.zero1}]")
uncovered_prof = []
profile_count = [0]
fw_performance = 0
for cfg in stage_1_valid_ptd_configs:
self._logger.info("====================")
self._logger.info(f"Looking at:\n\n{cfg}")
recompute_mem, peak_stage_mem, optimizer_peak = mem_model.estimate(cfg)
if max(peak_stage_mem, optimizer_peak) <= device_mem_cap:
try:
perf, uncovered_prof, use_mc2, fw_performance = model.model_performance.performance(
cfg, working_dir, profile_count, re_profiling_flag
)
except Exception as err:
self._logger.warning(f"Search: ERROR during perf_modeling_calculation: {type(err).__name__}")
tb.print_exc()
self._logger.debug(f"before recompute, perf = {perf} and memory = {peak_stage_mem}")
self._logger.debug(f"success enter recompute_solver and tp = {cfg.tensor_model_parallel_size} "
f"pp = {cfg.pipeline_model_parallel_size} "
f"layers_per_vpp={cfg.num_layers_per_virtual_pipeline_stage} "
f"dp = {cfg.data_parallel_size} cp = {cfg.context_parallel_size} "
f"ep = {cfg.expert_model_parallel_size} zero = {cfg.use_distributed_optimizer}")
need_recompute, new_perf, add_mem, recompute_layer = self.full_recompute_solver(
device_mem_cap - peak_stage_mem, model_config, perf, cfg, recompute_mem, fw_performance
)
new_memory = add_mem + peak_stage_mem
self._logger.debug(f"after recompute, perf = {new_perf} and need_recompute = {need_recompute}")
self._logger.debug(f"cur mem_estimated = {new_memory}, recompute_layer = {recompute_layer}")
better_found = False
for i, perf_cfg in enumerate(best_perf_cfg_map):
if new_perf < perf_cfg[0]:
better_found = True
cfg.performance = new_perf
cfg.memory = new_memory
cfg.recompute_num_layers = recompute_layer
cfg.use_ascend_mc2 = use_mc2 if cfg.tensor_model_parallel_size > 1 else False
self._logger.info(f"Search: SUCCESSFUL Better #{i} Config Found.")
self._logger.debug(f"Performance Estimation: {new_perf}.")
best_perf_cfg_map.pop()
best_perf_cfg_map.insert(i, (new_perf, deepcopy(cfg)))
break
if not better_found:
self._logger.info(f"Sub-optimal performance, next!")
else:
self._logger.info(f"OOM found, next!")
return [cfg for _, cfg in best_perf_cfg_map]
def full_recompute_solver(self, oom_cap, model_cfg: ModelConfig, perf, search_config, fw_memory, fw_performance):
if search_config.layers_per_vpp:
num_model_chunks = search_config.num_layers // search_config.layers_per_vpp // search_config.pp
layers_per_vpp = search_config.layers_per_vpp
else:
num_model_chunks = 1
layers_per_vpp = model_cfg.num_layers // search_config.pp
warmup_micro_batchs, total_num_micro_batches = self.get_num_warmup_micro_batches(num_model_chunks, search_config,
model_cfg)
release_mem = 0
time_cost = 0
num_layers = model_cfg.num_layers // search_config.pp
need_recompute = True
memory_per_layer = fw_memory
max_release_mem = warmup_micro_batchs * layers_per_vpp * memory_per_layer - memory_per_layer
if max_release_mem <= oom_cap:
return False, perf - total_num_micro_batches * num_layers * fw_performance, max_release_mem, 0
if search_config.layers_per_vpp:
max_release_mem = (num_model_chunks - 1) * search_config.pp * layers_per_vpp * memory_per_layer
if max_release_mem <= oom_cap:
layer_calculate = (oom_cap - max_release_mem) // ((2 * search_config.pp - 1) * memory_per_layer)
release_mem += (2 * search_config.pp - 1) * layer_calculate * memory_per_layer + max_release_mem - memory_per_layer
time_cost += (num_layers - layers_per_vpp + layer_calculate) * total_num_micro_batches * fw_performance
return True, perf - time_cost, release_mem, layers_per_vpp - layer_calculate
layer_calculate = (oom_cap // (memory_per_layer * search_config.pp))
release_mem += layer_calculate * memory_per_layer * search_config.pp
if layer_calculate < num_layers:
release_mem -= memory_per_layer
time_cost += total_num_micro_batches * layer_calculate * fw_performance
return need_recompute, perf - time_cost, release_mem, num_layers - layer_calculate
else:
layer_calculate = (oom_cap // (memory_per_layer * search_config.pp))
release_mem += layer_calculate * memory_per_layer * search_config.pp
if layer_calculate < num_layers:
release_mem -= memory_per_layer
time_cost += total_num_micro_batches * layer_calculate * fw_performance
return need_recompute, perf - time_cost, release_mem, num_layers - layer_calculate
def get_num_warmup_micro_batches(self, num_model_chunks, search_config, model_cfg):
pipeline_parallel_size = search_config.pp
data_parallel_size = search_config.dp
num_microbatches = model_cfg.gbs // (search_config.mbs * data_parallel_size)
if pipeline_parallel_size <= 1:
return 1, num_microbatches
pipeline_parallel_size = pipeline_parallel_size
pipeline_parallel_rank = 0
total_num_micro_batches = num_microbatches * num_model_chunks
if num_model_chunks == 1:
num_warmup_micro_batches = pipeline_parallel_size - pipeline_parallel_rank - 1
else:
num_warmup_micro_batches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
num_warmup_micro_batches += (num_model_chunks - 1) * pipeline_parallel_size
num_warmup_micro_batches += 1
num_warmup_micro_batches = min(num_warmup_micro_batches, total_num_micro_batches)
return num_warmup_micro_batches, num_microbatches