[logging] polish logger format (#543)

pull/539/head^2
Jiarui Fang 2022-03-29 10:37:11 +08:00 committed by GitHub
parent 1f90a3b129
commit 7d81b5b46e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 27 additions and 3 deletions

View File

@ -5,15 +5,18 @@ import colossalai
import logging
from pathlib import Path
from typing import Union
import inspect
from colossalai.context.parallel_mode import ParallelMode
try:
from rich.logging import RichHandler
_FORMAT = 'colossalai - %(name)s - %(asctime)s %(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=_FORMAT, handlers=[RichHandler()])
_FORMAT = 'colossalai - %(name)s - %(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO,
format=_FORMAT,
handlers=[RichHandler(show_path=False, markup=True, rich_tracebacks=True)])
except ImportError:
_FORMAT = 'colossalai - %(name)s - %(asctime)s %(levelname)s: %(message)s'
_FORMAT = 'colossalai - %(name)s - %(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=_FORMAT)
@ -50,6 +53,19 @@ class DistributedLogger:
self._logger = logging.getLogger(name)
DistributedLogger.__instances[name] = self
@staticmethod
def __get_call_info():
stack = inspect.stack()
# stack[1] gives previous function ('info' in our case)
# stack[2] gives before previous function and so on
fn = stack[2][1]
ln = stack[2][2]
func = stack[2][3]
return fn, ln, func
@staticmethod
def _check_valid_logging_level(level: str):
assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR'], 'found invalid logging level'
@ -122,6 +138,8 @@ class DistributedLogger:
:param ranks: List of parallel ranks
:type ranks: list
"""
message_prefix = "{}:{} {}".format(*self.__get_call_info())
self._log('info', message_prefix, parallel_mode, ranks)
self._log('info', message, parallel_mode, ranks)
def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
@ -134,6 +152,8 @@ class DistributedLogger:
:param ranks: List of parallel ranks
:type ranks: list
"""
message_prefix = "{}:{} {}".format(*self.__get_call_info())
self._log('warning', message_prefix, parallel_mode, ranks)
self._log('warning', message, parallel_mode, ranks)
def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
@ -146,6 +166,8 @@ class DistributedLogger:
:param ranks: List of parallel ranks
:type ranks: list
"""
message_prefix = "{}:{} {}".format(*self.__get_call_info())
self._log('debug', message_prefix, parallel_mode, ranks)
self._log('debug', message, parallel_mode, ranks)
def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
@ -158,4 +180,6 @@ class DistributedLogger:
:param ranks: List of parallel ranks
:type ranks: list
"""
message_prefix = "{}:{} {}".format(*self.__get_call_info())
self._log('error', message_prefix, parallel_mode, ranks)
self._log('error', message, parallel_mode, ranks)