import types
import torch
from mindspeed import megatron_adaptor
from mindspeed.model.transformer import set_attention_mask
from mindspeed.features_manager.recompute.activation_function import RecomputeActivationFeature
from mindspeed.patch_utils import MindSpeedPatchesManager as pm
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from megatron.training.global_vars import set_args
from megatron.training.arguments import parse_args
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from tests_extend.commons import initialize_model_parallel
from tests_extend.unit_tests.common import DistributedTest
class ActRecomputeFeatureTset:
@staticmethod
def act_recompute_register_patch(patch_manager):
args = parse_args(None, True)
args.num_layers = 4
args.recompute_activation_function = True
args.recompute_activation_function_num_layers = 2
args.num_query_groups = None
set_args(args)
recompute_activation_feature_func = RecomputeActivationFeature()
recompute_activation_feature_func.validate_args(args)
recompute_activation_feature_func.register_patches(patch_manager, args)
patch_manager.apply_patches()
@staticmethod
def del_recompute_register_patch(patch_manager):
patch_manager.register_patch('megatron.core.transformer.mlp.MLP.forward', MLP.forward, force_patch=True)
patch_manager.remove_wrappers('megatron.core.transformer.transformer_layer.TransformerLayer.__init__', 'parallel_transformer_layer_init_wrapper')
patch_manager.apply_patches()
class TestActivationRecompute(DistributedTest):
world_size = 8
def test_activation_recompute(self):
ActRecomputeFeatureTset().act_recompute_register_patch(pm)
self.activation_recopute()
ActRecomputeFeatureTset().del_recompute_register_patch(pm)
def activation_recopute(self):
initialize_model_parallel(2, 2)
model_parallel_cuda_manual_seed(312)
submodules = MLPSubmodules(
linear_fc1=ColumnParallelLinear,
linear_fc2=RowParallelLinear,
)
config = TransformerConfig(num_layers=4, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True)
config.hidden_dropout = 0
config.attention_dropout = 0
config.layer_number = 4
config.pipeline_model_parallel_size = 2
config.gradient_accumulation_fusion = False
transformer_block_ref = MLP(config, submodules=submodules)
transformer_block_test = MLP(config, submodules=submodules)
transformer_block_test.load_state_dict(transformer_block_ref.state_dict().copy())
sequence_length = 32
micro_batch_size = 2
transformer_block_ref.cuda()
transformer_block_test.cuda()
hidden_states_ref = torch.rand((sequence_length, micro_batch_size, config.hidden_size)).cuda()
hidden_states_ref.requires_grad = True
hidden_states_test = hidden_states_ref.clone().detach()
hidden_states_test.requires_grad = True
attention_mask = torch.zeros((1, 1, sequence_length, sequence_length), dtype=bool).cuda()
set_attention_mask(attention_mask)
transformer_block_ref.layer_number = 4
out_ref, _ = transformer_block_ref(hidden_states=hidden_states_ref)
out_test, _ = transformer_block_test(hidden_states=hidden_states_test)
assert(torch.allclose(out_ref, out_test))
out_ref.backward(torch.ones_like(out_ref))
out_test.backward(torch.ones_like(out_ref))
assert(torch.allclose(hidden_states_ref.grad, hidden_states_test.grad))