"""
test mindformer_book.py
"""
from unittest.mock import patch
import pytest
from mindformers.mindformer_book import MindFormerBook
class TestMindFormerBook:
""" A test class for testing mindformer_book."""
def setup_method(self):
"""Execute before each test method: save original data and set up test data"""
self.original_trainer_list = getattr(MindFormerBook, '_TRAINER_SUPPORT_TASKS_LIST', {})
self.original_pipeline_list = getattr(MindFormerBook, '_PIPELINE_SUPPORT_TASK_LIST', {})
MindFormerBook._TRAINER_SUPPORT_TASKS_LIST = {
"general": {"some_key": "some_value"},
"text_generation": {
"common": {"config": "value"},
"model1": "path1",
"model2": "path2"
},
"text_classification": {
"common": {"config": "value"},
"model3": "path3"
}
}
MindFormerBook._PIPELINE_SUPPORT_TASK_LIST = {
"text_generation": {
"common": {"config": "value"},
"model1": "path1",
"model2": "path2"
},
"text_classification": {
"common": {"config": "value"},
"model3": "path3"
},
"image_classification": {
"common": {"config": "value"},
"model4": "path4"
}
}
def teardown_method(self):
"""Execute after each test method: restore original data"""
MindFormerBook._TRAINER_SUPPORT_TASKS_LIST = self.original_trainer_list
MindFormerBook._PIPELINE_SUPPORT_TASK_LIST = self.original_pipeline_list
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_show_trainer_support_model_list_without_task(self):
"""Test case when no task is specified"""
with patch('mindformers.mindformer_book.print_dict') as mock_print_dict, \
patch('mindformers.mindformer_book.logger') as mock_logger:
MindFormerBook.show_trainer_support_model_list()
mock_logger.info.assert_called_with("Trainer support model list of MindFormer is: ")
mock_print_dict.assert_called_once()
call_args = mock_print_dict.call_args[0][0]
assert "text_generation" in call_args
assert "text_classification" in call_args
assert "general" not in call_args
assert call_args["text_generation"] == ["model1", "model2"]
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_show_trainer_support_model_list_with_valid_task(self):
"""Test case when a valid task is specified"""
with patch('mindformers.mindformer_book.print_path_or_list') as mock_print_list, \
patch('mindformers.mindformer_book.logger') as mock_logger:
MindFormerBook.show_trainer_support_model_list(task="text_generation")
mock_logger.info.assert_called_with("Trainer support model list for %s task is: ", "text_generation")
mock_print_list.assert_called_once()
call_args = mock_print_list.call_args[0][0]
assert call_args == ["model1", "model2"]
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_show_trainer_support_model_list_with_another_valid_task(self):
"""Test case when another valid task is specified"""
with patch('mindformers.mindformer_book.print_path_or_list') as mock_print_list, \
patch('mindformers.mindformer_book.logger') as mock_logger:
MindFormerBook.show_trainer_support_model_list(task="text_classification")
mock_logger.info.assert_called_with("Trainer support model list for %s task is: ", "text_classification")
mock_print_list.assert_called_once()
call_args = mock_print_list.call_args[0][0]
assert call_args == ["model3"]
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_show_trainer_support_model_list_with_invalid_task(self):
"""Test case when an invalid task is specified"""
with patch('mindformers.mindformer_book.logger') as mock_logger:
with pytest.raises(KeyError, match="unsupported task"):
MindFormerBook.show_trainer_support_model_list(task="invalid_task")
mock_logger.info.assert_not_called()
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_show_pipeline_support_model_list_without_task(self):
"""Test pipeline case when no task is specified"""
with patch('mindformers.mindformer_book.print_dict') as mock_print_dict, \
patch('mindformers.mindformer_book.logger') as mock_logger:
MindFormerBook.show_pipeline_support_model_list()
mock_logger.info.assert_called_with("Pipeline support model list of MindFormer is: ")
mock_print_dict.assert_called_once()
call_args = mock_print_dict.call_args[0][0]
assert "text_generation" in call_args
assert "text_classification" in call_args
assert "image_classification" in call_args
assert call_args["text_generation"] == ["model1", "model2"]
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_show_pipeline_support_model_list_with_valid_task(self):
"""Test pipeline case when a valid task is specified"""
with patch('mindformers.mindformer_book.print_path_or_list') as mock_print_list, \
patch('mindformers.mindformer_book.logger') as mock_logger:
MindFormerBook.show_pipeline_support_model_list(task="text_generation")
mock_logger.info.assert_called_with("Pipeline support model list for %s task is: ", "text_generation")
mock_print_list.assert_called_once()
call_args = mock_print_list.call_args[0][0]
assert call_args == ["model1", "model2"]
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_show_pipeline_support_model_list_with_another_valid_task(self):
"""Test pipeline case when another valid task is specified"""
with patch('mindformers.mindformer_book.print_path_or_list') as mock_print_list, \
patch('mindformers.mindformer_book.logger') as mock_logger:
MindFormerBook.show_pipeline_support_model_list(task="image_classification")
mock_logger.info.assert_called_with("Pipeline support model list for %s task is: ", "image_classification")
mock_print_list.assert_called_once()
call_args = mock_print_list.call_args[0][0]
assert call_args == ["model4"]
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_show_pipeline_support_model_list_with_invalid_task(self):
"""Test pipeline case when an invalid task is specified"""
with patch('mindformers.mindformer_book.logger') as mock_logger:
with pytest.raises(KeyError, match="unsupported task"):
MindFormerBook.show_pipeline_support_model_list(task="invalid_task")
mock_logger.info.assert_not_called()