from typing import List
from mindspeed.auto_settings.config.search_config import SearchConfig
from mindspeed.auto_settings.module.communication.linear_models import HCCSDomainModel
from mindspeed.auto_settings.module.communication.comm_perf_predictor import CommPerfPredictor, SimpleParallelCfg
from mindspeed.auto_settings.module.communication.communication_profile import TpProfileTimeInfo
class DebugTpComm:
def __init__(self):
self.comm_time_y = 0
self.total_time = 0
self.wait_time = 0
self.overlap_time_y = 0
self.comm_x = 0
self.cfg = SearchConfig()
self.cfg_no = 0
class TpCommPerfPredictor(CommPerfPredictor):
def __init__(self, hard_info):
super(TpCommPerfPredictor, self).__init__(hard_info)
self.is_tp_modeling = False
self.tp_total_model = HCCSDomainModel()
self.tp_overlap_model = HCCSDomainModel()
def get_communication_info_from_profile(self, tp_profile_time_info, hcom_info_tage_id):
tp_profile_time_info.total_comm_time += hcom_info_tage_id.total_time_ms
tp_profile_time_info.wait_comm_time += hcom_info_tage_id.wait_time_ms
tp_profile_time_info.overlap_comm_time += hcom_info_tage_id.overlap_time_ms
def receive_samples_from_profiling(
self, config_no, model_config: SearchConfig, tp_profile_time_info: TpProfileTimeInfo
):
if not self.is_tp_modeling:
return
config = model_config
tp = config.tp
cp = config.cp
pp = config.pp
if tp == 1:
return
s = config.seq_length / 1000
total_time = tp_profile_time_info.total_comm_time
wait_time = tp_profile_time_info.wait_comm_time
overlap_time = tp_profile_time_info.overlap_comm_time
traffic = s * (tp - 1) / (tp * cp) * pp
bandwidth_910b = (tp - 1)
min_domain = tp
bandwidth = self.hard_info.calbandwidth(bandwidth_910b, min_domain)
comm_x = traffic / bandwidth
cfg = SimpleParallelCfg(config_no, tp, cp, '', '', pp, '')
comm_time_y = total_time - wait_time
overlap_time_y = overlap_time
self.tp_total_model.add_sample(*(comm_x, comm_time_y, cfg))
self.tp_overlap_model.add_sample(*(comm_x, overlap_time_y, cfg))
debug_info = DebugTpComm()
debug_info.comm_x = comm_x
debug_info.comm_time_y = comm_time_y
debug_info.total_time = total_time
debug_info.wait_time = wait_time
debug_info.overlap_time_y = overlap_time_y
debug_info.cfg = config
debug_info.cfg_no = config_no
self.debug_info_list.append(debug_info)
def fit(self):
if not self.is_tp_modeling:
return
self.tp_total_model.fit()
self.tp_overlap_model.fit()
def debug(self):
if not self.is_tp_modeling:
return
self.logger.debug(f"******************profile info list***********************")
tplt = "{0:<8}\t{1:<8}\t{2:<8}\t{3:<8}\t{4:<8}\t{5:<1}\t{6:<1}\t{7:<1}\t{8:<1}\t{9:<1}\t{10:<1}\t{11:<1}"
self.logger.debug(f"****************** tp(ms) ***********************")
self.logger.debug(tplt.format('x', 'tp_time', 'overlap', 'total_time', 'wait_time',
'No', 'tp', 'dp', 'pp', 'cp', 'ep', chr(12288)))
for debug_info in self.debug_info_list:
if debug_info.cfg.use_ascend_mc2:
continue
self.logger.debug(tplt.format(
round(debug_info.comm_x, 2),
round(debug_info.comm_time_y, 3),
round(debug_info.overlap_time_y, 2),
round(debug_info.total_time, 2),
round(debug_info.wait_time, 2),
debug_info.cfg_no, debug_info.cfg.tp, debug_info.cfg.dp,
debug_info.cfg.pp, debug_info.cfg.cp, debug_info.cfg.ep,
chr(12288)))
self.logger.debug(f"-----------")
tplt = "{0:<9}\t{1:<9}\t{2:<9}\t{3:<9}"
self.logger.debug(tplt.format('tp_w', 'tp_b', 'overlap_w', 'overlap_b', chr(12288)))
self.logger.debug(tplt.format(round(self.tp_total_model.w, 3), round(self.tp_total_model.b, 3),
round(self.tp_overlap_model.w, 3), round(self.tp_overlap_model.b, 3),
chr(12288)))
self.logger.debug(f"\n\n\n")
return
def predict(self, search_cfg: SearchConfig):
if not self.is_tp_modeling:
return 0
tp = search_cfg.tensor_model_parallel_size
cp = search_cfg.context_parallel_size
pp = search_cfg.pipeline_model_parallel_size
s = search_cfg.seq_length / 1000
tp_time = 0
if tp > 1:
traffic = s * (tp - 1) / (tp * cp) * pp
bandwidth_910b = (tp - 1)
min_domain = tp
bandwidth = self.hard_info.calbandwidth(bandwidth_910b, min_domain)
tp_x = traffic / bandwidth
tp_time = self.tp_total_model.predict(*(tp_x,))
tp_overlap_time = self.tp_overlap_model.predict(*(tp_x,))
tp_time = tp_time - tp_overlap_time
if tp_time < 0:
tp_time = 0
return tp_time