mirror of https://github.com/hpcaitech/ColossalAI
[logging] polish logger format (#543)
parent
1f90a3b129
commit
7d81b5b46e
|
@ -5,15 +5,18 @@ import colossalai
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
import inspect
|
||||||
|
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from rich.logging import RichHandler
|
from rich.logging import RichHandler
|
||||||
_FORMAT = 'colossalai - %(name)s - %(asctime)s %(levelname)s: %(message)s'
|
_FORMAT = 'colossalai - %(name)s - %(levelname)s: %(message)s'
|
||||||
logging.basicConfig(level=logging.INFO, format=_FORMAT, handlers=[RichHandler()])
|
logging.basicConfig(level=logging.INFO,
|
||||||
|
format=_FORMAT,
|
||||||
|
handlers=[RichHandler(show_path=False, markup=True, rich_tracebacks=True)])
|
||||||
except ImportError:
|
except ImportError:
|
||||||
_FORMAT = 'colossalai - %(name)s - %(asctime)s %(levelname)s: %(message)s'
|
_FORMAT = 'colossalai - %(name)s - %(levelname)s: %(message)s'
|
||||||
logging.basicConfig(level=logging.INFO, format=_FORMAT)
|
logging.basicConfig(level=logging.INFO, format=_FORMAT)
|
||||||
|
|
||||||
|
|
||||||
|
@ -50,6 +53,19 @@ class DistributedLogger:
|
||||||
self._logger = logging.getLogger(name)
|
self._logger = logging.getLogger(name)
|
||||||
DistributedLogger.__instances[name] = self
|
DistributedLogger.__instances[name] = self
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __get_call_info():
|
||||||
|
stack = inspect.stack()
|
||||||
|
|
||||||
|
# stack[1] gives previous function ('info' in our case)
|
||||||
|
# stack[2] gives before previous function and so on
|
||||||
|
|
||||||
|
fn = stack[2][1]
|
||||||
|
ln = stack[2][2]
|
||||||
|
func = stack[2][3]
|
||||||
|
|
||||||
|
return fn, ln, func
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _check_valid_logging_level(level: str):
|
def _check_valid_logging_level(level: str):
|
||||||
assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR'], 'found invalid logging level'
|
assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR'], 'found invalid logging level'
|
||||||
|
@ -122,6 +138,8 @@ class DistributedLogger:
|
||||||
:param ranks: List of parallel ranks
|
:param ranks: List of parallel ranks
|
||||||
:type ranks: list
|
:type ranks: list
|
||||||
"""
|
"""
|
||||||
|
message_prefix = "{}:{} {}".format(*self.__get_call_info())
|
||||||
|
self._log('info', message_prefix, parallel_mode, ranks)
|
||||||
self._log('info', message, parallel_mode, ranks)
|
self._log('info', message, parallel_mode, ranks)
|
||||||
|
|
||||||
def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
|
def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
|
||||||
|
@ -134,6 +152,8 @@ class DistributedLogger:
|
||||||
:param ranks: List of parallel ranks
|
:param ranks: List of parallel ranks
|
||||||
:type ranks: list
|
:type ranks: list
|
||||||
"""
|
"""
|
||||||
|
message_prefix = "{}:{} {}".format(*self.__get_call_info())
|
||||||
|
self._log('warning', message_prefix, parallel_mode, ranks)
|
||||||
self._log('warning', message, parallel_mode, ranks)
|
self._log('warning', message, parallel_mode, ranks)
|
||||||
|
|
||||||
def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
|
def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
|
||||||
|
@ -146,6 +166,8 @@ class DistributedLogger:
|
||||||
:param ranks: List of parallel ranks
|
:param ranks: List of parallel ranks
|
||||||
:type ranks: list
|
:type ranks: list
|
||||||
"""
|
"""
|
||||||
|
message_prefix = "{}:{} {}".format(*self.__get_call_info())
|
||||||
|
self._log('debug', message_prefix, parallel_mode, ranks)
|
||||||
self._log('debug', message, parallel_mode, ranks)
|
self._log('debug', message, parallel_mode, ranks)
|
||||||
|
|
||||||
def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
|
def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
|
||||||
|
@ -158,4 +180,6 @@ class DistributedLogger:
|
||||||
:param ranks: List of parallel ranks
|
:param ranks: List of parallel ranks
|
||||||
:type ranks: list
|
:type ranks: list
|
||||||
"""
|
"""
|
||||||
|
message_prefix = "{}:{} {}".format(*self.__get_call_info())
|
||||||
|
self._log('error', message_prefix, parallel_mode, ranks)
|
||||||
self._log('error', message, parallel_mode, ranks)
|
self._log('error', message, parallel_mode, ranks)
|
||||||
|
|
Loading…
Reference in New Issue