import warnings
import functools
import torch_npu._C
from torch_npu.utils.utils import _print_error_log
__all__ = ["mstx", "annotate"]
def _no_exception_func(default_ret=None):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
result = func(*args, **kwargs)
except Exception as ex:
_print_error_log(f"Call {func.__name__} failed. Exception: {str(ex)}")
return default_ret
return result
return wrapper
return decorator
class mstx:
@staticmethod
@_no_exception_func()
def mark(message: str, stream=None, domain: str = 'default'):
if not message or not isinstance(message, str):
warnings.warn("Invalid message for mstx.mark func. Please input valid message string.")
return
if not isinstance(domain, str):
warnings.warn("Invalid domain for mstx.mark func. Please input valid domain string.")
return
if stream:
if isinstance(stream, torch_npu.npu.streams.Stream):
torch_npu._C._mstx._mark(message,
stream.stream_id,
stream.device_index,
stream.device_type,
domain)
else:
warnings.warn("Invalid stream for mstx.mark func. Please input valid stream.")
return
else:
torch_npu._C._mstx._mark_on_host(message, domain)
@staticmethod
@_no_exception_func()
def range_push(message: str, stream=None, domain: str = 'default') -> int:
if not message or not isinstance(message, str):
warnings.warn("Invalid message for mstx.range_push func. Please input valid message string.")
return -1
if not domain or not isinstance(domain, str):
warnings.warn("Invalid domain for mstx.range_push func. Please input valid domain string.")
return -1
if stream:
if isinstance(stream, torch_npu.npu.streams.Stream):
return torch_npu._C._mstx._range_push(message,
stream.stream_id,
stream.device_index,
stream.device_type,
domain)
else:
warnings.warn("Invalid stream for mstx.range_push func. Please input valid stream.")
return -1
else:
return torch_npu._C._mstx._range_push_on_host(message, domain)
@staticmethod
@_no_exception_func()
def range_pop(domain: str = 'default') -> int:
if not domain or not isinstance(domain, str):
warnings.warn("Invalid domain for mstx.range_pop func. Please input valid domain string.")
return -1
return torch_npu._C._mstx._range_pop(domain)
@staticmethod
@_no_exception_func()
def range_start(message: str, stream=None, domain: str = 'default') -> int:
if not message or not isinstance(message, str):
warnings.warn("Invalid message for mstx.range_start func. Please input valid message string.")
return 0
if not domain or not isinstance(domain, str):
warnings.warn("Invalid domain for mstx.range_start func. Please input valid domain string.")
return 0
if stream:
if isinstance(stream, torch_npu.npu.streams.Stream):
return torch_npu._C._mstx._range_start(message,
stream.stream_id,
stream.device_index,
stream.device_type,
domain)
else:
warnings.warn("Invalid stream for mstx.range_start func. Please input valid stream.")
return 0
else:
return torch_npu._C._mstx._range_start_on_host(message, domain)
@staticmethod
@_no_exception_func()
def range_end(range_id: int, domain: str = 'default'):
if not isinstance(range_id, int):
warnings.warn("Invalid message for mstx.range_end func. Please input return value from mstx.range_start.")
return
if not domain or not isinstance(domain, str):
warnings.warn("Invalid domain for mstx.range_end func. Please input valid domain string.")
return
torch_npu._C._mstx._range_end(range_id, domain)
@staticmethod
@_no_exception_func()
def mstx_range(message: str, stream=None, domain: str = 'default'):
def wrapper(func):
def inner(*args, **kargs):
range_id = mstx.range_start(message, stream, domain)
ret = func(*args, **kargs)
mstx.range_end(range_id, domain)
return ret
return inner
return wrapper
class annotate:
def __init__(self, message: str = '', stream=None, domain: str = 'default'):
self.message = message
self.stream = stream
self.domain = domain
self.range_id = None
def __enter__(self):
self.range_id = mstx.range_start(self.message, self.stream, self.domain)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.range_id is not None:
mstx.range_end(self.range_id, self.domain)
self.range_id = None
def __call__(self, func):
if not self.message:
self.message = func.__name__
@functools.wraps(func)
def inner(*args, **kwargs):
range_id = mstx.range_start(self.message, self.stream, self.domain)
try:
result = func(*args, **kwargs)
finally:
mstx.range_end(range_id, self.domain)
return result
return inner
mstx.annotate = annotate