import json
import os
import unittest
from unittest.mock import patch, MagicMock, call
import pandas as pd
from msprof_analyze.cluster_analyse.recipes.base_recipe_analysis import BaseRecipeAnalysis
from msprof_analyze.prof_common.constant import Constant
from msprof_analyze.prof_common.db_manager import DBManager
from msprof_analyze.prof_common.file_manager import FileManager
class TestBaseRecipeAnalysis(unittest.TestCase):
def setUp(self):
self.params = {
Constant.COLLECTION_PATH: '/tmp/to/collection',
Constant.DATA_MAP: {0: '/tmp/to/data/0', 1: '/tmp/to/data/1'},
Constant.RECIPE_NAME: 'test_recipe',
Constant.PARALLEL_MODE: 'parallel',
Constant.EXPORT_TYPE: 'csv',
Constant.PROFILING_TYPE: Constant.PYTORCH,
Constant.IS_MSPROF: False,
Constant.IS_MINDSPORE: False,
Constant.CLUSTER_ANALYSIS_OUTPUT_PATH: '/tmp/to/output',
Constant.RANK_LIST: '0,1',
Constant.STEP_ID: 1,
Constant.EXTRA_ARGS: []
}
class ConcreteRecipeAnalysis(BaseRecipeAnalysis):
@property
def base_dir(self):
return 'test_dir'
def run(self, context):
pass
with patch('msprof_analyze.prof_common.path_manager.PathManager.check_output_directory_path'):
self.analysis = ConcreteRecipeAnalysis(self.params)
def test_enter_exit(self):
with self.analysis as instance:
self.assertEqual(instance, self.analysis)
with patch('msprof_analyze.cluster_analyse.recipes.base_recipe_analysis.logger.error') as mock_logger, \
patch('traceback.print_exc') as mock_traceback:
try:
with self.analysis:
raise ValueError('Test error')
except ValueError:
pass
mock_logger.assert_called_once_with('Failed to exit analysis: Test error')
mock_traceback.assert_called_once()
def test_output_path_property(self):
self.assertEqual(
self.analysis.output_path,
os.path.join('/tmp/to/output', Constant.CLUSTER_ANALYSIS_OUTPUT, 'test_recipe')
)
def test_filter_data(self):
test_data = [(1, [1, 2, 3]), (2, []), (3, None), (4, [4, 5])]
result = BaseRecipeAnalysis._filter_data(test_data)
self.assertEqual(result, [(1, [1, 2, 3]), (4, [4, 5])])
@patch.object(DBManager, 'create_connect_db')
@patch.object(DBManager, 'destroy_db_connect')
def test_dump_data_to_db(self, mock_destroy, mock_create):
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_create.return_value = (mock_conn, mock_cursor)
data = pd.DataFrame({'a': [1, 2], 'b': [3, 4]})
self.analysis.dump_data(data, 'test.db', 'test_table')
mock_create.assert_called_once_with(os.path.join(self.analysis.output_path, 'test.db'))
mock_destroy.assert_called_once_with(mock_conn, mock_cursor)
@patch.object(FileManager, 'create_csv_from_dataframe')
def test_dump_data_to_csv(self, mock_create_csv):
data = pd.DataFrame({'a': [1, 2], 'b': [3, 4]})
with patch('msprof_analyze.cluster_analyse.common_func.utils.convert_unit', return_value=data):
self.analysis.dump_data(data, 'test.csv')
@patch('shutil.copy')
@patch('os.chmod')
def test_create_notebook_without_replace(self, mock_chmod, mock_copy):
self.analysis.create_notebook('test.ipynb')
mock_copy.assert_called_once_with(
os.path.realpath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "msprof_analyze",
"cluster_analyse", "recipes", 'test_dir', 'test.ipynb')),
os.path.join(self.analysis.output_path, 'test.ipynb')
)
mock_chmod.assert_called_once_with(
os.path.join(self.analysis.output_path, 'test.ipynb'),
Constant.FILE_AUTHORITY
)
@patch('shutil.copy')
@patch('os.chmod')
def test_add_helper_file(self, mock_chmod, mock_copy):
helper_file = 'test_helper.txt'
mock_dirname = MagicMock(return_value='test_dir')
with patch('os.path.dirname', mock_dirname):
self.analysis.add_helper_file(helper_file)
mock_copy.assert_called_once_with(
os.path.join('test_dir', helper_file),
os.path.join(self.analysis.output_path, helper_file)
)
mock_chmod.assert_called_once_with(
os.path.join(self.analysis.output_path, helper_file),
Constant.FILE_AUTHORITY
)
def test_map_rank_pp_stage(self):
distributed_args = {}
result = self.analysis.map_rank_pp_stage(distributed_args)
self.assertEqual(result, {0: 0})
distributed_args = {self.analysis.TP_SIZE: 2}
result = self.analysis.map_rank_pp_stage(distributed_args)
self.assertEqual(result, {0: 0, 1: 0})
distributed_args = {self.analysis.PP_SIZE: 2}
result = self.analysis.map_rank_pp_stage(distributed_args)
self.assertEqual(result, {0: 0, 1: 1})
distributed_args = {
self.analysis.TP_SIZE: 2,
self.analysis.PP_SIZE: 2,
self.analysis.DP_SIZE: 2
}
result = self.analysis.map_rank_pp_stage(distributed_args)
self.assertEqual(result, {
0: 0, 1: 0, 2: 0, 3: 0,
4: 1, 5: 1, 6: 1, 7: 1
})
@patch('os.path.exists')
@patch('json.loads')
def test_load_distributed_args_from_extra_args(self, mock_json_loads, mock_exists):
self.analysis._extra_args = {'tp': 2, 'pp': 2, 'dp': 2}
result = self.analysis.load_distributed_args()
self.assertEqual(result, {
self.analysis.TP_SIZE: 2,
self.analysis.PP_SIZE: 2,
self.analysis.DP_SIZE: 2
})
@patch('os.path.exists')
@patch('json.loads')
@patch('msprof_analyze.cluster_analyse.recipes.base_recipe_analysis.DatabaseService')
def test_load_distributed_args_from_db(self, mock_service, mock_json_loads, mock_exists):
mock_exists.return_value = True
mock_df = MagicMock()
mock_df.loc.return_value = MagicMock(empty=False, values=[json.dumps({
self.analysis.TP_SIZE: 1,
self.analysis.PP_SIZE: 1,
self.analysis.DP_SIZE: 1
})])
mock_service.return_value.query_data.return_value = {'META_DATA': mock_df}
result = self.analysis.load_distributed_args()
self.assertEqual(result, {
self.analysis.TP_SIZE: 1,
self.analysis.PP_SIZE: 1,
self.analysis.DP_SIZE: 1
})
@patch('os.path.exists')
def test_get_rank_db(self, mock_exists):
mock_exists.return_value = True
self.analysis._get_step_range = MagicMock(return_value={'id': 1})
self.analysis._get_profiler_db_path = MagicMock(return_value='test_profiler.db')
self.analysis._get_analysis_db_path = MagicMock(return_value='test_analysis.db')
result = self.analysis._get_rank_db()
self.assertEqual(len(result), 2)
self.assertEqual(result[0][Constant.RANK_ID], 0)
self.assertEqual(result[0][Constant.PROFILER_DB_PATH], 'test_profiler.db')
self.assertEqual(result[0][Constant.ANALYSIS_DB_PATH], 'test_analysis.db')
self.assertEqual(result[0][Constant.STEP_RANGE], {'id': 1})
@patch('msprof_analyze.cluster_analyse.recipes.base_recipe_analysis.logger.warning')
@patch('os.path.exists')
def test_get_rank_db_filters_by_required_db_keys_and_logs_summary(self, mock_exists, mock_warning):
class RequiredDbRecipeAnalysis(BaseRecipeAnalysis):
@property
def base_dir(self):
return 'test_dir'
@property
def required_db_keys(self):
return [Constant.PROFILER_DB_PATH, Constant.ANALYSIS_DB_PATH]
def run(self, context):
pass
params = dict(self.params)
params[Constant.RANK_LIST] = '0,1,9'
with patch('msprof_analyze.prof_common.path_manager.PathManager.check_output_directory_path'):
analysis = RequiredDbRecipeAnalysis(params)
analysis._get_step_range = MagicMock(return_value={'id': 1})
analysis._get_profiler_db_path = MagicMock(side_effect=lambda rank_id, _: f'profiler_{rank_id}.db')
analysis._get_analysis_db_path = MagicMock(
side_effect=lambda rank_path: f'analysis_{os.path.basename(rank_path)}.db'
)
mock_exists.side_effect = lambda path: path in {'profiler_0.db', 'analysis_0.db', 'profiler_1.db'}
result = analysis._get_rank_db()
self.assertEqual(result, [{
Constant.RANK_ID: 0,
Constant.PROFILER_DB_PATH: 'profiler_0.db',
Constant.ANALYSIS_DB_PATH: 'analysis_0.db',
Constant.STEP_RANGE: {'id': 1},
Constant.PROFILING_PATH: '/tmp/to/data/0'
}])
mock_warning.assert_has_calls([
call('Invalid Rank id: [9].'),
call('test_recipe: missing analysis DB file (analysis.db) for 1 rank(s) [1]; these ranks will be skipped.')
])
@patch.object(BaseRecipeAnalysis, '_get_rank_db')
def test_mapper_func_returns_empty_when_no_valid_rank_db(self, mock_get_rank_db):
mock_get_rank_db.return_value = []
context = MagicMock()
result = self.analysis.mapper_func(context)
self.assertEqual(result, [])
context.map.assert_not_called()
context.wait.assert_not_called()
def test_get_profiler_db_path(self):
result = self.analysis._get_profiler_db_path(0, 'test_path')
self.assertEqual(result, os.path.join('test_path', Constant.SINGLE_OUTPUT, 'ascend_pytorch_profiler_0.db'))
self.analysis._prof_type = Constant.MINDSPORE
result = self.analysis._get_profiler_db_path(0, 'test_path')
self.assertEqual(result, os.path.join('test_path', Constant.SINGLE_OUTPUT, 'ascend_mindspore_profiler_0.db'))
def test_get_analysis_db_path(self):
result = self.analysis._get_analysis_db_path('test_path')
self.assertEqual(result, os.path.join('test_path', Constant.SINGLE_OUTPUT, 'analysis.db'))
self.analysis._prof_type = Constant.MINDSPORE
result = self.analysis._get_analysis_db_path('test_path')
self.assertEqual(result, os.path.join('test_path', Constant.SINGLE_OUTPUT, 'communication_analyzer.db'))
@patch('msprof_analyze.cluster_analyse.recipes.base_recipe_analysis.DBManager.create_connect_db')
@patch('msprof_analyze.cluster_analyse.recipes.base_recipe_analysis.DBManager.judge_table_exists')
@patch('msprof_analyze.cluster_analyse.recipes.base_recipe_analysis.DBManager.fetch_all_data')
@patch('msprof_analyze.cluster_analyse.recipes.base_recipe_analysis.DBManager.destroy_db_connect')
def test_get_step_range(self, mock_destroy, mock_fetch, mock_judge, mock_connect):
mock_conn, mock_cursor = MagicMock(), MagicMock()
mock_connect.return_value = (mock_conn, mock_cursor)
mock_judge.return_value = True
mock_fetch.return_value = [{'id': 1, 'startNs': 0, 'endNs': 100}]
self.analysis._step_id = 1
result = self.analysis._get_step_range('test.db')
self.assertEqual(result, {'id': 1, 'startNs': 0, 'endNs': 100})
if __name__ == '__main__':
unittest.main()