import logging
import os
from common_func.constant import Constant
from common_func.db_name_constant import DBNameConstant
from common_func.info_conf_reader import InfoConfReader
from common_func.ms_constant.str_constant import StrConstant
from common_func.path_manager import PathManager
from msmodel.cluster_info.cluster_info_model import ClusterInfoViewModel
from msmodel.parallel.cluster_parallel_model import ClusterParallelModel
from msmodel.parallel.parallel_model import ParallelViewModel
from msparser.interface.iparser import IParser
class ClusterParallelCollector(IParser):
THREAD_NUM = 10
def __init__(self: any, collect_path: str) -> None:
self.collect_path = collect_path
self._cluster_info = []
self._cluster_parallel_data = []
self._cluster_parallel_strategy_data = []
self._parallel_table_name = Constant.NA
def ms_run(self) -> None:
self.parse()
self.save()
def parse(self: any) -> None:
with ClusterInfoViewModel(self.collect_path) as _model:
if _model.check_table():
self._cluster_info = _model.get_all_cluster_rank_info()
if not self._cluster_info:
return
_project_path = os.path.join(self.collect_path, self._cluster_info[0].dir_name)
if not os.path.exists(PathManager.get_db_path(_project_path, DBNameConstant.DB_PARALLEL)):
return
with ParallelViewModel(_project_path) as _model:
self._parallel_table_name = _model.get_parallel_table_name()
if self._parallel_table_name == Constant.NA:
return
logging.info("Start to parse cluster parallel data!")
self.get_device_parallel_data()
def save(self: any) -> None:
if not self._cluster_parallel_data or not self._cluster_parallel_strategy_data:
logging.warning("Invalid cluster parallel data!")
return
with ClusterParallelModel(self.collect_path) as _model:
_model.create_table(self._parallel_table_name)
_model.flush(self._parallel_table_name, self._cluster_parallel_data)
_model.create_table(DBNameConstant.TABLE_CLUSTER_PARALLEL_STRATEGY)
_model.flush(DBNameConstant.TABLE_CLUSTER_PARALLEL_STRATEGY, self._cluster_parallel_strategy_data)
def get_device_parallel_data(self: any):
for cluster_info in self._cluster_info:
_project_path = os.path.join(self.collect_path, cluster_info.dir_name)
hwts_freq = InfoConfReader().get_freq(StrConstant.HWTS)
with ParallelViewModel(_project_path) as _model:
parallel_index_data = _model.get_parallel_index_data(self._parallel_table_name, cluster_info.rank_id,
cluster_info.device_id, hwts_freq)
if parallel_index_data:
self._cluster_parallel_data.extend(parallel_index_data)
parallel_strategy_data = _model.get_parallel_strategy_data()
if parallel_strategy_data:
self._cluster_parallel_strategy_data.extend(parallel_strategy_data)