import os
from unittest import TestCase
import pandas as pd
from msprof_analyze.cluster_analyse.recipes.hccl_sum.hccl_sum import HcclSum
from msprof_analyze.prof_common.constant import Constant
from msprof_analyze.prof_common.db_manager import DBManager
from msprof_analyze.prof_common.path_manager import PathManager
from test.st.utils import execute_cmd
from test.st.utils import ST_DATA_PATH
class TestHcclSum(TestCase):
"""
Test recipe: hccl_sum
"""
CLUSTER_PATH = os.path.join(ST_DATA_PATH, "cluster_data_2_db")
OUTPUT_PATH = os.path.join(os.path.abspath(os.path.dirname(__file__)), "TestHcclSum")
COMMAND_SUCCESS = 0
def setup_class(self):
PathManager.make_dir_safety(self.OUTPUT_PATH)
cmd = ["msprof-analyze", "cluster", "-d", self.CLUSTER_PATH, "-m", "hccl_sum",
"--output_path", self.OUTPUT_PATH, "--force"]
if execute_cmd(cmd) != self.COMMAND_SUCCESS or not os.path.exists(self.OUTPUT_PATH):
self.fail("HcclSum task failed.")
self.db_path = os.path.join(self.OUTPUT_PATH, Constant.CLUSTER_ANALYSIS_OUTPUT,
Constant.DB_CLUSTER_COMMUNICATION_ANALYZER)
self.conn, self.cursor = DBManager.create_connect_db(self.db_path)
self.db_path_base = os.path.join(self.CLUSTER_PATH, "cluster_analysis_output_base",
Constant.DB_CLUSTER_COMMUNICATION_ANALYZER)
self.conn_base, self.cursor_base = DBManager.create_connect_db(self.db_path_base)
def teardown_class(self):
DBManager.destroy_db_connect(self.conn, self.cursor)
DBManager.destroy_db_connect(self.conn_base, self.cursor_base)
PathManager.remove_path_safety(self.OUTPUT_PATH)
def check_tables_in_db(self):
expected_tables = [
HcclSum.TABLE_ALL_RANK_STATS,
HcclSum.TABLE_PER_RANK_STATS,
HcclSum.TABLE_TOP_OP_STATS,
HcclSum.TABLE_GROUP_NAME_MAP
]
return DBManager.check_tables_in_db(self.db_path, *expected_tables)
def check_hccl_all_rank_stats_columns(self):
expected_columns = ["OpType", "Count", "MeanNs", "StdNs", "MinNs", "Q1Ns", "MedianNs", "Q3Ns",
"MaxNs", "SumNs"]
return DBManager.get_table_columns_name(self.cursor, HcclSum.TABLE_ALL_RANK_STATS) == expected_columns
def check_hccl_per_rank_stats_columns(self):
expected_columns = ["OpType", "Count", "MeanNs", "StdNs", "MinNs", "Q1Ns", "MedianNs", "Q3Ns",
"MaxNs", "SumNs", "Rank"]
return DBManager.get_table_columns_name(self.cursor, HcclSum.TABLE_PER_RANK_STATS) == expected_columns
def check_hccl_top_op_stats_columns(self):
expected_columns = ["OpName", "Count", "MeanNs", "StdNs", "MinNs", "Q1Ns", "MedianNs", "Q3Ns",
"MaxNs", "SumNs", "MinRank", "MaxRank"]
return DBManager.get_table_columns_name(self.cursor, HcclSum.TABLE_TOP_OP_STATS) == expected_columns
def check_hccl_group_name_map_columns(self):
expected_columns = ["GroupName", "GroupId", "Ranks"]
return DBManager.get_table_columns_name(self.cursor, HcclSum.TABLE_GROUP_NAME_MAP) == expected_columns
def test_hccl_sum_should_run_success_when_given_cluster_data(self):
self.assertTrue(self.check_tables_in_db(), msg="DB does not exist or is missing tables.")
self.assertTrue(self.check_hccl_all_rank_stats_columns(),
msg=f"The header of {HcclSum.TABLE_ALL_RANK_STATS} does not meet expectations.")
self.assertTrue(self.check_hccl_per_rank_stats_columns(),
msg=f"The header of {HcclSum.TABLE_PER_RANK_STATS} does not meet expectations.")
self.assertTrue(self.check_hccl_top_op_stats_columns(),
msg=f"The header of {HcclSum.TABLE_TOP_OP_STATS} does not meet expectations.")
self.assertTrue(self.check_hccl_group_name_map_columns(),
msg=f"The header of {HcclSum.TABLE_GROUP_NAME_MAP} does not meet expectations.")
def test_hccl_all_rank_stats_data_when_given_cluster_data(self):
query = f"select * from {HcclSum.TABLE_ALL_RANK_STATS}"
df = pd.read_sql(query, self.conn)
df_base = pd.read_sql(query, self.conn_base)
self.assertTrue(df.equals(df_base))
def test_hccl_per_rank_stats_data_when_given_cluster_data(self):
query = f"select * from {HcclSum.TABLE_PER_RANK_STATS}"
df = pd.read_sql(query, self.conn)
df_base = pd.read_sql(query, self.conn_base)
self.assertTrue(df.equals(df_base))
def test_hccl_top_op_stats_data_when_given_cluster_data(self):
query = f"select * from {HcclSum.TABLE_TOP_OP_STATS}"
df = pd.read_sql(query, self.conn)
df_base = pd.read_sql(query, self.conn_base)
self.assertTrue(df.equals(df_base))
def test_hccl_group_name_map_data_when_given_cluster_data(self):
query = f"select * from {HcclSum.TABLE_GROUP_NAME_MAP}"
df = pd.read_sql(query, self.conn)
df_base = pd.read_sql(query, self.conn_base)
self.assertTrue(df.equals(df_base))