import abc
from typing import List
from collections import namedtuple
from mindspeed.auto_settings.config.search_config import SearchConfig
from mindspeed.auto_settings.module.communication.communication_profile import ProfileTimeInfo
from mindspeed.auto_settings.utils.logger import get_logger
SimpleParallelCfg = namedtuple(
"SimpleParallelCfg", field_names=["config_no", "tp", "cp", "dp", "ep", "pp", "vp"]
)
class CommPerfPredictor:
def __init__(self, hard_info):
self.logger = get_logger("CommPerfPredictor")
self.max_hccs_rank_num = hard_info.max_hccs_rank_num
self.hard_info = hard_info
self.debug_info_list = []
@abc.abstractmethod
def get_communication_info_from_profile(self, hcom_info_tage_id):
pass
@abc.abstractmethod
def receive_samples_from_profiling(
self, config_no, model_config: SearchConfig, profile_info: ProfileTimeInfo
):
"""Parse profiling info and extract the samples including 'x'(s) and 'y' and add to the
linear models.
:param model_config:
:param profile_info:
:return:
"""
pass
@abc.abstractmethod
def fit(self):
"""Trigger all the linear models to fit.
:return:
"""
pass
@abc.abstractmethod
def predict(self, search_cfg: SearchConfig):
"""Predict the communication time based on the given model configuration searched.
:param search_cfg: The configuration of the search.
:return: The predicted communication time.
"""
pass
@abc.abstractmethod
def debug(self, config_list: List[SearchConfig]):
"""Print model configurations and the linear models' samples and fitted parameters.
:return:
"""
pass