from mindspeed.features_manager.feature import MindSpeedFeature


class RecomputeActivationFeature(MindSpeedFeature):
    def __init__(self):
        super().__init__('recompute-activation-function')

    def register_args(self, parser):
        group = parser.add_argument_group(title=self.feature_name)
        group.add_argument('--recompute-activation-function', action='store_true',
                           help='Recompute the activation function in MLP layers.')
        group.add_argument('--recompute-activation-function-num-layers', type=int, default=None,
                           help='Can be used together with "--recompute-method block." '
                           'and "--recompute-num-layers". ')

    def validate_args(self, args):
        if args.recompute_activation_function_num_layers is not None:
            if not isinstance(args.recompute_activation_function_num_layers, int):
                raise TypeError('--recompute-activation-function-num-layers must be an integer.')
            if args.recompute_activation_function_num_layers < 0:
                raise AssertionError('--recompute-activation-function-num-layers cannot be less than 0.')
            if args.recompute_activation_function_num_layers > args.num_layers:
                raise ValueError(f'--recompute-activation-function-num-layers ({args.recompute_activation_function_num_layers}) '
                                            f'cannot be greater than --num-layers ({args.num_layers}).')

    def register_patches(self, patch_manager, args):
        from mindspeed.core.memory.recompute.activation.adaptor import mindspeed_activation_recompute_forward
        from mindspeed.core.transformer.transformer import parallel_transformer_layer_init_wrapper

        if getattr(args, self.feature_name, None):
            patch_manager.register_patch('megatron.core.transformer.transformer_layer.TransformerLayer.__init__',
                                          parallel_transformer_layer_init_wrapper)
            patch_manager.register_patch('megatron.core.transformer.mlp.MLP.forward', mindspeed_activation_recompute_forward)