fixed logger

pull/1958/head
zbian 2022-11-15 15:07:50 +08:00 committed by アマデウス
parent 6877121377
commit 598d456d0e
1 changed files with 19 additions and 14 deletions

View File

@ -1,24 +1,14 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import colossalai import inspect
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Union, List from typing import List, Union
import inspect
import colossalai
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
try:
from rich.logging import RichHandler
_FORMAT = 'colossalai - %(name)s - %(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO,
format=_FORMAT,
handlers=[RichHandler(show_path=False, markup=True, rich_tracebacks=True)])
except ImportError:
_FORMAT = 'colossalai - %(name)s - %(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=_FORMAT)
class DistributedLogger: class DistributedLogger:
"""This is a distributed event logger class essentially based on :class:`logging`. """This is a distributed event logger class essentially based on :class:`logging`.
@ -40,7 +30,7 @@ class DistributedLogger:
Args: Args:
name (str): The name of the logger. name (str): The name of the logger.
Returns: Returns:
DistributedLogger: A DistributedLogger object DistributedLogger: A DistributedLogger object
""" """
@ -55,8 +45,23 @@ class DistributedLogger:
raise Exception( raise Exception(
'Logger with the same name has been created, you should use colossalai.logging.get_dist_logger') 'Logger with the same name has been created, you should use colossalai.logging.get_dist_logger')
else: else:
handler = None
formatter = logging.Formatter('colossalai - %(name)s - %(levelname)s: %(message)s')
try:
from rich.logging import RichHandler
handler = RichHandler(show_path=False, markup=True, rich_tracebacks=True)
handler.setFormatter(formatter)
except ImportError:
handler = logging.StreamHandler()
handler.setFormatter(formatter)
self._name = name self._name = name
self._logger = logging.getLogger(name) self._logger = logging.getLogger(name)
self._logger.setLevel(logging.INFO)
if handler is not None:
self._logger.addHandler(handler)
self._logger.propagate = False
DistributedLogger.__instances[name] = self DistributedLogger.__instances[name] = self
@staticmethod @staticmethod