from enum import Enum
import os
from mindspeed.auto_settings.config.search_config import SearchConfig
from mindspeed.auto_settings.module.communication import communication_profile
from mindspeed.auto_settings.config.system_config import get_system_config
from mindspeed.auto_settings.module.communication.cp_comm_perf_predictor import CpCommPerfPredictor
from mindspeed.auto_settings.module.communication.dp_comm_perf_predictor import DpCommPerfPredictor
from mindspeed.auto_settings.module.communication.ep_comm_perf_predictor import EpCommPerfPredictor
from mindspeed.auto_settings.module.communication.mc2_comm_perf_predictor import Mc2CommPerfPredictor
from mindspeed.auto_settings.module.communication.pp_comm_perf_predictor import PpCommPerfPredictor
from mindspeed.auto_settings.module.communication.tp_comm_perf_predictor import TpCommPerfPredictor
from mindspeed.auto_settings.module.communication.comm_hardware.comm_hardware_info import CommHardInfo, HccsDev
class CommPerfPredictorManager(object):
"""CommPerfPredictorManager modeling."""
def __init__(self, hardware=None, model_cfg=None):
self.hardware = hardware
self.model_config = model_cfg
hard_info = CommHardInfo(self.hardware.device_type)
self.tp_predictor = TpCommPerfPredictor(hard_info)
self.cp_predictor = CpCommPerfPredictor(hard_info)
self.dp_predictor = DpCommPerfPredictor(hard_info)
self.pp_predictor = PpCommPerfPredictor(hard_info)
self.ep_predictor = EpCommPerfPredictor(hard_info)
self.mc2_predictor = Mc2CommPerfPredictor(hard_info)
if "tp" in model_cfg.parallel_switch:
self.tp_predictor.is_tp_modeling = True
if "cp" in model_cfg.parallel_switch and \
get_system_config().model_config.seq_length >= 8 * 1024:
self.cp_predictor.is_cp_modeling = True
if "ep" in model_cfg.parallel_switch and \
get_system_config().model_config.num_experts:
self.ep_predictor.is_ep_modeling = True
if "mc2" in model_cfg.parallel_switch:
self.ep_predictor.is_mc2_modeling = True
self.config_list = []
def communication_modeling(self, profiling_results):
self.adapt_to_profile_info(profiling_results)
self.info_to_modeling()
def communication_roce_predict_fix(self):
if "910_93" in self.hardware.device_type:
if os.getenv("HCCL_INTER_HCCS_DISABLE", None):
self.tp_predictor.max_hccs_rank_num = HccsDev.hccs_dev_num_910_93.value
self.cp_predictor.max_hccs_rank_num = HccsDev.hccs_dev_num_910_93.value
self.dp_predictor.max_hccs_rank_num = HccsDev.hccs_dev_num_910_93.value
self.pp_predictor.max_hccs_rank_num = HccsDev.hccs_dev_num_910_93.value
self.ep_predictor.max_hccs_rank_num = HccsDev.hccs_dev_num_910_93.value
self.mc2_predictor.max_hccs_rank_num = HccsDev.hccs_dev_num_910_93.value
def adapt_to_profile_info(self, profiling_results):
for index, (config, model) in enumerate(profiling_results):
total_profile_time_info = communication_profile.TotalProfileTimeInfo()
self.config_list.append(config)
self.get_profile_info(model, total_profile_time_info, config, profiling_results, index)
config.use_ascend_mc2 = 0
if config.use_ascend_mc2:
self.mc2_predictor.receive_samples_from_profiling(
index, config, total_profile_time_info.mc2_profile_time_info
)
else:
self.tp_predictor.receive_samples_from_profiling(
index, config, total_profile_time_info.tp_profile_time_info
)
self.dp_predictor.receive_samples_from_profiling(
index, config, total_profile_time_info.dp_profile_time_info
)
self.cp_predictor.receive_samples_from_profiling(
index, config, total_profile_time_info.cp_profile_time_info
)
self.ep_predictor.receive_samples_from_profiling(
index, config, total_profile_time_info.ep_profile_time_info
)
self.pp_predictor.receive_samples_from_profiling(
index, config, total_profile_time_info.pp_profile_time_info
)
def info_to_modeling(self):
self.tp_predictor.fit()
self.tp_predictor.debug()
self.mc2_predictor.fit()
self.mc2_predictor.debug(self.config_list)
self.dp_predictor.fit()
self.dp_predictor.debug(self.config_list)
self.cp_predictor.fit()
self.cp_predictor.debug(self.config_list)
self.ep_predictor.fit()
self.ep_predictor.debug(self.config_list)
self.pp_predictor.fit()
self.pp_predictor.debug(self.config_list)
def get_profile_info(
self, model, total_profile_time_info, config: SearchConfig, profiling_results, index
):
tensor_hcom_info = model.tensor_parallel_comm
data_hcom_info = model.data_parallel_comm
pipeline_hcom_info = model.pipeline_parallel_comm
context_hcom_info = model.context_parallel_comm
expert_hcom_info = model.expert_parallel_comm
config.use_ascend_mc2 = 0
if config.use_ascend_mc2:
self.mc2_predictor.get_communication_info_from_profile(
total_profile_time_info.mc2_profile_time_info, profiling_results, index
)
for stage_id, _ in enumerate(tensor_hcom_info):
if config.tp > 1:
if stage_id == 0 and len(tensor_hcom_info) > stage_id:
self.tp_predictor.get_communication_info_from_profile(
total_profile_time_info.tp_profile_time_info, tensor_hcom_info[stage_id]
)
if stage_id == 0 and len(context_hcom_info) > stage_id:
self.cp_predictor.get_communication_info_from_profile(
total_profile_time_info.cp_profile_time_info,
context_hcom_info[stage_id],
model,
config.cp,
)
if config.pp > 1:
if stage_id == 0 and len(pipeline_hcom_info) > stage_id:
self.pp_predictor.get_communication_info_from_profile(
total_profile_time_info.pp_profile_time_info,
pipeline_hcom_info[stage_id],
config.pp,
)
if stage_id == 0 and len(data_hcom_info) > stage_id:
self.dp_predictor.get_communication_info_from_profile(
total_profile_time_info.dp_profile_time_info, data_hcom_info[stage_id]
)
if stage_id == 0 and len(expert_hcom_info) > stage_id:
self.ep_predictor.get_communication_info_from_profile(
total_profile_time_info.ep_profile_time_info, expert_hcom_info[stage_id]
)