"""
算子预估
"""
import math
from mindspeed.auto_settings.config.model_config import get_model_config
from mindspeed.auto_settings.config.system_config import get_system_config
from mindspeed.auto_settings.module.communication.comm_perf_predictor_manager import CommPerfPredictorManager
from mindspeed.auto_settings.utils.logger import get_logger
class Communication(object):
def __init__(self):
self.logger = get_logger("Communication")
self.predictor_mgr = CommPerfPredictorManager(get_system_config(), get_model_config())
def train_models(self, profiling_results):
self.predictor_mgr.communication_modeling(profiling_results)
def get_communication_time(self, config):
"""
通信执行耗时
"""
search_cfg = config
tp = search_cfg.tensor_model_parallel_size
dp = search_cfg.data_parallel_size
pp = search_cfg.pipeline_model_parallel_size
vp = search_cfg.num_layers // (pp * search_cfg.num_layers_per_virtual_pipeline_stage) \
if search_cfg.num_layers_per_virtual_pipeline_stage else 1
cp = search_cfg.context_parallel_size
ep = search_cfg.expert_model_parallel_size if search_cfg.expert_model_parallel_size else 1
search_micro_batch_size = search_cfg.micro_batch_size
num_layers = self.predictor_mgr.model_config.num_layers
global_batch_size = self.predictor_mgr.model_config.global_batch_size
cp_time = self.predictor_mgr.cp_predictor.predict(search_cfg) / pp
dp_time = self.predictor_mgr.dp_predictor.predict(search_cfg)
pp_time = self.predictor_mgr.pp_predictor.predict(search_cfg)
ep_time = self.predictor_mgr.ep_predictor.predict(search_cfg) / pp
model_micro_batch_size = 1
self.predictor_mgr.communication_roce_predict_fix()
mc2_time = self.predictor_mgr.mc2_predictor.predict(search_cfg)
tp_time = self.predictor_mgr.tp_predictor.predict(search_cfg)
self.logger.debug(f"mc2_time:{mc2_time} tp_time:{tp_time}")
use_mc2 = False
if self.predictor_mgr.tp_predictor.is_tp_modeling and \
self.predictor_mgr.mc2_predictor.is_mc2_modeling:
use_mc2 = mc2_time < tp_time
tp_time = min(mc2_time, tp_time)
tp_time = tp_time / pp
micro_batch_num = global_batch_size / (dp * search_micro_batch_size)
layer_num = math.ceil(micro_batch_num * (num_layers / pp))
search_model_mbs_ratio = search_micro_batch_size / model_micro_batch_size
communication_time = (tp_time + cp_time + ep_time) * search_model_mbs_ratio * layer_num
fw_communication_time = (tp_time + cp_time + ep_time) / 3
print('communication:', communication_time, 'pp:', pp_time, 'mbs:', search_model_mbs_ratio, 'dp:', dp_time)
return {
"use_mc2": use_mc2,
"fw_communication_time": fw_communication_time,
"communication_time": communication_time,
"pp_time": pp_time,
"dp_time": dp_time,
"tp_time": tp_time,
"cp_time": cp_time,
"ep_time": ep_time
}