"""
Global logging configuration module with custom distributed logger (rank0 support).
"""
import logging
import sys
import os
import threading
from functools import lru_cache
from typing import Optional
_thread_lock = threading.Lock()
_default_handler = None
_initialized = False
class _Logger(logging.Logger):
"""Custom logger class with rank0-specific logging methods and plain message output.
This logger extends the standard logging.Logger to support:
1. Rank0-specific logging (only log on the main process in distributed scenarios)
2. One-time warning logging (avoid duplicate warning messages)
3. Plain message logging (output without level/time prefix)
"""
def info_rank0(self, msg: str, *args, **kwargs) -> None:
if int(os.getenv("LOCAL_RANK", "0")) == 0:
kwargs["stacklevel"] = kwargs.get("stacklevel", 1) + 1
self.info(msg, *args, **kwargs)
def warn_rank0(self, msg: str, *args, **kwargs) -> None:
if int(os.getenv("LOCAL_RANK", "0")) == 0:
kwargs["stacklevel"] = kwargs.get("stacklevel", 1) + 1
self.warning(msg, *args, **kwargs)
@lru_cache(None)
def warn_once(self, msg: str, *args, **kwargs) -> None:
"""Log warning level message only once (lifetime) and only on the process with LOCAL_RANK=0.
"""
if int(os.getenv("LOCAL_RANK", "0")) == 0:
kwargs["stacklevel"] = kwargs.get("stacklevel", 1) + 1
self.warning_rank0(msg, *args, **kwargs)
def debug_rank0(self, msg: str, *args, **kwargs) -> None:
if int(os.getenv("LOCAL_RANK", "0")) == 0:
kwargs["stacklevel"] = kwargs.get("stacklevel", 1) + 1
self.debug(msg, *args, **kwargs)
def log_plain(self, msg: str, *args, **kwargs) -> None:
"""Core method for plain message logging without any prefix.
Writes directly to stdout to avoid the standard logging format.
Args:
msg (str): Raw message to be logged (supports string formatting with args).
"""
formatted_msg = msg % args if args else msg
sys.stdout.write(f"{formatted_msg}\n")
sys.stdout.flush()
def info_plain(self, msg: str, *args, **kwargs) -> None:
self.log_plain(msg, *args, **kwargs)
def info_plain_rank0(self, msg: str, *args, **kwargs) -> None:
if int(os.getenv("LOCAL_RANK", "0")) == 0:
self.log_plain(msg, *args, **kwargs)
logging.setLoggerClass(_Logger)
def _get_library_name() -> str:
"""Get the root name of the library (first part of the module name).
Returns:
str: Root name of the current library module.
"""
return __name__.split(".")[0]
def _get_library_root_logger() -> _Logger:
"""Get the root logger instance of the library (custom _Logger type).
Returns:
_Logger: Custom logger instance for the library root.
"""
return logging.getLogger(_get_library_name())
def setup_global_logging(level: str = "INFO") -> None:
"""Configure the global logging system with rank0 support and dynamic format.
This function initializes the logging system with thread-safe configuration,
dynamic log format (DEBUG level includes path/line number), and avoids duplicate handlers.
It should be called only once at the project startup (main entry file).
Args:
level (str, optional): Log level (DEBUG/INFO/WARNING/ERROR/CRITICAL). Defaults to "INFO".
"""
global _default_handler, _initialized
if _initialized:
return
valid_levels = {
"DEBUG": logging.DEBUG,
"INFO": logging.INFO,
"WARNING": logging.WARNING,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL
}
log_level = valid_levels.get(level.strip().upper(), logging.INFO)
if log_level == logging.DEBUG:
log_format = "%(levelname)s [%(asctime)s %(pathname)s:%(lineno)s] >> %(message)s"
else:
log_format = "%(levelname)s [%(asctime)s] >> %(message)s"
with _thread_lock:
logging.root.handlers.clear()
logging.basicConfig(
level=log_level,
format=log_format,
datefmt="%Y-%m-%d %H:%M:%S",
stream=sys.stdout,
force=True
)
library_root_logger = _get_library_root_logger()
library_root_logger.handlers.clear()
if not _default_handler:
formatter = logging.Formatter(
fmt=log_format,
datefmt="%Y-%m-%d %H:%M:%S"
)
_default_handler = logging.StreamHandler(sys.stdout)
_default_handler.setFormatter(formatter)
library_root_logger.addHandler(_default_handler)
library_root_logger.setLevel(log_level)
library_root_logger.propagate = False
_initialized = True
library_root_logger.info_rank0(
f"Global logging initialized | Level: {level.upper()} | Local Rank: {os.getenv('LOCAL_RANK', '0')}"
)
def get_logger(name: Optional[str] = None, level: Optional[str] = "INFO") -> _Logger:
"""Get a custom _Logger instance with rank0 support.
Ensures the global logging configuration is applied before returning the logger instance.
Uses INFO as the default log level if not specified, or overrides with the given level.
Args:
name (Optional[str], optional): Name of the logger, defaults to library root name if None.
level (Optional[str], optional): Log level for the logger (DEBUG/INFO/WARNING/ERROR/CRITICAL).
Defaults to "INFO".
Returns:
_Logger: Custom logger instance with rank0 logging methods.
"""
if name is None:
name = _get_library_name()
setup_global_logging(level=level)
return logging.getLogger(name)