import re
import pandas as pd
from msprof_analyze.prof_exports.base_stats_export import BaseStatsExport
from msprof_analyze.prof_common.constant import Constant
from msprof_analyze.prof_common.logger import get_logger
from msprof_analyze.prof_common.db_manager import DBManager
logger = get_logger()
MARK_QUERY_TEMPLATE = """
{with_clause}
SELECT
MSG_IDS.value AS "msg",
MSTX_EVENTS.startNs AS "cann_ts",
{device_start_ts} AS "device_ts",
{framework_start_ts} AS "framework_ts",
MSTX_EVENTS.globalTid AS "tid"
FROM
MSTX_EVENTS
{task_join}
{framework_join}
LEFT JOIN
STRING_IDS AS MSG_IDS
ON MSTX_EVENTS.message = MSG_IDS.id
WHERE
MSTX_EVENTS.eventType = 3 AND MSTX_EVENTS.startNs >= ? AND MSTX_EVENTS.startNs <= ?
ORDER BY
MSTX_EVENTS.startNs
"""
class MstxMarkExport(BaseStatsExport):
def __init__(self, db_path, recipe_name, step_range):
super().__init__(db_path, recipe_name, step_range)
self._query = self.get_query_statement()
def get_query_statement(self):
has_pytorch_api = DBManager.judge_table_exists(self._db_path, "PYTORCH_API")
has_task = DBManager.judge_table_exists(self._db_path, "TASK")
with_clause = ""
framework_join = ""
framework_start_ts = "0"
if has_pytorch_api:
with_clause = f"""
WITH
FRAMEWORK_API AS (
SELECT
PYTORCH_API.startNs,
CONNECTION_IDS.connectionId
FROM
PYTORCH_API
LEFT JOIN
CONNECTION_IDS
ON PYTORCH_API.connectionId = CONNECTION_IDS.id
)
"""
framework_join = "LEFT JOIN FRAMEWORK_API ON MSTX_EVENTS.connectionId = FRAMEWORK_API.connectionId"
framework_start_ts = "FRAMEWORK_API.startNs"
task_join = ""
device_start_ts = "0"
if has_task:
task_join = "LEFT JOIN TASK ON MSTX_EVENTS.connectionId = TASK.connectionId"
device_start_ts = "TASK.startNs"
return MARK_QUERY_TEMPLATE.format(
with_clause=with_clause,
device_start_ts=device_start_ts,
framework_start_ts=framework_start_ts,
task_join=task_join,
framework_join=framework_join,
)
def get_param_order(self):
return [Constant.START_NS, Constant.END_NS]
RANGE_QUERY_TEMPLATE = '''
SELECT
MSG_IDS.value AS "msg",
MSTX_EVENTS.startNs AS "cann_start_ts",
MSTX_EVENTS.endNs AS "cann_end_ts",
{device_start_ts} AS "device_start_ts",
{device_end_ts} AS "device_end_ts",
MSTX_EVENTS.globalTid AS "tid"
FROM
MSTX_EVENTS
{task_join}
LEFT JOIN
STRING_IDS AS MSG_IDS
ON MSTX_EVENTS.message = MSG_IDS.id
WHERE
MSTX_EVENTS.eventType = 2 AND MSTX_EVENTS.startNs >= ? AND MSTX_EVENTS.startNs <= ?
AND
MSTX_EVENTS.connectionId != 4294967295
ORDER BY
MSTX_EVENTS.startNs
'''
class MstxRangeExport(BaseStatsExport):
def __init__(self, db_path, recipe_name, param_dict):
super().__init__(db_path, recipe_name, param_dict)
self.set_query()
def set_query(self):
if not DBManager.check_tables_in_db(self._db_path, "TASK"):
self._query = self.get_query_statement_no_task()
else:
self._query = self.get_query_statement_with_task()
def get_query_statement_with_task(self):
return RANGE_QUERY_TEMPLATE.format(
device_start_ts="TASK.startNs",
device_end_ts="TASK.endNs",
task_join="LEFT JOIN TASK ON MSTX_EVENTS.connectionId = TASK.connectionId",
)
def get_query_statement_no_task(self):
return RANGE_QUERY_TEMPLATE.format(
device_start_ts="0",
device_end_ts="0",
task_join="",
)
def get_param_order(self):
return [Constant.START_NS, Constant.END_NS]