import pytest
import torch
import torch_npu
import mindspeed.megatron_adaptor
import megatron.core.parallel_state as Utils
from tests_extend.unit_tests.common import DistributedTest
from megatron.training.global_vars import set_args
from megatron.training.arguments import parse_args
from megatron.core.tensor_parallel.layers import VocabParallelEmbedding, RowParallelLinear, ColumnParallelLinear
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec
class TestInitialization(DistributedTest):
world_size = 8
args = parse_args(None, True)
set_args(args)
transformer_config = TransformerConfig(num_layers=1, hidden_size=12,
num_attention_heads=4, use_cpu_initialization=True)
@pytest.mark.timeout(100)
def test_embedding_init(self):
Utils.initialize_model_parallel(1, 1)
torch.manual_seed(42)
model_parallel_cuda_manual_seed(42)
tp1 = VocabParallelEmbedding(num_embeddings=16, embedding_dim=4,
init_method=self.transformer_config.init_method,
config=self.transformer_config).weight
Utils.destroy_model_parallel()
Utils.initialize_model_parallel(4, 1)
torch.manual_seed(42)
model_parallel_cuda_manual_seed(41)
tp4 = VocabParallelEmbedding(num_embeddings=16, embedding_dim=4,
init_method=self.transformer_config.init_method,
config=self.transformer_config).weight
if torch.distributed.get_rank() == 0:
assert tp4.shape[0] * 4 == tp1.shape[0]
assert torch.allclose(tp1[:4], tp4)
Utils.destroy_model_parallel()
@pytest.mark.timeout(100)
def test_row_init(self):
Utils.initialize_model_parallel(1, 1)
torch.manual_seed(42)
model_parallel_cuda_manual_seed(42)
tp1 = RowParallelLinear(input_size=16, output_size=16,
init_method=self.transformer_config.init_method,
bias=True, input_is_parallel=False,
config=self.transformer_config,
skip_bias_add=False).weight
Utils.destroy_model_parallel()
Utils.initialize_model_parallel(4, 1)
torch.manual_seed(42)
model_parallel_cuda_manual_seed(41)
tp4 = RowParallelLinear(input_size=16, output_size=16,
init_method=self.transformer_config.init_method,
bias=True,
input_is_parallel=False,
config=self.transformer_config,
skip_bias_add=False).weight
if torch.distributed.get_rank() == 0:
assert tp4.shape[1] * 4 == tp1.shape[1]
assert torch.allclose(tp1[:, :4], tp4)
Utils.destroy_model_parallel()
@pytest.mark.timeout(100)
def test_col_init(self):
Utils.initialize_model_parallel(1, 1)
torch.manual_seed(42)
model_parallel_cuda_manual_seed(42)
tp1 = ColumnParallelLinear(input_size=16, output_size=16,
init_method=self.transformer_config.init_method,
bias=True, config=self.transformer_config,
skip_bias_add=False).weight
Utils.destroy_model_parallel()
Utils.initialize_model_parallel(4, 1)
torch.manual_seed(42)
model_parallel_cuda_manual_seed(41)
tp4 = ColumnParallelLinear(input_size=16, output_size=16,
init_method=self.transformer_config.init_method,
bias=True, config=self.transformer_config,
skip_bias_add=False).weight
if torch.distributed.get_rank() == 0:
assert tp4.shape[0] * 4 == tp1.shape[0]
assert torch.allclose(tp1[:4], tp4)
Utils.destroy_model_parallel()