from functools import wraps
from torch_npu.utils.flops_count import FlopsCounter
FLOPS_COUNTER = None
RECORDED_COUNT = 0
TRAVERSED_COUNT = 0
def get_count():
"""Get recorded and traversed FLOPS counts."""
if RECORDED_COUNT == 0 and TRAVERSED_COUNT == 0:
flops_counter = get_flops_counter()
count = flops_counter.get_flops()
set_count(count)
return RECORDED_COUNT, TRAVERSED_COUNT
def set_count(count):
"""Set the FLOPS counts."""
global RECORDED_COUNT
global TRAVERSED_COUNT
RECORDED_COUNT = count[0]
TRAVERSED_COUNT = count[1]
def get_flops_counter():
"""Get or initialize the FLOPS counter."""
global FLOPS_COUNTER
if FLOPS_COUNTER is None:
FLOPS_COUNTER = FlopsCounter()
return FLOPS_COUNTER
def checkpoint_function_backward_wrapper(fn):
@wraps(fn)
def wrapper(ctx, *args):
flops_counter = get_flops_counter()
flops_counter.pause()
result = fn(ctx, *args)
flops_counter.resume()
return result
return wrapper
def train_step_wrapper(train_step):
@wraps(train_step)
def wrapper(*args, **kwargs):
flop_count = get_flops_counter()
flop_count.start()
ret = train_step(*args, **kwargs)
counts = flop_count.get_flops()
set_count(counts)
flop_count.stop()
return ret
return wrapper