mirror of https://github.com/hpcaitech/ColossalAI
fixed logger
parent
6877121377
commit
598d456d0e
|
@ -1,24 +1,14 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
import colossalai
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union, List
|
from typing import List, Union
|
||||||
import inspect
|
|
||||||
|
|
||||||
|
import colossalai
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
|
|
||||||
try:
|
|
||||||
from rich.logging import RichHandler
|
|
||||||
_FORMAT = 'colossalai - %(name)s - %(levelname)s: %(message)s'
|
|
||||||
logging.basicConfig(level=logging.INFO,
|
|
||||||
format=_FORMAT,
|
|
||||||
handlers=[RichHandler(show_path=False, markup=True, rich_tracebacks=True)])
|
|
||||||
except ImportError:
|
|
||||||
_FORMAT = 'colossalai - %(name)s - %(levelname)s: %(message)s'
|
|
||||||
logging.basicConfig(level=logging.INFO, format=_FORMAT)
|
|
||||||
|
|
||||||
|
|
||||||
class DistributedLogger:
|
class DistributedLogger:
|
||||||
"""This is a distributed event logger class essentially based on :class:`logging`.
|
"""This is a distributed event logger class essentially based on :class:`logging`.
|
||||||
|
@ -55,8 +45,23 @@ class DistributedLogger:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
'Logger with the same name has been created, you should use colossalai.logging.get_dist_logger')
|
'Logger with the same name has been created, you should use colossalai.logging.get_dist_logger')
|
||||||
else:
|
else:
|
||||||
|
handler = None
|
||||||
|
formatter = logging.Formatter('colossalai - %(name)s - %(levelname)s: %(message)s')
|
||||||
|
try:
|
||||||
|
from rich.logging import RichHandler
|
||||||
|
handler = RichHandler(show_path=False, markup=True, rich_tracebacks=True)
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
except ImportError:
|
||||||
|
handler = logging.StreamHandler()
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
|
||||||
self._name = name
|
self._name = name
|
||||||
self._logger = logging.getLogger(name)
|
self._logger = logging.getLogger(name)
|
||||||
|
self._logger.setLevel(logging.INFO)
|
||||||
|
if handler is not None:
|
||||||
|
self._logger.addHandler(handler)
|
||||||
|
self._logger.propagate = False
|
||||||
|
|
||||||
DistributedLogger.__instances[name] = self
|
DistributedLogger.__instances[name] = self
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
Loading…
Reference in New Issue