import pytest
import torch
import torch_npu
import mindspeed.megatron_adaptor
from megatron.training.global_vars import set_args
from megatron.training.arguments import parse_args
from megatron.core.tensor_parallel.data import broadcast_data
import megatron.core.parallel_state as Utils
from tests_extend.unit_tests.common import DistributedTest
class TestTPData(DistributedTest):
world_size = 8
args = parse_args(None, True)
set_args(args)
def test_broadcast_data(self):
Utils.initialize_model_parallel(2, 4)
input_data = {
0: torch.ones((8, 8)).cuda() * 0.0,
1: torch.ones((8, 8)).cuda() * 1.0,
2: torch.ones((8, 8)).cuda() * 2.0,
3: torch.ones((8, 8)).cuda() * 3.0,
4: torch.ones((8, 8)).cuda() * 4.0,
5: torch.ones((8, 8)).cuda() * 5.0,
6: torch.ones((8, 8)).cuda() * 6.0,
7: torch.ones((8, 8)).cuda() * 7.0
}
dtype = torch.float32
actual_output = broadcast_data([0, 1], input_data, dtype)
assert(torch.equal(actual_output[0], input_data[0]))
assert(torch.equal(actual_output[1], input_data[1]))
Utils.destroy_model_parallel()