"""Unit tests for VTP P2P communication functions."""
from unittest.mock import patch, MagicMock
import mindspeed.megatron_adaptor
import torch
from mindspeed_mm.patchs.layerwise_disaggregated_training import p2p_communication_patch as mod
from tests.ut.utils import judge_expression
class TestVtpSendForward:
def _setup_mocks(self, rank0=True):
mocks = {
"get_vtp_my_stage_idx": patch.object(mod, "get_vtp_my_stage_idx", return_value=1),
"pp_group": patch.object(mod, "get_pipeline_model_parallel_group", return_value=MagicMock()),
"next_rank": patch.object(mod, "get_pipeline_model_parallel_next_rank", return_value=5),
"prev_rank": patch.object(mod, "get_pipeline_model_parallel_prev_rank", return_value=0),
}
return mocks
def test_rank0_sends(self):
mocks = self._setup_mocks()
with mocks["get_vtp_my_stage_idx"], mocks["pp_group"], mocks["next_rank"], mocks["prev_rank"], \
patch.object(mod, "is_vtp_stage_rank0", return_value=True), \
patch.object(mod.torch.distributed, "isend", return_value=MagicMock()) as mock_isend:
result = mod.vtp_send_forward(torch.tensor([1.0]))
mock_isend.assert_called_once()
judge_expression(result is not None)
def test_non_rank0_returns_none(self):
with patch.object(mod, "is_vtp_stage_rank0", return_value=False):
result = mod.vtp_send_forward(torch.tensor([1.0]))
judge_expression(result is None)
class TestVtpSendBackward:
def _setup_mocks(self, rank0=True):
mocks = {
"get_vtp_my_stage_idx": patch.object(mod, "get_vtp_my_stage_idx", return_value=1),
"pp_group": patch.object(mod, "get_pipeline_model_parallel_group", return_value=MagicMock()),
"next_rank": patch.object(mod, "get_pipeline_model_parallel_next_rank", return_value=5),
"prev_rank": patch.object(mod, "get_pipeline_model_parallel_prev_rank", return_value=2),
}
return mocks
def test_rank0_sends(self):
mocks = self._setup_mocks()
with mocks["get_vtp_my_stage_idx"], mocks["pp_group"], mocks["next_rank"], mocks["prev_rank"], \
patch.object(mod, "is_vtp_stage_rank0", return_value=True), \
patch.object(mod.torch.distributed, "isend", return_value=MagicMock()) as mock_isend:
result = mod.vtp_send_backward(torch.tensor([1.0]))
mock_isend.assert_called_once()
judge_expression(result is not None)
def test_non_rank0_returns_none(self):
with patch.object(mod, "is_vtp_stage_rank0", return_value=False):
result = mod.vtp_send_backward(torch.tensor([1.0]))
judge_expression(result is None)
class TestVTPRecvWork:
def test_wait_with_irecv_and_broadcast(self):
irecv_work = MagicMock()
tensor = torch.tensor([1.0])
intra_group = MagicMock()
with patch.object(mod.torch.distributed, "broadcast") as mock_bcast:
work = mod._VTPRecvWork(irecv_work, tensor, broadcast_src=0, intra_group=intra_group, dst_size=4)
work.wait()
irecv_work.wait.assert_called_once()
mock_bcast.assert_called_once_with(tensor, src=0, group=intra_group)
def test_wait_no_irecv(self):
with patch.object(mod.torch.distributed, "broadcast") as mock_bcast:
work = mod._VTPRecvWork(None, torch.tensor([1.0]), broadcast_src=0, intra_group=MagicMock(), dst_size=2)
work.wait()
mock_bcast.assert_called_once()
def test_wait_dst_size_1_no_broadcast(self):
with patch.object(mod.torch.distributed, "broadcast") as mock_bcast:
work = mod._VTPRecvWork(MagicMock(), torch.tensor([1.0]), broadcast_src=0, intra_group=MagicMock(), dst_size=1)
work.wait()
mock_bcast.assert_not_called()
def test_wait_no_intra_group_no_broadcast(self):
with patch.object(mod.torch.distributed, "broadcast") as mock_bcast:
work = mod._VTPRecvWork(MagicMock(), torch.tensor([1.0]), broadcast_src=0, intra_group=None, dst_size=4)
work.wait()
mock_bcast.assert_not_called()
class TestVtpRecvForward:
def _setup_recv_mocks(self):
return {
"get_vtp_my_stage_idx": patch.object(mod, "get_vtp_my_stage_idx", return_value=1),
"get_vtp_stage_ranks": patch.object(mod, "get_vtp_stage_ranks", return_value=[[0], [1, 2, 3, 4]]),
"get_vtp_size_list": patch.object(mod, "get_vtp_size_list", return_value=[1, 4]),
"intra_group": patch.object(mod, "get_tensor_model_parallel_group", return_value=MagicMock()),
"pp_group": patch.object(mod, "get_pipeline_model_parallel_group", return_value=MagicMock()),
"prev_rank": patch.object(mod, "get_pipeline_model_parallel_prev_rank", return_value=0),
"next_rank": patch.object(mod, "get_pipeline_model_parallel_next_rank", return_value=1),
"current_device": patch.object(mod.torch.cuda, "current_device", return_value=0),
}
def test_sync_rank0_recv_and_broadcast(self):
mocks = self._setup_recv_mocks()
with mocks["get_vtp_my_stage_idx"], mocks["get_vtp_stage_ranks"], mocks["get_vtp_size_list"], \
mocks["intra_group"], mocks["pp_group"], mocks["prev_rank"], mocks["next_rank"], mocks["current_device"], \
patch.object(mod, "is_vtp_stage_rank0", return_value=True), \
patch.object(mod.torch.distributed, "irecv", return_value=MagicMock()) as mock_irecv, \
patch.object(mod.torch.distributed, "broadcast") as mock_bcast:
config = MagicMock()
config.pipeline_dtype = torch.float32
result = mod.vtp_recv_forward((2, 3), config, async_op=False)
mock_irecv.assert_called_once()
mock_bcast.assert_called_once()
judge_expression(isinstance(result, torch.Tensor))
def test_sync_non_rank0_broadcast_only(self):
mocks = self._setup_recv_mocks()
with mocks["get_vtp_my_stage_idx"], mocks["get_vtp_stage_ranks"], mocks["get_vtp_size_list"], \
mocks["intra_group"], mocks["pp_group"], mocks["prev_rank"], mocks["next_rank"], mocks["current_device"], \
patch.object(mod, "is_vtp_stage_rank0", return_value=False), \
patch.object(mod.torch.distributed, "irecv") as mock_irecv, \
patch.object(mod.torch.distributed, "broadcast") as mock_bcast:
config = MagicMock()
config.pipeline_dtype = torch.float32
mod.vtp_recv_forward((2, 3), config, async_op=False)
mock_irecv.assert_not_called()
mock_bcast.assert_called_once()
def test_async_returns_tensor_and_reqs(self):
mocks = self._setup_recv_mocks()
with mocks["get_vtp_my_stage_idx"], mocks["get_vtp_stage_ranks"], mocks["get_vtp_size_list"], \
mocks["intra_group"], mocks["pp_group"], mocks["prev_rank"], mocks["next_rank"], mocks["current_device"], \
patch.object(mod, "is_vtp_stage_rank0", return_value=True), \
patch.object(mod.torch.distributed, "irecv", return_value=MagicMock()):
config = MagicMock()
config.pipeline_dtype = torch.float32
tensor, reqs = mod.vtp_recv_forward((2, 3), config, async_op=True)
judge_expression(isinstance(tensor, torch.Tensor))
judge_expression("recv_prev" in reqs)
class TestVtpRecvBackward:
def _setup_recv_mocks(self):
return {
"get_vtp_my_stage_idx": patch.object(mod, "get_vtp_my_stage_idx", return_value=1),
"get_vtp_stage_ranks": patch.object(mod, "get_vtp_stage_ranks", return_value=[[0], [1, 2, 3, 4]]),
"get_vtp_size_list": patch.object(mod, "get_vtp_size_list", return_value=[1, 4]),
"intra_group": patch.object(mod, "get_tensor_model_parallel_group", return_value=MagicMock()),
"pp_group": patch.object(mod, "get_pipeline_model_parallel_group", return_value=MagicMock()),
"prev_rank": patch.object(mod, "get_pipeline_model_parallel_prev_rank", return_value=0),
"next_rank": patch.object(mod, "get_pipeline_model_parallel_next_rank", return_value=5),
"current_device": patch.object(mod.torch.cuda, "current_device", return_value=0),
}
def test_sync_rank0(self):
mocks = self._setup_recv_mocks()
with mocks["get_vtp_my_stage_idx"], mocks["get_vtp_stage_ranks"], mocks["get_vtp_size_list"], \
mocks["intra_group"], mocks["pp_group"], mocks["prev_rank"], mocks["next_rank"], mocks["current_device"], \
patch.object(mod, "is_vtp_stage_rank0", return_value=True), \
patch.object(mod.torch.distributed, "irecv", return_value=MagicMock()) as mock_irecv, \
patch.object(mod.torch.distributed, "broadcast") as mock_bcast:
config = MagicMock()
config.pipeline_dtype = torch.float32
mod.vtp_recv_backward((2, 3), config, async_op=False)
mock_irecv.assert_called_once()
mock_bcast.assert_called_once()
def test_async_returns_recv_next(self):
mocks = self._setup_recv_mocks()
with mocks["get_vtp_my_stage_idx"], mocks["get_vtp_stage_ranks"], mocks["get_vtp_size_list"], \
mocks["intra_group"], mocks["pp_group"], mocks["prev_rank"], mocks["next_rank"], mocks["current_device"], \
patch.object(mod, "is_vtp_stage_rank0", return_value=True), \
patch.object(mod.torch.distributed, "irecv", return_value=MagicMock()):
config = MagicMock()
config.pipeline_dtype = torch.float32
tensor, reqs = mod.vtp_recv_backward((2, 3), config, async_op=True)
judge_expression("recv_next" in reqs)
def test_sync_non_rank0(self):
mocks = self._setup_recv_mocks()
with mocks["get_vtp_my_stage_idx"], mocks["get_vtp_stage_ranks"], mocks["get_vtp_size_list"], \
mocks["intra_group"], mocks["pp_group"], mocks["prev_rank"], mocks["next_rank"], mocks["current_device"], \
patch.object(mod, "is_vtp_stage_rank0", return_value=False), \
patch.object(mod.torch.distributed, "irecv") as mock_irecv, \
patch.object(mod.torch.distributed, "broadcast") as mock_bcast:
config = MagicMock()
config.pipeline_dtype = torch.float32
mod.vtp_recv_backward((2, 3), config, async_op=False)
mock_irecv.assert_not_called()
mock_bcast.assert_called_once()