"""Configuration of logging."""
import os
from logging.config import dictConfig
from logging import Logger
from typing import Callable
import torch
LOG_LEVEL = os.getenv("MIND_SPEED_LOG_LEVEL", "INFO")
RANK = os.getenv("RANK", 0)
LOCAL_RANK = os.getenv("LOCAL_RANK", 0)
LOGGING_CONFIG = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"standard": {
"format": f"[Rank {RANK} | Local Rank {LOCAL_RANK}] %(asctime)s "
"%(levelname)s [%(name)s:%(lineno)d] => %(message)s",
}
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"level": f"{LOG_LEVEL}",
"formatter": "standard",
"stream": "ext://sys.stdout",
},
},
"root": {
"handlers": ["console"],
"level": f"{LOG_LEVEL}",
},
}
_warned_messages = set()
def set_log_config():
"""Make log config effect."""
dictConfig(LOGGING_CONFIG)
def log_rank_0(log: Callable, message: str):
"""If distributed is initialized, Log only in rank 0.
Args:
log (Logger): A function which can log message.
such as:
```python
LOG = getLogger(__name__)
log_rank_0(LOG.INFO, "message")
```
message (str): The log message.
"""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
log(message)
else:
log(message)
def log_warning_once(logger, message):
"""
Logs a warning message only once. Subsequent calls with the same message
will be ignored.
"""
if message not in _warned_messages:
logger.warning(message)
_warned_messages.add(message)