from common_func.constant import Constant
from common_func.db_manager import DBManager
from common_func.db_name_constant import DBNameConstant
from common_func.empty_class import EmptyClass
from common_func.utils import Utils
from msmodel.interface.base_model import BaseModel
from profiling_bean.db_dto.step_trace_dto import StepTraceDto
class MsprofStep:
"""
mainly process iteration
"""
def __init__(self: any, result_dir: str) -> None:
self._result_dir = result_dir
self.data = []
self.model = BaseModel(self._result_dir, DBNameConstant.DB_STEP_TRACE, [DBNameConstant.TABLE_STEP_TRACE_DATA])
def __enter__(self):
if not self.model.check_table():
return self
self.get_step_data()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.model.finalize()
def get_step_data(self: any) -> None:
"""
get all data from table step trace data
:return:
"""
self.model.init()
all_data_sql = "select * from {}".format(DBNameConstant.TABLE_STEP_TRACE_DATA)
self.data = DBManager.fetch_all_data(self.model.cur, all_data_sql, dto_class=StepTraceDto)
def get_step_iteration_time(self: any, index_id: int, model_id: int) -> list:
"""
get iteration time by model id and index id
:param index_id:
:param model_id:
:return:
"""
step_end_min = None
step_end_max = None
iter_id_list = self.get_iter_id(index_id, model_id)
if not iter_id_list:
return []
for data in self.data:
if data.iter_id == iter_id_list[0]:
step_end_min = data.step_end
if data.iter_id == iter_id_list[1]:
step_end_max = data.step_end
if not step_end_min or not step_end_max:
return []
return [step_end_min, step_end_max]
def get_mix_op_iter_id(self: any, index_id: int, model_id: int) -> tuple:
"""
get op single iter id in single op and graph mix scene
:param index_id:
:param model_id:
:return: [min_iter_id - 1, max_iter_id]
"""
iter_data = None
iter_id_min = Constant.DEFAULT_COUNT
iter_id_max = max([data.iter_id for data in self.data])
for data in self.data:
if data.model_id == model_id and data.index_id == index_id:
iter_data = data
break
if iter_data is None:
return ()
for data in self.data:
if data.model_id == Constant.GE_OP_MODEL_ID and data.step_start <= iter_data.step_start <= data.step_end:
iter_id_min = data.iter_id - 1
if data.model_id == Constant.GE_OP_MODEL_ID and iter_data.step_start < data.step_start:
iter_id_max = data.iter_id - 1
break
return iter_id_min, iter_id_max
def get_graph_iter_id(self: any, index_id: int, model_id: int) -> tuple:
"""
get graph iter id in single op scene or graph scene
:param index_id:
:param model_id:
:return:
"""
iter_id = None
for data in self.data:
if data.model_id == model_id and data.index_id == index_id:
iter_id = data.iter_id
break
if not iter_id:
return ()
return iter_id - 1, iter_id
def get_iter_id(self: any, index_id: int, model_id: int) -> tuple:
"""
get iter id by model id and index id
:param index_id:
:param model_id:
:return:
"""
if Utils.need_all_model_in_one_iter(self._result_dir, model_id):
return self.get_mix_op_iter_id(index_id, model_id)
return self.get_graph_iter_id(index_id, model_id)
def get_model_and_index_id_by_iter_id(self: any, iter_id: int) -> tuple:
"""
get model id and index id by iter id
:param iter_id:
:return:
"""
for data in self.data:
if data.iter_id == iter_id:
return data.model_id, data.index_id
return EmptyClass(), EmptyClass()