"""Unit tests for Trainer public interfaces: train() and get_batch()."""
import os
import sys
from unittest.mock import Mock, patch
import pytest
from mindformers.tools.register import MindFormerConfig
from mindformers.pynative.trainer import Trainer
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../..'))
sys.path.insert(0, project_root)
mindspore_mock = Mock()
mindspore_mock.nn = Mock()
mindspore_mock.dataset = Mock()
mindspore_mock.common = Mock()
mindspore_mock.common.tensor = Mock()
mindspore_mock.common.initializer = Mock()
mindspore_mock.common.parameter = Mock()
mindspore_mock.common.dtype = Mock()
mindspore_mock.ops = Mock()
mindspore_mock.context = Mock()
sys.modules['mindspore'] = mindspore_mock
sys.modules['mindspore.nn'] = mindspore_mock.nn
sys.modules['mindspore.dataset'] = mindspore_mock.dataset
sys.modules['mindspore.common'] = mindspore_mock.common
sys.modules['mindspore.common.tensor'] = mindspore_mock.common.tensor
sys.modules['mindspore.common.initializer'] = mindspore_mock.common.initializer
sys.modules['mindspore.common.parameter'] = mindspore_mock.common.parameter
sys.modules['mindspore.common.dtype'] = mindspore_mock.common.dtype
sys.modules['mindspore.ops'] = mindspore_mock.ops
sys.modules['mindspore.context'] = mindspore_mock.context
sys.modules['mindspore.parallel'] = Mock()
sys.modules['mindspore.train'] = Mock()
communication_mock = Mock()
communication_mock.management = Mock()
communication_mock.management.get_rank = Mock(return_value=0)
communication_mock.management.get_group_size = Mock(return_value=1)
sys.modules['mindspore.communication'] = communication_mock
sys.modules['mindspore.communication.management'] = communication_mock.management
sys.modules['mindspore._checkparam'] = Mock()
sys.modules['mindspore.amp'] = Mock()
sys.modules['mindspore._c_expression'] = Mock()
modules_mock = Mock()
modules_mock.transformer = Mock()
modules_mock.__all__ = []
sys.modules['mindformers.modules'] = modules_mock
sys.modules['mindformers.modules.transformer'] = modules_mock.transformer
checkpoint_mock = Mock()
checkpoint_mock.__all__ = []
checkpoint_mock.checkpoint = Mock()
sys.modules['mindformers.checkpoint'] = checkpoint_mock
sys.modules['mindformers.checkpoint.checkpoint'] = checkpoint_mock.checkpoint
models_mock = Mock()
models_mock.llama = Mock()
models_mock.__all__ = []
sys.modules['mindformers.models'] = models_mock
sys.modules['mindformers.models.llama'] = models_mock.llama
dataset_mock = Mock()
dataset_mock.__all__ = []
sys.modules['mindformers.dataset'] = dataset_mock
sys.modules['mindformers.run_check'] = Mock()
core_mock = Mock()
core_mock.context = Mock()
core_mock.__all__ = []
sys.modules['mindformers.core'] = core_mock
sys.modules['mindformers.core.context'] = core_mock.context
sys.modules['mindformers.core.config_args'] = Mock()
sys.modules['mindformers.core.lr'] = Mock()
sys.modules['mindformers.core.optim'] = Mock()
sys.modules['mindformers.core.callback'] = Mock()
sys.modules['mindformers.core.callback_pynative'] = Mock()
sys.modules['mindformers.core.metric'] = Mock()
pet_mock = Mock()
pet_mock.__all__ = []
sys.modules['mindformers.pet'] = pet_mock
wrapper_mock = Mock()
wrapper_mock.__all__ = []
sys.modules['mindformers.wrapper'] = wrapper_mock
generation_mock = Mock()
generation_mock.__all__ = []
sys.modules['mindformers.generation'] = generation_mock
pipeline_mock = Mock()
pipeline_mock.__all__ = []
sys.modules['mindformers.pipeline'] = pipeline_mock
trainer_mock = Mock()
trainer_mock.__all__ = []
trainer_mock.training_args = Mock()
trainer_mock.optimizer_grouped_parameters = Mock()
sys.modules['mindformers.trainer'] = trainer_mock
sys.modules['mindformers.trainer.training_args'] = trainer_mock.training_args
sys.modules['mindformers.trainer.optimizer_grouped_parameters'] = trainer_mock.optimizer_grouped_parameters
sys.modules['mindformers.trainer.general_task_trainer'] = Mock()
sys.modules['mindformers.model_runner'] = Mock()
train_state_mock = Mock()
mock_trainer_state_class = Mock()
sys.modules['trainer'] = Mock()
sys.modules['trainer.train_state'] = train_state_mock
train_state_mock.TrainerState = mock_trainer_state_class
class TestTrainerTrain:
"""Test cases for Trainer.train() interface."""
@pytest.fixture
def mock_config(self):
"""Create a mock MindFormerConfig."""
config = MindFormerConfig()
config.max_steps = 100
config.eval_steps = 20
config.save_steps = 50
config.global_batch_size = 32
return config
@pytest.fixture
def mock_trainer_state(self):
"""Create a mock TrainerState."""
state = Mock()
state.global_step = 0
state.epoch_step = 10
state.max_steps = 100
state.eval_steps = 20
state.save_steps = 50
state.global_batch_size = 32
state.update_epoch = Mock()
return state
@pytest.fixture
def mock_model(self):
"""Create a mock model."""
model = Mock()
model.__call__ = Mock(return_value={'loss': 0.5})
return model
@pytest.fixture
def mock_dataset(self):
"""Create a mock dataset."""
dataset = Mock()
dataset.__len__ = Mock(return_value=10)
dataset.get_dataset_size = Mock(return_value=10)
mock_iter = Mock()
mock_iter.__next__ = Mock(side_effect=[
{'input_ids': [1, 2, 3], 'labels': [2, 3, 4]},
{'input_ids': [4, 5, 6], 'labels': [5, 6, 7]},
] * 100)
dataset.create_dict_iterator = Mock(return_value=mock_iter)
return dataset
@pytest.fixture
def mock_optimizer(self):
"""Create a mock optimizer."""
optimizer = Mock()
optimizer.step = Mock()
return optimizer
@pytest.fixture
def mock_callback_handler(self):
"""Create a mock CallbackHandler."""
handler = Mock()
handler.on_train_begin = Mock()
handler.on_train_end = Mock()
handler.on_epoch_begin = Mock()
handler.on_epoch_end = Mock()
handler.on_step_begin = Mock()
handler.on_step_end = Mock()
return handler
def test_train_pretrain_mode_success(
self, mock_config, mock_model, mock_dataset,
mock_optimizer, mock_callback_handler, mock_trainer_state
):
"""Test train() in pretrain mode executes successfully."""
trainer = Trainer.__new__(Trainer)
trainer.config = mock_config
trainer.model = mock_model
trainer.train_dataset = mock_dataset
trainer.eval_dataset = None
trainer.optimizer = mock_optimizer
trainer.lr_scheduler = Mock()
trainer.callback_handler = mock_callback_handler
trainer.compute_metrics = None
trainer.compute_loss_func = None
trainer.processing_class = None
trainer._init_parallel_config = Mock()
trainer._load_checkpoint = Mock()
trainer._get_dataset_size = Mock(return_value=10)
trainer._inner_train_loop = Mock()
with patch('trainer.train_state.TrainerState', return_value=mock_trainer_state):
trainer.train(checkpoint_path=None, mode="pretrain", do_eval=False)
trainer._init_parallel_config.assert_called_once()
trainer._load_checkpoint.assert_not_called()
mock_callback_handler.on_train_begin.assert_called_once()
mock_callback_handler.on_train_end.assert_called_once()
trainer._inner_train_loop.assert_called_once_with(False)
def test_train_finetune_mode_requires_checkpoint(
self, mock_config, mock_model, mock_dataset,
mock_optimizer, mock_callback_handler
):
"""Test train() in finetune mode raises error without checkpoint."""
trainer = Trainer.__new__(Trainer)
trainer.config = mock_config
trainer.model = mock_model
trainer.train_dataset = mock_dataset
trainer.eval_dataset = None
trainer.optimizer = mock_optimizer
trainer.lr_scheduler = Mock()
trainer.callback_handler = mock_callback_handler
trainer.compute_metrics = None
trainer.compute_loss_func = None
trainer.processing_class = None
trainer._init_parallel_config = Mock()
with pytest.raises(ValueError, match="checkpoint_path cannot be None"):
trainer.train(checkpoint_path=None, mode="finetune", do_eval=False)
def test_train_finetune_mode_with_checkpoint(
self, mock_config, mock_model, mock_dataset,
mock_optimizer, mock_callback_handler, mock_trainer_state
):
"""Test train() in finetune mode with checkpoint loads correctly."""
trainer = Trainer.__new__(Trainer)
trainer.config = mock_config
trainer.model = mock_model
trainer.train_dataset = mock_dataset
trainer.eval_dataset = None
trainer.optimizer = mock_optimizer
trainer.lr_scheduler = Mock()
trainer.callback_handler = mock_callback_handler
trainer.compute_metrics = None
trainer.compute_loss_func = None
trainer.processing_class = None
trainer._init_parallel_config = Mock()
trainer._load_checkpoint = Mock()
trainer._get_dataset_size = Mock(return_value=10)
trainer._inner_train_loop = Mock()
checkpoint_path = "/mock/checkpoint.ckpt"
with patch('trainer.train_state.TrainerState', return_value=mock_trainer_state):
with patch('os.path.exists', return_value=True):
trainer.train(checkpoint_path=checkpoint_path, mode="finetune", do_eval=False)
trainer._load_checkpoint.assert_called_once_with(checkpoint_path, "finetune")
def test_train_invalid_mode_raises_error(
self, mock_config, mock_model, mock_dataset,
mock_optimizer, mock_callback_handler
):
"""Test train() raises error with invalid mode."""
trainer = Trainer.__new__(Trainer)
trainer.config = mock_config
trainer.model = mock_model
trainer.train_dataset = mock_dataset
trainer.optimizer = mock_optimizer
trainer.callback_handler = mock_callback_handler
trainer._init_parallel_config = Mock()
with pytest.raises(ValueError, match="mode must be 'pretrain' or 'finetune'"):
trainer.train(checkpoint_path=None, mode="invalid_mode", do_eval=False)
def test_train_calls_callbacks_correctly(
self, mock_config, mock_model, mock_dataset,
mock_optimizer, mock_callback_handler, mock_trainer_state
):
"""Test train() calls all callback hooks in correct order."""
trainer = Trainer.__new__(Trainer)
trainer.config = mock_config
trainer.model = mock_model
trainer.train_dataset = mock_dataset
trainer.optimizer = mock_optimizer
trainer.lr_scheduler = Mock()
trainer.callback_handler = mock_callback_handler
trainer.compute_loss_func = None
trainer._init_parallel_config = Mock()
trainer._load_checkpoint = Mock()
trainer._get_dataset_size = Mock(return_value=10)
trainer._inner_train_loop = Mock()
with patch('trainer.train_state.TrainerState', return_value=mock_trainer_state):
trainer.train(checkpoint_path=None, mode="pretrain", do_eval=False)
assert mock_callback_handler.on_train_begin.called
assert mock_callback_handler.on_train_end.called
call_order = [
call for call in mock_callback_handler.method_calls
if call[0] in ['on_train_begin', 'on_train_end']
]
assert call_order[0][0] == 'on_train_begin'
assert call_order[-1][0] == 'on_train_end'
def test_train_with_do_eval_true(
self, mock_config, mock_model, mock_dataset,
mock_optimizer, mock_callback_handler, mock_trainer_state
):
"""Test train() with do_eval=True passes flag to inner loop."""
trainer = Trainer.__new__(Trainer)
trainer.config = mock_config
trainer.model = mock_model
trainer.train_dataset = mock_dataset
trainer.optimizer = mock_optimizer
trainer.lr_scheduler = Mock()
trainer.callback_handler = mock_callback_handler
trainer._init_parallel_config = Mock()
trainer._get_dataset_size = Mock(return_value=10)
trainer._inner_train_loop = Mock()
with patch('trainer.train_state.TrainerState', return_value=mock_trainer_state):
trainer.train(checkpoint_path=None, mode="pretrain", do_eval=True)
trainer._inner_train_loop.assert_called_once_with(True)
class TestTrainerGetBatch:
"""Test cases for Trainer.get_batch() interface."""
@pytest.fixture
def mock_config(self):
"""Create a mock config."""
config = Mock()
config.use_distribute_dataset = False
config.use_remove_redundant_dataset = False
return config
@pytest.fixture
def mock_dataset_iter(self):
"""Create a mock dataset iterator."""
iterator = Mock()
iterator.__next__ = Mock(return_value={'input_ids': [1, 2, 3], 'labels': [2, 3, 4]})
return iterator
def test_get_batch_naive_mode_returns_dict(self, mock_config, mock_dataset_iter):
"""Test get_batch() in naive mode returns dict data."""
trainer = Trainer.__new__(Trainer)
trainer.config = mock_config
batch = trainer.get_batch(mock_dataset_iter)
assert isinstance(batch, dict)
assert 'input_ids' in batch
assert batch['input_ids'] == [1, 2, 3]
def test_get_batch_distributed_mode(self, mock_config, mock_dataset_iter):
"""Test get_batch() in distributed mode."""
mock_config.use_distribute_dataset = True
trainer = Trainer.__new__(Trainer)
trainer.config = mock_config
batch = trainer.get_batch(mock_dataset_iter)
assert isinstance(batch, dict)
assert 'input_ids' in batch
def test_get_batch_remove_redundant_mode(self, mock_config, mock_dataset_iter):
"""Test get_batch() in remove redundant mode."""
mock_config.use_remove_redundant_dataset = True
trainer = Trainer.__new__(Trainer)
trainer.config = mock_config
batch = trainer.get_batch(mock_dataset_iter)
assert isinstance(batch, dict)
assert 'input_ids' in batch
def test_get_batch_handles_tuple_data(self, mock_config):
"""Test get_batch() converts tuple data to dict."""
trainer = Trainer.__new__(Trainer)
trainer.config = mock_config
mock_iter = Mock()
mock_iter.__next__ = Mock(return_value=([1, 2, 3], [2, 3, 4]))
batch = trainer.get_batch(mock_iter)
assert isinstance(batch, dict)
assert 'input_ids' in batch
def test_get_batch_handles_list_data(self, mock_config):
"""Test get_batch() converts list data to dict."""
trainer = Trainer.__new__(Trainer)
trainer.config = mock_config
mock_iter = Mock()
mock_iter.__next__ = Mock(return_value=[[1, 2, 3], [2, 3, 4]])
batch = trainer.get_batch(mock_iter)
assert isinstance(batch, dict)
assert 'input_ids' in batch
def test_get_batch_handles_none_data(self, mock_config):
"""Test get_batch() handles None data gracefully."""
trainer = Trainer.__new__(Trainer)
trainer.config = mock_config
mock_iter = Mock()
mock_iter.__next__ = Mock(return_value=None)
batch = trainer.get_batch(mock_iter)
assert isinstance(batch, dict)
assert len(batch) == 0
def test_get_batch_calls_correct_internal_method(self, mock_config, mock_dataset_iter):
"""Test get_batch() calls the correct internal method based on config."""
trainer = Trainer.__new__(Trainer)
trainer.config = mock_config
trainer._get_batch_naive = Mock(return_value={'input_ids': [1, 2, 3]})
trainer._get_batch_distributed = Mock()
trainer._get_batch_remove_redundant = Mock()
trainer.get_batch(mock_dataset_iter)
trainer._get_batch_naive.assert_called_once()
trainer.config.use_distribute_dataset = True
trainer._get_batch_distributed.reset_mock()
trainer.get_batch(mock_dataset_iter)
trainer._get_batch_distributed.assert_called_once()
trainer.config.use_distribute_dataset = False
trainer.config.use_remove_redundant_dataset = True
trainer._get_batch_remove_redundant.reset_mock()
trainer.get_batch(mock_dataset_iter)
trainer._get_batch_remove_redundant.assert_called_once()
if __name__ == '__main__':
pytest.main([__file__, '-v'])