import os
import torch._logging._internal
from torch_npu import _C
def _set_logs():
"""
Propagate the results torch._logging.set_logs to the C++ layer.
.. note:: The ``TORCH_LOGS`` or ``TORCH_NPU_LOGS`` environment variable has complete precedence
over this function, so if it was set, this function does nothing.
"""
if os.environ.get('TORCH_LOGS', None) is not None or os.environ.get('TORCH_NPU_LOGS', None) is not None:
return
_C._logging._LogContext.GetInstance().setLogs(torch._logging._internal.log_state.log_qname_to_level)
def _trigger_set_logs_decorator(func):
def wrapper(*args, **kwargs):
result = func(*args, **kwargs)
_set_logs()
return result
return wrapper
def _logging_patch():
torch._logging.set_logs = _trigger_set_logs_decorator(torch._logging.set_logs)
def _add_logging_module():
torch._logging._internal.register_log("memory", "torch_npu.memory")
torch._logging._internal.register_log("dispatch", "torch_npu.dispatch")
torch._logging._internal.register_log("dispatch_time", "torch_npu.dispatch.time")
torch._logging._internal.register_log("silent", "torch_npu.silent_check")
torch._logging._internal.register_log("recovery", "torch_npu.recovery")
torch._logging._internal.register_log("op_plugin", "torch_npu.op_plugin")
torch._logging._internal.register_log("shmem", "torch_npu.symmetric_memory")
torch._logging._internal.register_log("env", "torch_npu.env")
torch._logging._internal.register_log("acl", "torch_npu.acl")
torch._logging._internal.register_log("aclgraph", "torch_npu.aclgraph")
torch._logging._internal.register_log("npugraph", "torch_npu.npugraph")
torch._logging._internal.register_log("cudagraphs", "torch_npu.npugraph")
def _update_log_state_from_env():
log_setting = os.environ.get("TORCH_NPU_LOGS", None)
if log_setting is not None:
torch._logging._internal.LOG_ENV_VAR = "TORCH_NPU_LOGS"
torch._logging._internal._init_logs()
_C._logging._LogContext.GetInstance().setLogs(torch._logging._internal.log_state.log_qname_to_level)
elif os.environ.get("TORCH_LOGS", None) is not None:
_C._logging._LogContext.GetInstance().setLogs(torch._logging._internal.log_state.log_qname_to_level)