from msprof_analyze.prof_exports.base_stats_export import BaseStatsExport
from msprof_analyze.prof_common.constant import Constant
GROUPED_MATMUL_QUERY = """
SELECT
InputShapes_IDS.value AS "InputShapes"
FROM COMPUTE_TASK_INFO
JOIN TASK
ON COMPUTE_TASK_INFO.globalTaskId = TASK.globalTaskId
LEFT JOIN STRING_IDS AS InputShapes_IDS
ON InputShapes_IDS.id = COMPUTE_TASK_INFO.inputShapes
WHERE COMPUTE_TASK_INFO.opType = (
SELECT id
FROM STRING_IDS
WHERE value = 'GroupedMatmul'
)
{}
"""
class InputShapeExport(BaseStatsExport):
def __init__(self, db_path, recipe_name, step_range):
super().__init__(db_path, recipe_name, step_range)
filter_statement = "And TASK.startNs >= ? And TASK.endNs <= ?" if step_range else ""
self._query = GROUPED_MATMUL_QUERY.format(filter_statement)