from msprof_analyze.prof_common.constant import Constant
from msprof_analyze.prof_common.logger import get_logger
from msprof_analyze.prof_exports.base_stats_export import BaseStatsExport
logger = get_logger()
QUERY_COMPUTE_TASK = """
WITH task_connections AS (
SELECT
str.value AS name,
task.startNs,
task.endNs,
conn.id AS api_conn_id
FROM
{compute_table} AS compute
LEFT JOIN
TASK task ON compute.globalTaskId = task.globalTaskId
LEFT JOIN
STRING_IDS str ON str.id = compute.name
LEFT JOIN
CONNECTION_IDS conn ON conn.connectionId = task.connectionId
)"""
QUERY_COMMUNICATION_TASK = """
WITH task_connections AS (
SELECT
str.value AS name,
comm.startNs,
comm.endNs,
conn.id AS api_conn_id
FROM
COMMUNICATION_OP AS comm
JOIN
STRING_IDS str ON str.id = comm.opType
JOIN
CONNECTION_IDS conn ON conn.connectionId = comm.connectionId
)"""
QUERY_TASK_LINK_PYTORCH_API = """
SELECT
tc.name as kernel_name,
tc.startNs as kernel_ts,
tc.endNs as kernel_end,
api_str.value AS op_name,
api.startNs as op_ts,
api.endNs as op_end
FROM
task_connections tc
JOIN
PYTORCH_API api ON tc.api_conn_id = api.connectionId
JOIN
STRING_IDS api_str ON api.name = api_str.id
ORDER BY op_ts, kernel_ts
"""
QUERY_MSTX_RANGE_WITH_DOMAIN = """
SELECT
mstx.startNs,
mstx.endNs,
str_name.value AS name
FROM
MSTX_EVENTS mstx
LEFT JOIN
STRING_IDS str_name ON mstx.message = str_name.id
LEFT JOIN
STRING_IDS str_domain ON mstx.domainId = str_domain.id
WHERE
mstx.eventType = 2 AND str_domain.value = 'Module'
ORDER BY mstx.startNs
"""
QUEYR_FWD_BWD_FLOW = """
SELECT
c.connectionId as connectionId,
fwd_ids.value as fwd_name,
fwd_pa.startNs as fwd_ts,
fwd_pa.endNs as fwd_end,
bwd_ids.value as bwd_name,
bwd_pa.startNs as bwd_ts,
bwd_pa.endNs as bwd_end
FROM (
SELECT
connectionId,
MIN(id) as min_id,
MAX(id) as max_id
FROM CONNECTION_IDS
GROUP BY connectionId
HAVING COUNT(*) > 1
) c
LEFT JOIN PYTORCH_API fwd_pa ON fwd_pa.connectionId = c.min_id
LEFT JOIN STRING_IDS fwd_ids ON fwd_ids.id = fwd_pa.name
LEFT JOIN PYTORCH_API bwd_pa ON bwd_pa.connectionId = c.max_id
LEFT JOIN STRING_IDS bwd_ids ON bwd_ids.id = bwd_pa.name
WHERE fwd_ids.value NOT LIKE 'Enqueue%' AND fwd_ids.value NOT LIKE 'Dequeue%'
ORDER BY c.connectionId
"""
class FrameworkOpToKernelExport(BaseStatsExport):
def __init__(self, db_path, recipe_name, table_name):
super().__init__(db_path, recipe_name, param_dict=None)
if table_name in [Constant.TABLE_COMPUTE_TASK_INFO, Constant.TABLE_COMMUNICATION_SCHEDULE_TASK_INFO]:
self._query = (QUERY_COMPUTE_TASK + QUERY_TASK_LINK_PYTORCH_API).format(compute_table=table_name)
elif table_name == Constant.TABLE_COMMUNICATION_OP:
self._query = QUERY_COMMUNICATION_TASK + QUERY_TASK_LINK_PYTORCH_API
else:
logger.error(f"FrameworkOpToKernelExport not support {table_name}")
def get_param_order(self):
return []
class ModuleMstxRangeExport(BaseStatsExport):
def __init__(self, db_path, recipe_name):
super().__init__(db_path, recipe_name, param_dict=None)
self._query = QUERY_MSTX_RANGE_WITH_DOMAIN
def get_param_order(self):
return []
class FwdBwdFlowExport(BaseStatsExport):
def __init__(self, db_path, recipe_name):
super().__init__(db_path, recipe_name, param_dict=None)
self._query = QUEYR_FWD_BWD_FLOW
def get_param_order(self):
return []