import os
from pathlib import Path
from functools import wraps
from logging import getLogger
import torch
import torch_npu
from msprobe.pytorch import PrecisionDebugger
current_dir = Path(__file__).parent.resolve()
MSTT_DEBUGGER = PrecisionDebugger(config_path=str(current_dir) + '/config.json')
LOG = getLogger(__name__)
def dump_start_wrapper(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
MSTT_DEBUGGER.start(model=args[2])
return fn(*args, **kwargs)
return wrapper
def dump_end_wrapper(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
fn(*args, **kwargs)
MSTT_DEBUGGER.stop()
MSTT_DEBUGGER.step()
return wrapper