from typing import List
from collections import namedtuple
from mindspeed.auto_settings.config.search_config import SearchConfig
from mindspeed.auto_settings.module.communication.comm_perf_linear_model_factory import (
CommPerfLinearModelFactory,
)
from mindspeed.auto_settings.module.communication.comm_perf_predictor import CommPerfPredictor, SimpleParallelCfg
from mindspeed.auto_settings.module.communication.communication_profile import EpProfileTimeInfo
class DebugEpComm:
def __init__(self):
self.comm_x = 0
self.hccs_x = 0
self.roce_x = 0
self.total_time = 0
self.cfg = SearchConfig()
self.model_type = None
self.cfg_no = 0
class EpCommPerfPredictor(CommPerfPredictor):
def __init__(self, hard_info):
super(EpCommPerfPredictor, self).__init__(hard_info)
self.is_ep_modeling = False
def get_communication_info_from_profile(self, ep_profile_time_info, hcom_info_tage_id):
ep_profile_time_info.total_comm_time += hcom_info_tage_id.total_time_ms
ep_profile_time_info.wait_comm_time += hcom_info_tage_id.wait_time_ms
ep_profile_time_info.min_time += hcom_info_tage_id.min_comm_time_ms
def receive_samples_from_profiling(
self, config_no, model_config: SearchConfig, ep_profile_time_info: EpProfileTimeInfo
):
ep = model_config.ep
if model_config.moe_tp_extend_ep:
ep = ep * model_config.tp
if not ep or ep <= 1:
return
if model_config.num_experts:
self.is_ep_modeling = True
tp = model_config.tp
cp = model_config.cp
pp = model_config.pp
s = model_config.seq_length / 1000
traffic = s * (ep - 1) * pp / ep / tp / cp
bandwidth_910b = (ep - 1)
min_domain = tp
bandwidth = self.hard_info.calbandwidth(bandwidth_910b, min_domain)
comm_x = traffic / bandwidth
K = ep * tp / self.max_hccs_rank_num
hccs_x = s * K * pp / ep / tp / cp
roce_x = s * (K - 1) / K * pp / ep / tp / cp
comm_time = ep_profile_time_info.min_time
cfg = SimpleParallelCfg(config_no, tp, cp, '', ep, pp, '')
iv_list = [comm_x, hccs_x, roce_x, comm_time, cfg]
max_domain = ep * tp
if model_config.moe_tp_extend_ep:
max_domain = ep
min_domain = 0
ep_time_model = CommPerfLinearModelFactory.get_or_create_model(
"ep",
min_rank_num=min_domain,
max_rank_num=max_domain,
max_hccs_dev_num=self.max_hccs_rank_num,
)
ep_time_model.add_sample(*iv_list)
debug_info = DebugEpComm()
debug_info.comm_x = comm_x
debug_info.hccs_x = hccs_x
debug_info.roce_x = roce_x
debug_info.total_time = comm_time
debug_info.cfg = model_config
debug_info.cfg_no = config_no
debug_info.model_type = str(type(ep_time_model))
self.debug_info_list.append(debug_info)
def fit(self):
if self.is_ep_modeling:
for model in CommPerfLinearModelFactory.get_models_by_module_name("ep"):
if model:
model.fit()
def debug(self, config_list: List[SearchConfig]):
if not self.is_ep_modeling:
return
self.logger.debug(f"****************** EP modeling ***********************")
if "hccs" in CommPerfLinearModelFactory._instance_table["ep"].keys():
self.logger.debug(f"HCCS")
tplt = "{0:<8}\t{1:<8}\t{2:<1}\t{3:<1}\t{4:<1}\t{5:<1}\t{6:<1}\t{7:<1}"
self.logger.debug(tplt.format('x', 'total_time',
'No', 'tp', 'dp', 'pp', 'cp', 'ep', chr(12288)))
for debug_info in self.debug_info_list:
if "HCCS" in debug_info.model_type:
self.logger.debug(tplt.format(
round(debug_info.comm_x, 2),
round(debug_info.total_time, 3),
debug_info.cfg_no, debug_info.cfg.tp, debug_info.cfg.dp,
debug_info.cfg.pp, debug_info.cfg.cp, debug_info.cfg.ep,
chr(12288)))
self.logger.debug(f"-----------")
tplt = "{0:<9}\t{1:<9}"
self.logger.debug(tplt.format('w', 'b', chr(12288)))
attn_ag_model = CommPerfLinearModelFactory._instance_table["ep"]["hccs"]
attn_rs_model = CommPerfLinearModelFactory._instance_table["ep"]["hccs"]
self.logger.debug(tplt.format(round(attn_ag_model.w, 3), round(attn_ag_model.b, 3),
round(attn_rs_model.w, 3), round(attn_rs_model.b, 3),
chr(12288)))
self.logger.debug(f"----------------------")
if "roce" in CommPerfLinearModelFactory._instance_table["ep"].keys():
self.logger.debug(f"ROCE")
tplt = "{0:<8}\t{1:<8}\t{2:<1}\t{3:<1}\t{4:<1}\t{5:<1}\t{6:<1}\t{7:<1}"
self.logger.debug(tplt.format('x', 'total_time',
'No', 'tp', 'dp', 'pp', 'cp', 'ep', chr(12288)))
for debug_info in self.debug_info_list:
if "ROCE" in debug_info.model_type:
self.logger.debug(tplt.format(
round(debug_info.comm_x, 2),
round(debug_info.total_time, 3),
debug_info.cfg_no, debug_info.cfg.tp, debug_info.cfg.dp,
debug_info.cfg.pp, debug_info.cfg.cp, debug_info.cfg.ep,
chr(12288)))
self.logger.debug(f"-----------")
tplt = "{0:<9}\t{1:<9}"
self.logger.debug(tplt.format('w', 'b', chr(12288)))
attn_ag_model = CommPerfLinearModelFactory._instance_table["ep"]["roce"]
attn_rs_model = CommPerfLinearModelFactory._instance_table["ep"]["roce"]
self.logger.debug(tplt.format(round(attn_ag_model.w, 3), round(attn_ag_model.b, 3),
round(attn_rs_model.w, 3), round(attn_rs_model.b, 3),
chr(12288)))
self.logger.debug(f"----------------------")
self.logger.debug(f"Cross")
tplt = "{0:<8}\t{1:<8}\t{2:<8}\t{3:<1}\t{4:<1}\t{5:<1}\t{6:<1}\t{7:<1}\t{8:<1}"
self.logger.debug(tplt.format('hccs_x', 'roce_x', 'total_time',
'No', 'tp', 'dp', 'pp', 'cp', 'ep', chr(12288)))
for debug_info in self.debug_info_list:
if "Cross" in debug_info.model_type:
self.logger.debug(tplt.format(
round(debug_info.hccs_x, 2),
round(debug_info.roce_x, 2),
round(debug_info.total_time, 3),
debug_info.cfg_no, debug_info.cfg.tp, debug_info.cfg.dp,
debug_info.cfg.pp, debug_info.cfg.cp, debug_info.cfg.ep,
chr(12288)))
self.logger.debug(f"-----------")
self.logger.debug(f"\n\n\n")
def predict(self, search_cfg: SearchConfig):
ep = search_cfg.ep
if search_cfg.moe_tp_extend_ep:
ep = ep * search_cfg.tp
ep_time = 0.0
if not ep or ep <= 1:
return ep_time
tp = search_cfg.tp
pp = search_cfg.pp
cp = search_cfg.cp
s = search_cfg.seq_length / 1000
traffic = s * (ep - 1) * pp / ep / tp / cp
bandwidth = (ep - 1)
min_domain = tp
bandwidth_910b = (ep - 1)
bandwidth = self.hard_info.calbandwidth(bandwidth_910b, min_domain)
comm_x = traffic / bandwidth
K = ep * tp / self.max_hccs_rank_num
comm_y = s * K * pp / ep / tp / cp
comm_z = s * (K - 1) / K * pp / ep / tp / cp
iv_list = [comm_x, comm_y, comm_z]
max_domain = ep * tp
if search_cfg.moe_tp_extend_ep:
max_domain = ep
min_domain = 0
ep_time_model = CommPerfLinearModelFactory.get_or_create_model(
"ep",
min_rank_num=min_domain,
max_rank_num=max_domain,
max_hccs_dev_num=self.max_hccs_rank_num,
)
ep_time = ep_time_model.predict(*iv_list)
return ep_time