"""Unit tests for VTP utility functions."""
import math
from unittest.mock import patch, MagicMock
import mindspeed.megatron_adaptor
import pytest
import torch
from mindspeed_mm.patchs.layerwise_disaggregated_training import utils_patch as mod
from tests.ut.utils import judge_expression
_PATCH_CUDA = patch("torch.cuda.current_device", return_value="cpu")
_REQUIRES_CUDA = pytest.mark.skipif(
not torch.cuda.is_available(), reason="CUDA/NPU required"
)
class TestAllreduceModelParallel:
def test_vtp_enabled_calls_vtp_allreduce(self):
with patch.object(mod, "is_vtp_enabled", return_value=True), \
patch.object(mod, "is_vdp_enabled", return_value=False), \
patch.object(mod, "vtp_allreduce") as mock_ar:
tensor = torch.tensor([1.0])
mod._ldt_allreduce_model_parallel(tensor, op=torch.distributed.ReduceOp.SUM, group=MagicMock())
mock_ar.assert_called_once_with(tensor, op=torch.distributed.ReduceOp.SUM)
def test_vtp_disabled_calls_standard_allreduce(self):
with patch.object(mod, "is_vtp_enabled", return_value=False), \
patch("torch.distributed.all_reduce") as mock_ar:
tensor = torch.tensor([1.0])
group = MagicMock()
mod._ldt_allreduce_model_parallel(tensor, op=torch.distributed.ReduceOp.MAX, group=group)
mock_ar.assert_called_once_with(tensor, op=torch.distributed.ReduceOp.MAX, group=group)
class TestLdtReduceMaxStat:
def test_none_input_returns_none(self):
with _PATCH_CUDA, \
patch.object(mod, "_ldt_allreduce_model_parallel"), \
patch.object(mod.mpu, "get_tensor_model_parallel_group", return_value=MagicMock()):
result = mod.ldt_reduce_max_stat_across_model_parallel_group(None)
judge_expression(result is None)
def test_valid_stat_returns_value(self):
def set_tensor(tensor, op, group=None):
tensor.fill_(5.0)
with _PATCH_CUDA, \
patch.object(mod, "_ldt_allreduce_model_parallel", side_effect=set_tensor), \
patch.object(mod.mpu, "get_tensor_model_parallel_group", return_value=MagicMock()):
result = mod.ldt_reduce_max_stat_across_model_parallel_group(3.0)
judge_expression(result == 5.0)
class TestLdtLogicalAnd:
def test_true_input(self):
with _PATCH_CUDA, \
patch.object(mod, "_ldt_allreduce_model_parallel"), \
patch.object(mod.mpu, "get_tensor_model_parallel_group", return_value=MagicMock()):
result = mod.ldt_logical_and_across_model_parallel_group(True)
judge_expression(result is True)
def test_false_input(self):
with _PATCH_CUDA, \
patch.object(mod, "_ldt_allreduce_model_parallel"), \
patch.object(mod.mpu, "get_tensor_model_parallel_group", return_value=MagicMock()):
result = mod.ldt_logical_and_across_model_parallel_group(False)
judge_expression(result is False)
class TestLdtGetGradNormFp32:
def test_single_tensor_wrapped_to_list(self):
with _PATCH_CUDA, \
patch.object(mod, "is_vtp_enabled", return_value=False), \
patch("torch.distributed.all_reduce"), \
patch.object(mod, "get_data_parallel_group_if_dtensor", return_value=None), \
patch.object(mod, "to_local_if_dtensor", side_effect=lambda x: x):
result = mod.ldt_get_grad_norm_fp32(torch.tensor([3.0, 4.0]), norm_type=float('inf'))
judge_expression(isinstance(result, float))
def test_norm_type_inf(self):
with _PATCH_CUDA, \
patch.object(mod, "is_vtp_enabled", return_value=False), \
patch("torch.distributed.all_reduce"), \
patch.object(mod, "get_data_parallel_group_if_dtensor", return_value=None), \
patch.object(mod, "to_local_if_dtensor", side_effect=lambda x: x):
result = mod.ldt_get_grad_norm_fp32([torch.tensor([3.0, -4.0])], norm_type=float('inf'))
judge_expression(math.isclose(result, 4.0))
def test_norm_type_2(self):
with _PATCH_CUDA, \
patch.object(mod, "is_vtp_enabled", return_value=True), \
patch.object(mod, "is_vdp_enabled", return_value=False), \
patch.object(mod, "_ldt_allreduce_model_parallel") as mock_ar, \
patch("torch.distributed.all_reduce"), \
patch.object(mod, "get_data_parallel_group_if_dtensor", return_value=None), \
patch.object(mod, "get_vdp_size", return_value=1), \
patch.object(mod.mpu, "is_pipeline_first_stage", return_value=True), \
patch.object(mod, "to_local_if_dtensor", side_effect=lambda x: x):
result = mod.ldt_get_grad_norm_fp32([torch.tensor([3.0, 4.0])], norm_type=2)
judge_expression(isinstance(result, float))
judge_expression(mock_ar.call_args[1]['op'] == torch.distributed.ReduceOp.SUM)
def test_norm_type_other(self):
with _PATCH_CUDA, \
patch.object(mod, "is_vtp_enabled", return_value=True), \
patch.object(mod, "is_vdp_enabled", return_value=False), \
patch.object(mod, "_ldt_allreduce_model_parallel"), \
patch("torch.distributed.all_reduce"), \
patch.object(mod, "get_data_parallel_group_if_dtensor", return_value=None), \
patch.object(mod, "get_vdp_size", return_value=1), \
patch.object(mod.mpu, "is_pipeline_first_stage", return_value=True), \
patch.object(mod, "to_local_if_dtensor", side_effect=lambda x: x):
result = mod.ldt_get_grad_norm_fp32([torch.tensor([3.0, 4.0])], norm_type=3)
judge_expression(isinstance(result, float))
def test_with_data_parallel_group(self):
with _PATCH_CUDA, \
patch.object(mod, "is_vtp_enabled", return_value=False), \
patch("torch.distributed.all_reduce") as mock_dist_ar, \
patch.object(mod, "get_data_parallel_group_if_dtensor", return_value=MagicMock()), \
patch.object(mod, "to_local_if_dtensor", side_effect=lambda x: x):
mod.ldt_get_grad_norm_fp32([torch.tensor([3.0, -4.0])], norm_type=float('inf'))
mock_dist_ar.assert_called()
judge_expression(mock_dist_ar.call_count >= 1)
def test_empty_grads_norm_type_2(self):
with _PATCH_CUDA, \
patch.object(mod, "is_vtp_enabled", return_value=True), \
patch.object(mod, "is_vdp_enabled", return_value=False), \
patch.object(mod, "_ldt_allreduce_model_parallel"), \
patch("torch.distributed.all_reduce"), \
patch.object(mod, "get_data_parallel_group_if_dtensor", return_value=None), \
patch.object(mod, "get_vdp_size", return_value=1), \
patch.object(mod.mpu, "is_pipeline_first_stage", return_value=True), \
patch.object(mod, "to_local_if_dtensor", side_effect=lambda x: x):
result = mod.ldt_get_grad_norm_fp32([], norm_type=2)
judge_expression(isinstance(result, float))
class TestLdtVdpBarrierWrapper:
def test_vtp_no_group_uses_hierarchical(self):
with patch.object(mod, "is_vtp_enabled", return_value=True), \
patch.object(mod, "is_vdp_enabled", return_value=False), \
patch.object(mod, "vtp_hierarchical_barrier") as mock_barrier:
wrapper = mod.ldt_vdp_barrier_wrapper(MagicMock())
result = wrapper(group=None)
mock_barrier.assert_called_once()
judge_expression(result is None)
def test_vtp_with_group_uses_original(self):
original = MagicMock(return_value="ok")
with patch.object(mod, "is_vtp_enabled", return_value=True), \
patch.object(mod, "is_vdp_enabled", return_value=False):
wrapper = mod.ldt_vdp_barrier_wrapper(original)
result = wrapper(group=MagicMock())
original.assert_called_once()
judge_expression(result == "ok")
def test_non_vtp_uses_original(self):
original = MagicMock(return_value="ok")
with patch.object(mod, "is_vtp_enabled", return_value=False), \
patch.object(mod, "is_vdp_enabled", return_value=False):
wrapper = mod.ldt_vdp_barrier_wrapper(original)
wrapper(group=None)
original.assert_called_once_with(group=None)
class TestVtpAllGatherWrapper:
def test_vtp_no_group_copies(self):
original = MagicMock(return_value=None)
with patch.object(mod, "is_vtp_enabled", return_value=True):
wrapper = mod.vtp_all_gather_into_tensor_wrapper(original)
output = torch.zeros(4)
input_t = torch.ones(4)
result = wrapper(output, input_t, group=None)
judge_expression(result is None)
original.assert_called_once_with(output, input_t, group=None, async_op=False)
def test_vtp_with_group_uses_original(self):
original = MagicMock(return_value="ok")
with patch.object(mod, "is_vtp_enabled", return_value=True):
wrapper = mod.vtp_all_gather_into_tensor_wrapper(original)
wrapper(torch.zeros(4), torch.ones(4), group=MagicMock())
original.assert_called_once()
def test_non_vtp_uses_original(self):
original = MagicMock(return_value="ok")
with patch.object(mod, "is_vtp_enabled", return_value=False):
wrapper = mod.vtp_all_gather_into_tensor_wrapper(original)
wrapper(torch.zeros(4), torch.ones(4), group=None, async_op=True)
original.assert_called_once()