mirror of https://github.com/hpcaitech/ColossalAI
[logging] polish logger format (#543)
parent
1f90a3b129
commit
7d81b5b46e
|
@ -5,15 +5,18 @@ import colossalai
|
|||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
import inspect
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
|
||||
try:
|
||||
from rich.logging import RichHandler
|
||||
_FORMAT = 'colossalai - %(name)s - %(asctime)s %(levelname)s: %(message)s'
|
||||
logging.basicConfig(level=logging.INFO, format=_FORMAT, handlers=[RichHandler()])
|
||||
_FORMAT = 'colossalai - %(name)s - %(levelname)s: %(message)s'
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format=_FORMAT,
|
||||
handlers=[RichHandler(show_path=False, markup=True, rich_tracebacks=True)])
|
||||
except ImportError:
|
||||
_FORMAT = 'colossalai - %(name)s - %(asctime)s %(levelname)s: %(message)s'
|
||||
_FORMAT = 'colossalai - %(name)s - %(levelname)s: %(message)s'
|
||||
logging.basicConfig(level=logging.INFO, format=_FORMAT)
|
||||
|
||||
|
||||
|
@ -50,6 +53,19 @@ class DistributedLogger:
|
|||
self._logger = logging.getLogger(name)
|
||||
DistributedLogger.__instances[name] = self
|
||||
|
||||
@staticmethod
|
||||
def __get_call_info():
|
||||
stack = inspect.stack()
|
||||
|
||||
# stack[1] gives previous function ('info' in our case)
|
||||
# stack[2] gives before previous function and so on
|
||||
|
||||
fn = stack[2][1]
|
||||
ln = stack[2][2]
|
||||
func = stack[2][3]
|
||||
|
||||
return fn, ln, func
|
||||
|
||||
@staticmethod
|
||||
def _check_valid_logging_level(level: str):
|
||||
assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR'], 'found invalid logging level'
|
||||
|
@ -122,6 +138,8 @@ class DistributedLogger:
|
|||
:param ranks: List of parallel ranks
|
||||
:type ranks: list
|
||||
"""
|
||||
message_prefix = "{}:{} {}".format(*self.__get_call_info())
|
||||
self._log('info', message_prefix, parallel_mode, ranks)
|
||||
self._log('info', message, parallel_mode, ranks)
|
||||
|
||||
def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
|
||||
|
@ -134,6 +152,8 @@ class DistributedLogger:
|
|||
:param ranks: List of parallel ranks
|
||||
:type ranks: list
|
||||
"""
|
||||
message_prefix = "{}:{} {}".format(*self.__get_call_info())
|
||||
self._log('warning', message_prefix, parallel_mode, ranks)
|
||||
self._log('warning', message, parallel_mode, ranks)
|
||||
|
||||
def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
|
||||
|
@ -146,6 +166,8 @@ class DistributedLogger:
|
|||
:param ranks: List of parallel ranks
|
||||
:type ranks: list
|
||||
"""
|
||||
message_prefix = "{}:{} {}".format(*self.__get_call_info())
|
||||
self._log('debug', message_prefix, parallel_mode, ranks)
|
||||
self._log('debug', message, parallel_mode, ranks)
|
||||
|
||||
def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
|
||||
|
@ -158,4 +180,6 @@ class DistributedLogger:
|
|||
:param ranks: List of parallel ranks
|
||||
:type ranks: list
|
||||
"""
|
||||
message_prefix = "{}:{} {}".format(*self.__get_call_info())
|
||||
self._log('error', message_prefix, parallel_mode, ranks)
|
||||
self._log('error', message, parallel_mode, ranks)
|
||||
|
|
Loading…
Reference in New Issue