import warnings
import torch_npu

__all__ = []


class _FlopsCounter:
    def __init__(self, ):
        self.flop_count_instance = torch_npu._C._flops_count._FlopCountContext.GetInstance()
     
    def __enter__(self):
        self.count_enable()
    
    def __exit__(self):
        self.count_disable()

    def start(self):
        self.flop_count_instance.enable()

    def stop(self):
        self.flop_count_instance.disable()
        self.flop_count_instance.reset()
    
    def pause(self):
        self.flop_count_instance.pause()

    def resume(self):
        self.flop_count_instance.resume()

    def get_flops(self):
        recorded_count = self.flop_count_instance.recordedCount
        traversed_count = self.flop_count_instance.traversedCount
        return [recorded_count, traversed_count]


class FlopsCounter(_FlopsCounter):
    def __init__(self):
        super().__init__()
        warnings.warn("torch_npu.utils.flops_count.FlopsCounter() will be deprecated. "
                      "If necessary, please use torch_npu.utils.FlopsCounter().", FutureWarning)