import unittest
from unittest import mock
from mindspeed_llm.core.layerwise_disaggregated_training.distributed_data_parallel import (
finish_grad_sync,
finish_grad_sync_ldt,
register_grad_ready,
)
class TestDistributedDataParallel(unittest.TestCase):
@mock.patch('mindspeed_llm.core.layerwise_disaggregated_training.distributed_data_parallel.finish_grad_sync_ldt')
def test_finish_grad_sync_empty_buckets(self, mock_finish_grad_sync_ldt):
mock_self = mock.MagicMock()
mock_self.bucket_groups = []
mock_self.expert_parallel_bucket_groups = []
finish_grad_sync(mock_self)
mock_finish_grad_sync_ldt.assert_not_called()
@mock.patch('mindspeed_llm.core.layerwise_disaggregated_training.distributed_data_parallel.finish_grad_sync_ldt')
def test_finish_grad_sync_with_buckets(self, mock_finish_grad_sync_ldt):
mock_self = mock.MagicMock()
bucket_group1 = mock.MagicMock()
bucket_group2 = mock.MagicMock()
expert_bucket_group = mock.MagicMock()
mock_self.bucket_groups = [bucket_group1, bucket_group2]
mock_self.expert_parallel_bucket_groups = [expert_bucket_group]
finish_grad_sync(mock_self)
self.assertEqual(mock_finish_grad_sync_ldt.call_count, 3)
mock_finish_grad_sync_ldt.assert_any_call(bucket_group1)
mock_finish_grad_sync_ldt.assert_any_call(bucket_group2)
mock_finish_grad_sync_ldt.assert_any_call(expert_bucket_group)
@mock.patch(
'mindspeed_llm.core.layerwise_disaggregated_training.distributed_data_parallel.parallel_state.is_pipeline_first_stage'
)
def test_finish_grad_sync_ldt_pipeline_first_stage(self, mock_is_pipeline_first_stage):
mock_is_pipeline_first_stage.return_value = True
bucket_group = mock.MagicMock()
finish_grad_sync_ldt(bucket_group)
bucket_group.start_grad_sync.assert_not_called()
@mock.patch(
'mindspeed_llm.core.layerwise_disaggregated_training.distributed_data_parallel.parallel_state.is_pipeline_first_stage'
)
def test_finish_grad_sync_ldt_non_overlap_grad_reduce(self, mock_is_pipeline_first_stage):
mock_is_pipeline_first_stage.return_value = False
bucket_group = mock.MagicMock()
bucket_group.ddp_config.overlap_grad_reduce = False
finish_grad_sync_ldt(bucket_group)
self.assertFalse(bucket_group.param_gather_dispatched)
bucket_group.start_grad_sync.assert_called_once()
@mock.patch(
'mindspeed_llm.core.layerwise_disaggregated_training.distributed_data_parallel.parallel_state.is_pipeline_first_stage'
)
@mock.patch(
'mindspeed_llm.core.layerwise_disaggregated_training.distributed_data_parallel.torch.cuda.default_stream'
)
def test_finish_grad_sync_ldt_multi_dist_optimizer(self, mock_default_stream, mock_is_pipeline_first_stage):
mock_is_pipeline_first_stage.return_value = False
bucket_group = mock.MagicMock()
bucket_group.ddp_config.overlap_grad_reduce = True
bucket_group.ddp_config.num_distributed_optimizer_instances = 2
finish_grad_sync_ldt(bucket_group)
mock_default_stream.return_value.wait_stream.assert_called_once_with(bucket_group.communication_stream)
bucket_group.start_grad_sync.assert_not_called()
@mock.patch(
'mindspeed_llm.core.layerwise_disaggregated_training.distributed_data_parallel.parallel_state.is_pipeline_first_stage'
)
def test_finish_grad_sync_ldt_handle_wait(self, mock_is_pipeline_first_stage):
mock_is_pipeline_first_stage.return_value = False
bucket_group = mock.MagicMock()
bucket_group.ddp_config.overlap_grad_reduce = True
bucket_group.ddp_config.num_distributed_optimizer_instances = 1
mock_handle = mock.MagicMock()
bucket_group.grad_reduce_handle = [mock_handle]
finish_grad_sync_ldt(bucket_group)
mock_handle.wait.assert_called_once()
self.assertIsNone(bucket_group.grad_reduce_handle)
@mock.patch(
'mindspeed_llm.core.layerwise_disaggregated_training.distributed_data_parallel.parallel_state.is_pipeline_first_stage'
)
def test_finish_grad_sync_ldt_no_handle(self, mock_is_pipeline_first_stage):
mock_is_pipeline_first_stage.return_value = False
bucket_group = mock.MagicMock()
bucket_group.ddp_config.overlap_grad_reduce = True
bucket_group.ddp_config.num_distributed_optimizer_instances = 1
bucket_group.grad_reduce_handle = None
bucket_group.params_with_grad = []
bucket_group.params = []
with self.assertRaises(AssertionError):
finish_grad_sync_ldt(bucket_group)
@mock.patch(
'mindspeed_llm.core.layerwise_disaggregated_training.distributed_data_parallel.parallel_state.is_pipeline_first_stage'
)
def test_register_grad_ready_pipeline_first_stage(self, mock_is_pipeline_first_stage):
mock_is_pipeline_first_stage.return_value = True
mock_self = mock.MagicMock()
param = mock.MagicMock()
register_grad_ready(mock_self, param)
mock_self.params_with_grad.add.assert_not_called()
@mock.patch(
'mindspeed_llm.core.layerwise_disaggregated_training.distributed_data_parallel.parallel_state.is_pipeline_first_stage'
)
def test_register_grad_ready_not_overlap_grad_reduce(self, mock_is_pipeline_first_stage):
mock_is_pipeline_first_stage.return_value = False
mock_self = mock.MagicMock()
mock_self.ddp_config.overlap_grad_reduce = False
param = mock.MagicMock()
with self.assertRaises(AssertionError):
register_grad_ready(mock_self, param)
@mock.patch(
'mindspeed_llm.core.layerwise_disaggregated_training.distributed_data_parallel.parallel_state.is_pipeline_first_stage'
)
def test_register_grad_ready_not_last_microbatch(self, mock_is_pipeline_first_stage):
mock_is_pipeline_first_stage.return_value = False
mock_self = mock.MagicMock()
mock_self.ddp_config.overlap_grad_reduce = True
mock_self.is_last_microbatch = False
param = mock.MagicMock()
register_grad_ready(mock_self, param)
mock_self.params_with_grad.add.assert_not_called()
@mock.patch(
'mindspeed_llm.core.layerwise_disaggregated_training.distributed_data_parallel.parallel_state.is_pipeline_first_stage'
)
def test_register_grad_ready_param_not_in_bucket(self, mock_is_pipeline_first_stage):
mock_is_pipeline_first_stage.return_value = False
mock_self = mock.MagicMock()
mock_self.ddp_config.overlap_grad_reduce = True
mock_self.is_last_microbatch = True
mock_self.param_to_bucket = {}
param = mock.MagicMock()
with self.assertRaises(AssertionError):
register_grad_ready(mock_self, param)
@mock.patch(
'mindspeed_llm.core.layerwise_disaggregated_training.distributed_data_parallel.parallel_state.is_pipeline_first_stage'
)
def test_register_grad_ready_grad_twice(self, mock_is_pipeline_first_stage):
mock_is_pipeline_first_stage.return_value = False
mock_self = mock.MagicMock()
mock_self.ddp_config.overlap_grad_reduce = True
mock_self.is_last_microbatch = True
param = mock.MagicMock()
mock_self.param_to_bucket = {param: mock.MagicMock()}
mock_self.params_with_grad = {param}
with self.assertRaises(AssertionError):
register_grad_ready(mock_self, param)
@mock.patch(
'mindspeed_llm.core.layerwise_disaggregated_training.distributed_data_parallel.parallel_state.is_pipeline_first_stage'
)
def test_register_grad_ready_not_all_params_ready(self, mock_is_pipeline_first_stage):
mock_is_pipeline_first_stage.return_value = False
mock_self = mock.MagicMock()
mock_self.ddp_config.overlap_grad_reduce = True
mock_self.is_last_microbatch = True
param = mock.MagicMock()
mock_self.param_to_bucket = {param: mock.MagicMock()}
mock_self.params_with_grad = mock.MagicMock()
mock_self.params_with_grad.__len__.return_value = 1
mock_self.params = [param, mock.MagicMock()]
register_grad_ready(mock_self, param)
mock_self.params_with_grad.add.assert_called_once_with(param)
mock_self.start_grad_sync.assert_not_called()
@mock.patch(
'mindspeed_llm.core.layerwise_disaggregated_training.distributed_data_parallel.parallel_state.is_pipeline_first_stage'
)
def test_register_grad_ready_all_params_ready(self, mock_is_pipeline_first_stage):
mock_is_pipeline_first_stage.return_value = False
mock_self = mock.MagicMock()
mock_self.ddp_config.overlap_grad_reduce = True
mock_self.is_last_microbatch = True
param = mock.MagicMock()
mock_self.param_to_bucket = {param: mock.MagicMock()}
mock_self.params_with_grad = mock.MagicMock()
mock_self.params_with_grad.__len__.return_value = 1
mock_self.params = [param]
register_grad_ready(mock_self, param)
mock_self.params_with_grad.add.assert_called_once_with(param)
mock_self.start_grad_sync.assert_called_once()
if __name__ == '__main__':
unittest.main()