from functools import wraps

import torch
from megatron.training.global_vars import get_args

from .hccl_operator import MOEOrMLPStartOp, MOEOrMLPEndOp


def moelayer_forward_decorator(fn):

    @wraps
    def wrapper(*args, **kwargs):
        prof_file = get_args().prof_file
        if prof_file:
            args[1] = MOEOrMLPStartOp.apply(args[1])
            activation_func_1 = torch.nn.Softplus()
            args[1] = activation_func_1(args[1])

            output, mlp_bias = fn(*args, **kwargs)

            activation_func_2 = torch.nn.Softshrink()
            output = activation_func_2(output)
            output = MOEOrMLPEndOp.apply(output)
            return output, mlp_bias
        
        return fn(*args, **kwargs)
    
    return wrapper