import types
import torch
from mindspeed import megatron_adaptor
from tests_extend.commons import set_random_seed, initialize_model_parallel
from tests_extend.unit_tests.common import DistributedTest
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec
from megatron.training.global_vars import set_args
from megatron.training.arguments import parse_args
from mindspeed.core.memory.recompute.norm.adaptor import mindspeed_norm_recompute_forward
from mindspeed.model.transformer import set_attention_mask
class TestNormRecompute(DistributedTest):
world_size = 8
args = parse_args(None, True)
set_args(args)
def test_norm_recompute(self):
args = parse_args(None, True)
args.recompute_norm = True
args.num_layers = 4
args.recompute_norm_num_layers = 2
args.pipeline_model_parallel_size = 2
args.pipeline_dtype = torch.float32
args.num_query_groups = None
set_args(args)
self.norm_recopute()
def norm_recopute(self):
initialize_model_parallel(2, 2)
model_parallel_cuda_manual_seed(312)
config = TransformerConfig(num_layers=4, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True)
config.hidden_dropout = 0
config.attention_dropout = 0
config.gradient_accumulation_fusion = False
transformer_block_ref = TransformerBlock(config, get_gpt_layer_local_spec(), post_layer_norm=True)
transformer_block_test = TransformerBlock(config, get_gpt_layer_local_spec(), post_layer_norm=True)
transformer_block_test.load_state_dict(transformer_block_ref.state_dict().copy())
for layer in transformer_block_test.layers:
layer.forward = types.MethodType(mindspeed_norm_recompute_forward, layer)
sequence_length = 32
micro_batch_size = 2
transformer_block_ref.cuda()
transformer_block_test.cuda()
hidden_states_ref = torch.rand((sequence_length, micro_batch_size, config.hidden_size)).cuda()
hidden_states_ref.requires_grad = True
hidden_states_test = hidden_states_ref.clone().detach()
hidden_states_test.requires_grad = True
attention_mask = torch.zeros((1, 1, sequence_length, sequence_length), dtype=bool).cuda()
set_attention_mask(attention_mask)
out_ref = transformer_block_ref(hidden_states=hidden_states_ref, attention_mask=attention_mask)
out_test = transformer_block_test(hidden_states=hidden_states_test, attention_mask=attention_mask)
assert(torch.allclose(out_ref, out_test))
out_ref.backward(torch.ones_like(out_ref))
out_test.backward(torch.ones_like(out_ref))
assert(torch.allclose(hidden_states_ref.grad, hidden_states_test.grad))