from typing import List
from mindspeed.auto_settings.config.search_config import SearchConfig
from mindspeed.auto_settings.module.communication.comm_perf_linear_model_factory import (
CommPerfLinearModelFactory,
)
from mindspeed.auto_settings.module.communication.comm_perf_predictor import CommPerfPredictor, SimpleParallelCfg
from mindspeed.auto_settings.module.communication.communication_profile import PpProfileTimeInfo
class DebugPpComm:
def __init__(self):
self.comm_x = 0
self.total_time = 0
self.cfg = SearchConfig()
self.model_type = None
self.cfg_no = 0
class PpCommPerfPredictor(CommPerfPredictor):
def __init__(self, hard_info):
super(PpCommPerfPredictor, self).__init__(hard_info)
def get_communication_info_from_profile(self, pp_profile_time_info, hcom_info_tage_id, pp):
pp_profile_time_info.each_pp_time = hcom_info_tage_id.min_pp_time
def receive_samples_from_profiling(
self, config_no, model_config: SearchConfig, pp_profile_time_info: PpProfileTimeInfo
):
tp = model_config.tp
cp = model_config.cp
pp = model_config.pp
dp = model_config.dp
layers_per_vpp = model_config.layers_per_vpp if model_config.layers_per_vpp else 1
comm_x = 1 / (layers_per_vpp * tp * cp)
comm_time = pp_profile_time_info.each_pp_time
cfg = SimpleParallelCfg(config_no, tp, cp, dp, '', pp, layers_per_vpp)
if pp > 1:
max_domain = pp * dp * cp * tp
min_domain = pp * dp * cp * tp
pp_time_model = CommPerfLinearModelFactory.get_or_create_model(
"pp",
min_rank_num=max_domain,
max_rank_num=min_domain,
max_hccs_dev_num=self.max_hccs_rank_num,
)
pp_time_model.add_sample(*(comm_x, comm_time, cfg))
debug_info = DebugPpComm()
debug_info.comm_x = comm_x
debug_info.total_time = comm_time
debug_info.cfg = model_config
debug_info.cfg_no = config_no
debug_info.model_type = str(type(pp_time_model))
self.debug_info_list.append(debug_info)
def fit(self):
for model in CommPerfLinearModelFactory.get_models_by_module_name("pp"):
if model:
model.fit()
def debug(self, config_list: List[SearchConfig]):
self.logger.debug(f"****************** PP modeling ***********************")
tplt = "{0:<1}\t{1:<1}\t{2:<1}\t{3:<1}\t{4:<1}\t{5:<1}\t{6:<1}"
header = tplt.format("ConfigNo", "tp", "dp", "pp", "vp", "cp", "ep", chr(12288))
self.logger.debug(header)
for i, model_config in enumerate(config_list):
if model_config.pp > 1:
pp = model_config.pp
dp = model_config.dp
cp = model_config.cp
tp = model_config.tp
cur_row = tplt.format(
i,
tp,
dp,
pp,
str(model_config.layers_per_vpp),
cp,
model_config.ep,
chr(12288),
)
self.logger.debug(cur_row)
for model in CommPerfLinearModelFactory.get_models_by_module_name("pp"):
if model:
model.debug(f"pp_{model.protocol_name}")
def predict(self, search_cfg: SearchConfig):
tp = search_cfg.tensor_model_parallel_size
dp = search_cfg.data_parallel_size
pp = search_cfg.pipeline_model_parallel_size
vp = (
search_cfg.num_layers // (pp * search_cfg.num_layers_per_virtual_pipeline_stage)
if search_cfg.num_layers_per_virtual_pipeline_stage
else 1
)
cp = search_cfg.context_parallel_size
pp_time = 0.0
comm_x = 1 / (vp * tp * cp)
max_domain = pp * dp * cp * tp
min_domain = pp * dp * cp * tp
pp_time_model = CommPerfLinearModelFactory.get_or_create_model(
"pp",
min_rank_num=max_domain,
max_rank_num=min_domain,
max_hccs_dev_num=self.max_hccs_rank_num,
)
if pp > 1:
each_pp_time = pp_time_model.predict(*(comm_x,))
each_pp_time = each_pp_time * 2
pp_time = each_pp_time * (pp * vp - 1) * 2
return pp_time