import logging
import typing
from collections import deque
from common_func.ai_stack_data_check_manager import AiStackDataCheckManager
from common_func.db_manager import DBManager
from common_func.db_name_constant import DBNameConstant
from common_func.ms_constant.number_constant import NumberConstant
from common_func.msprof_exception import ProfException
from common_func.path_manager import PathManager
from common_func.step_trace_constant import StepTraceConstant
from mscalculate.step_trace.tag_handler.tag_dispatch_handler import DispatchModelHandler
from profiling_bean.db_dto.step_trace_dto import StepTraceOriginDto
from msmodel.step_trace.ts_track_model import TsTrackModel
class CreateSubTable:
"""
create sub table
"""
data = []
sample_config = {}
db_name = None
table_name = None
@classmethod
def extract_data(cls: any, collect_data: any) -> None:
"""
extract data
:param collect_data: collect data
:return: void
"""
@classmethod
def create_table(cls: any, conn: any) -> None:
"""
get table info
:param conn: connect to database
:return: create sql, insert sql
"""
@classmethod
def connect_db(cls: any) -> None:
"""
connect and destroy database
:return: void
"""
db_path = PathManager.get_db_path(cls.sample_config.get("result_dir"), cls.db_name)
conn, cur = DBManager.create_connect_db(db_path)
cls.create_table(conn)
DBManager.destroy_db_connect(conn, cur)
@classmethod
def run(cls: any, sample_config: dict, model_handler: any) -> None:
"""
extract data and build table
:param sample_config: sample config
:param model_handler: model handler
:return: void
"""
collect_data = model_handler.get_data()
cls.sample_config = sample_config
try:
cls.extract_data(collect_data)
except ProfException as step_error:
logging.error("Fail to create step trace sub table, error code: %s", step_error)
cls.connect_db()
class CreateStepTraceData(CreateSubTable):
"""
create step trace table
step_trace_data存储每个子图迭代的开始结束时间
"""
data = []
sample_config = {}
db_name = DBNameConstant.DB_STEP_TRACE
table_name = DBNameConstant.TABLE_STEP_TRACE_DATA
@classmethod
def extract_data(cls: any, collect_data: any) -> None:
for model_id, model_data in collect_data.items():
for index_id, index_data in model_data.items():
cls.data.append(
[index_id, model_id, index_data.get(
StepTraceConstant.STEP_START),
index_data.get(StepTraceConstant.STEP_END)])
@classmethod
def create_table(cls: any, conn: any) -> None:
cls.data = sorted(cls.data, key=lambda x: x[3])
for index, datum in enumerate(cls.data):
datum.append(index + 1)
create_sql = "create table if not exists {0} (index_id int, model_id int, " \
"step_start int, step_end int, iter_id int)".format(cls.table_name)
DBManager.execute_sql(conn, create_sql)
if not cls.data:
return
insert_sql = 'insert into {0} values ({1})'.format(
DBNameConstant.TABLE_STEP_TRACE_DATA, ",".join("?" * len(cls.data[0])))
DBManager.executemany_sql(conn, insert_sql, cls.data)
class CreateStepTime(CreateSubTable):
"""
create StepTime table
StepTime存储每个Step的开始结束时间
"""
data = []
sample_config = {}
db_name = DBNameConstant.DB_STEP_TRACE
table_name = DBNameConstant.TABLE_STEP_TIME
@classmethod
def extract_data(cls: any, collect_data: any = None) -> None:
with TsTrackModel(cls.sample_config.get("result_dir"), DBNameConstant.DB_STEP_TRACE,
[DBNameConstant.TABLE_STEP_TRACE]) as ts:
step_data = ts.get_step_trace_with_tag([StepTraceConstant.STEP_START_TAG, StepTraceConstant.STEP_END_TAG])
step_time = {}
for data in step_data:
step_time.setdefault(data.index_id, []).append(data)
for index_id, data in step_time.items():
if len(data) != 2:
logging.error("The step trace data is missing in step %d.", index_id)
continue
cls.data.append(
[index_id, data[0].model_id, data[0].timestamp, data[1].timestamp, index_id]
)
@classmethod
def create_table(cls: any, conn: any) -> None:
if not cls.data:
return
cls.data = sorted(cls.data, key=lambda x: x[3])
create_sql = "create table if not exists {0} (index_id int, model_id int, " \
"step_start int, step_end int, iter_id int)".format(cls.table_name)
DBManager.execute_sql(conn, create_sql)
insert_sql = 'insert into {0} values ({1})'.format(
cls.table_name, ",".join("?" * len(cls.data[0])))
DBManager.executemany_sql(conn, insert_sql, cls.data)
class CreateAllReduce(CreateSubTable):
"""
create all reduce table
"""
data = []
sample_config = {}
db_name = DBNameConstant.DB_TRACE
table_name = DBNameConstant.TABLE_ALL_REDUCE
@classmethod
def extract_data(cls: any, collect_data: any) -> None:
for model_id, model_data in collect_data.items():
for index_id, index_data in model_data.items():
for each_reduce in index_data[StepTraceConstant.ALL_REDUCE]:
cls.data.append(
[cls.sample_config.get("devices"), model_id, index_id,
index_data[StepTraceConstant.STEP_END], each_reduce[StepTraceConstant.REDUCE_START],
each_reduce[StepTraceConstant.REDUCE_END]])
@classmethod
def create_table(cls: any, conn: any) -> None:
create_sql = "create table if not exists {}" \
"(device_id int, model_id int, index_id int," \
"iteration_end int, start int, end int, primary key(device_id," \
"iteration_end, start))".format(DBNameConstant.TABLE_ALL_REDUCE)
DBManager.execute_sql(conn, create_sql)
if not cls.data:
return
insert_sql = 'insert into {0} values ({1})'.format(
DBNameConstant.TABLE_ALL_REDUCE, ",".join("?" * len(cls.data[0])))
DBManager.executemany_sql(conn, insert_sql, cls.data)
class CreateTrainingTrace(CreateSubTable):
"""
create training trace table
"""
data = []
sample_config = {}
db_name = DBNameConstant.DB_TRACE
table_name = DBNameConstant.TABLE_TRAINING_TRACE
@staticmethod
def check_timestamp(model_id: int, index_id: int, step_start: int, step_end: int) -> None:
"""
check timestamp
:param model_id: model id
:param index_id: index id
:param step_start: starting timestamp of a iter
:param step_end: ending timestamp of a iter
:return:
"""
if step_start == NumberConstant.NULL_NUMBER or step_end == NumberConstant.NULL_NUMBER:
logging.error("step time is None, step start: %d, step end: %d, "
"model id: %d, index id: %d", step_start, step_end, model_id, index_id)
raise ProfException(ProfException.PROF_INVALID_STEP_TRACE_ERROR)
if step_start == step_end:
logging.error("start time equals to end time, step start: %d, step end: %d, "
"model id: %d, index id: %d", step_start, step_end, model_id, index_id)
raise ProfException(ProfException.PROF_INVALID_STEP_TRACE_ERROR)
@classmethod
def extract_data(cls: any, collect_data: any) -> None:
for model_id, model_data in collect_data.items():
for index_id, index_data in model_data.items():
cls.update_data(model_id, index_id, index_data)
@classmethod
def create_table(cls: any, conn: any) -> None:
cls.update_step_time()
create_sql = "create table if not exists {0} " \
"(device_id int, model_id int, iteration_id int, " \
"FP_start int, " \
"BP_end int, iteration_end int, " \
"iteration_time int, fp_bp_time int, grad_refresh_bound int, " \
"data_aug_bound int)".format(DBNameConstant.TABLE_TRAINING_TRACE)
DBManager.execute_sql(conn, create_sql)
if not cls.data:
return
cls.update_data_aug()
insert_sql = 'insert into {0} values ({1})'.format(
DBNameConstant.TABLE_TRAINING_TRACE, ",".join("?" * len(cls.data[0])))
DBManager.executemany_sql(conn, insert_sql, cls.data)
@classmethod
def update_data(cls: any, model_id: int, index_id: int, index_data: dict) -> None:
"""
update data
:param model_id: model id
:param index_id: index id
:param index_data: index data
:return:
"""
training_trace = index_data.get(StepTraceConstant.TRAINING_TRACE, {})
step_start = index_data.get(StepTraceConstant.STEP_START, NumberConstant.NULL_NUMBER)
step_end = index_data.get(StepTraceConstant.STEP_END, NumberConstant.NULL_NUMBER)
cls.check_timestamp(model_id, index_id, step_start, step_end)
forward_pro = training_trace.get(StepTraceConstant.FORWARD_PROPAGATION, NumberConstant.NULL_NUMBER)
back_pro = training_trace.get(StepTraceConstant.BACK_PROPAGATION, NumberConstant.NULL_NUMBER)
iteration_time = step_end - step_start
fp_bp_time = \
NumberConstant.NULL_NUMBER if not (forward_pro and back_pro) else back_pro - forward_pro
grad_refresh_bound = \
NumberConstant.NULL_NUMBER if not back_pro else step_end - back_pro
cls.data.append([cls.sample_config.get("devices"), model_id, index_id, forward_pro,
back_pro, step_end, iteration_time, fp_bp_time,
grad_refresh_bound, NumberConstant.NULL_NUMBER])
@classmethod
def update_step_time(cls: any):
with TsTrackModel(cls.sample_config.get("result_dir"), DBNameConstant.DB_STEP_TRACE,
[DBNameConstant.TABLE_STEP_TIME]) as ts:
step_data = ts.get_step_trace_data(DBNameConstant.TABLE_STEP_TIME)
for data in step_data:
cls.data.append([cls.sample_config.get("devices"), data.model_id, data.index_id, 0, 0,
data.step_end, data.step_end - data.step_start, 0, 0, NumberConstant.NULL_NUMBER])
@classmethod
def update_data_aug(cls: any) -> None:
"""
update data aug
:return: void
"""
cls.data.sort(key=lambda datum: datum[NumberConstant.STEP_END])
for current_iter_index, current_datum in enumerate(cls.data):
if current_datum[NumberConstant.FORWARD_PROPAGATION]:
last_iter_index = cls.__find_closest_step_end_index(
current_iter_index)
if last_iter_index >= 0:
current_datum[
NumberConstant.DATA_AUG_BOUND] = \
current_datum[
NumberConstant.FORWARD_PROPAGATION] - \
cls.data[last_iter_index][
NumberConstant.STEP_END]
@classmethod
def __find_closest_step_end_index(cls: any, current_iter_index: int) -> int:
"""
find last iter index. Last iter is the closest to
current iter, and its iter end is before fp of current iter
:param current_iter_index: current_iter_index
:return: void
"""
last_iter_index = current_iter_index - 1
while last_iter_index >= 0:
if cls.data[current_iter_index][NumberConstant.FORWARD_PROPAGATION] > \
cls.data[last_iter_index][NumberConstant.STEP_END]:
break
last_iter_index = last_iter_index - 1
if last_iter_index < current_iter_index - 1:
logging.warning("The last iter of the %s iter of "
"total iters is not %s, but is %s",
current_iter_index,
current_iter_index - 1, last_iter_index)
return last_iter_index
class GetNextCreator(CreateSubTable):
"""
create get_next table
"""
data = []
sample_config = {}
db_name = DBNameConstant.DB_TRACE
table_name = DBNameConstant.TABLE_GET_NEXT
getnext_start = dict()
getnext_end = dict()
TIMESTAMP = "timestamp"
@classmethod
def extract_data(cls: typing.Any, collect_data: typing.Any) -> None:
for model_data in collect_data.values():
for index_data in model_data.values():
cls.update_data(index_data)
for model_id, model_data in cls.getnext_start.items():
for index_id, index_data in model_data.items():
for key, start_deque in index_data.items():
end_deque = cls.getnext_end.get(model_id, {}).get(index_id, {}).get(key, deque())
cls.match_getnext(start_deque, end_deque, model_id, index_id)
@classmethod
def match_getnext(cls: any, start_deque: deque, end_deque: deque, model_id: int, index_id: int) -> None:
mismatch_count = 0
while start_deque and end_deque:
start_record = start_deque.popleft()
end_record = end_deque.popleft()
while end_deque and end_record.get(cls.TIMESTAMP, 0) < start_record.get(cls.TIMESTAMP, 0):
mismatch_count += 1
end_record = end_deque.popleft()
while start_deque and start_deque[0].get(cls.TIMESTAMP, 0) <= end_record.get(cls.TIMESTAMP, 0):
mismatch_count += 1
start_record = start_deque.popleft()
cls.data.append(
[
model_id,
index_id,
start_record.get(cls.TIMESTAMP, 0),
end_record.get(cls.TIMESTAMP, 0)
]
)
if start_deque or end_deque:
logging.error("The getnext mismatch happen with model_id: %d, index_id: %d, start len: %d, end len: %d.",
model_id, index_id, len(start_deque), len(end_deque))
if mismatch_count > 0:
logging.error("There are %d getnext mismatching.", mismatch_count)
@classmethod
def update_data(cls: typing.Any, index_data: typing.Dict) -> None:
for key, records in index_data.get(StepTraceConstant.GET_NEXT, {}).items():
for record in records:
model_id = record.get("model_id", 0)
index_id = record.get("index_id", 1)
stream_key = str(record.get("stream_id", 0)) + "_" + str(key)
if record[StepTraceConstant.TAG_ID] % 2 == 0:
cls.getnext_start.setdefault(model_id, {})\
.setdefault(index_id, {}).setdefault(stream_key, deque()).append(record)
else:
cls.getnext_end.setdefault(model_id, {})\
.setdefault(index_id, {}).setdefault(stream_key, deque()).append(record)
@classmethod
def create_table(cls: typing.Any, conn: typing.Any) -> None:
create_sql = "create table if not exists {0} " \
"(model_id INTEGER, index_id INTEGER, start_time INTEGER, end_time INTEGER)".format(cls.table_name)
DBManager.execute_sql(conn, create_sql)
if not cls.data:
return
insert_sql = 'insert into {0} values ({1})'.format(cls.table_name, ",".join("?" * len(cls.data[0])))
DBManager.executemany_sql(conn, insert_sql, cls.data)
class StepTableBuilder:
"""
create table from step trace
"""
TIMESTAMP_INDEX = 2
model_handler = DispatchModelHandler()
table_list = [CreateStepTraceData, CreateStepTime, CreateAllReduce, CreateTrainingTrace, GetNextCreator]
step_conn = None
step_curs = None
@staticmethod
def to_dict(record: list) -> dict:
"""
transform list to dict
:param record: contain "model_id", "tag_id", timestamp
:return:
"""
record_dict = {
StepTraceConstant.INDEX_ID: record.index_id, StepTraceConstant.MODEL_ID: record.model_id,
StepTraceConstant.TIME_STAMP: record.timestamp, StepTraceConstant.TAG_ID: record.tag_id,
StepTraceConstant.STREAM_ID: record.stream_id
}
return record_dict
@classmethod
def process_step_trace(cls: any, step_trace: list) -> None:
"""
process step trace
:param step_trace: contain model_id, tag_id, timestamp
:return: void
"""
step_trace.sort(key=lambda x: (x.model_id, x.timestamp))
step_trace = StepTracePreProcess().reorder_step_trace_for_pipe_stage(step_trace)
for record in step_trace:
cls.model_handler.receive_record(cls.to_dict(record))
@classmethod
def build_table(cls: any, sample_config: dict) -> None:
"""
build step trace data, all reduce, training trace table
:param sample_config: sample config
:return: void
"""
if AiStackDataCheckManager.contain_training_trace_data(sample_config.get("result_dir")):
cls.table_list.remove(CreateAllReduce)
cls.table_list.remove(CreateTrainingTrace)
cls.table_list.remove(GetNextCreator)
for create_table in cls.table_list:
create_table.run(sample_config, cls.model_handler)
@classmethod
def run(cls: any, sample_config: dict) -> None:
"""
run class
:param sample_config: sample config
:return: void
"""
cls._connect_step_db(sample_config)
step_trace_data = cls._get_step_trace_data()
if not step_trace_data:
return
cls.process_step_trace(step_trace_data)
cls.build_table(sample_config)
@classmethod
def _get_step_data(cls: any, table_name: str, is_helper=False) -> list:
if not DBManager.judge_table_exist(cls.step_curs, table_name):
return []
if is_helper:
select_sql = "select DISTINCT index_id, model_id, " \
"timestamp, tag_id, 0 as stream_id from {}".format(table_name)
else:
select_sql = "select DISTINCT index_id, model_id, " \
"timestamp, tag_id, stream_id from {}".format(table_name)
return DBManager.fetch_all_data(cls.step_curs, select_sql, dto_class=StepTraceOriginDto)
@classmethod
def _connect_step_db(cls: any, sample_config: dict) -> None:
db_path = PathManager.get_db_path(sample_config.get("result_dir"), DBNameConstant.DB_STEP_TRACE)
cls.step_conn, cls.step_curs = DBManager.check_connect_db_path(db_path)
if not cls.step_conn or not cls.step_curs:
return
@classmethod
def _get_step_trace_data(cls: any) -> list:
_step_data = cls._get_step_data(DBNameConstant.TABLE_STEP_TRACE)
_helper_data = cls._get_step_data(DBNameConstant.TABLE_MODEL_WITH_Q, is_helper=True)
step_trace_data = _step_data + _helper_data
if not step_trace_data:
return []
DBManager.destroy_db_connect(cls.step_conn, cls.step_curs)
return step_trace_data
class StepTracePreProcess:
MODEL_START_TAG = 0
MODEL_END_TAG = 1
MSTX_TAG = 11
ITERATION_END_TAG = 4
HCCL_START_TAG = 10000
def __init__(self):
self.reordered_step_trace = []
self.current_model_id = None
self.current_step_trace_queue = deque()
def reorder_step_trace_for_pipe_stage(self: any, step_trace: list) -> list:
for record in step_trace:
if record.model_id != self.current_model_id:
for data in self.current_step_trace_queue:
self.reordered_step_trace.extend(data["all_record"])
self.current_step_trace_queue = deque()
self.current_model_id = record.model_id
if record.tag_id == self.MODEL_START_TAG:
self.current_step_trace_queue.append({"tag": [record], "all_record": [record]})
elif record.tag_id == self.MODEL_END_TAG:
self.deal_model_end_tag(record)
elif StepTraceConstant.GET_NEXT_START_TAG <= record.tag_id < StepTraceConstant.STEP_START_TAG:
if self.current_step_trace_queue:
self.current_step_trace_queue[-1]["all_record"].append(record)
elif record.tag_id >= self.HCCL_START_TAG:
self.deal_hccl_tag(record)
elif record.tag_id == self.MSTX_TAG:
self.current_step_trace_queue.append({"tag": [record], "all_record": [record]})
else:
self.deal_iteration_tag(record)
if self.current_step_trace_queue:
for data in self.current_step_trace_queue:
self.reordered_step_trace.extend(data.get("all_record", []))
return self.reordered_step_trace
def deal_model_end_tag(self, record: any):
model_start_tag_num = 0
while self.current_step_trace_queue:
if self.current_step_trace_queue[0]["tag"][0].tag_id == self.MODEL_START_TAG:
model_start_tag_num += 1
if model_start_tag_num == 2:
break
self.reordered_step_trace.extend(self.current_step_trace_queue.popleft()["all_record"])
self.reordered_step_trace.append(record)
def deal_hccl_tag(self, record: any):
for data in self.current_step_trace_queue:
if data["tag"][-1].tag_id != self.ITERATION_END_TAG:
data["all_record"].append(record)
break
def deal_iteration_tag(self, record: any):
is_new_iteration = True
for data in self.current_step_trace_queue:
if data["tag"][-1].tag_id < record.tag_id:
data["tag"].append(record)
data["all_record"].append(record)
is_new_iteration = False
break
if is_new_iteration:
self.current_step_trace_queue.append({"tag": [record], "all_record": [record]})