import time
from functools import wraps

import torch
from megatron.training import get_args

from mindspeed.auto_settings.module.black.auto_patch import AutoPatcher


def compute_language_model_loss_wrapper(fn):

    @wraps(fn)
    def wrapper(*args, **kwargs):
        prof_file = get_args().prof_file
        if prof_file:
            auto_profiler = AutoPatcher(prof_file)

            if 'loss' not in auto_profiler.context:
                auto_profiler.context['loss'] = {}
                
            torch.cuda.synchronize()
            used_mem, _ = auto_profiler.get_memory_status()
            start_time = time.time()
            loss = fn(*args, **kwargs)
            torch.cuda.synchronize()
            loss_time = (time.time() - start_time) * 1000
            cur_used_mem, cur_max_mem = auto_profiler.get_memory_status()
            auto_profiler.context['loss']['time'] = loss_time
            auto_profiler.context['loss']['memory'] = (cur_used_mem - used_mem) / auto_profiler.unit_gb
            auto_profiler.context['loss']['max_memory'] = (cur_max_mem - used_mem) / auto_profiler.unit_gb
            return loss
        
        return fn(*args, **kwargs)
    
    return wrapper