mirror of https://github.com/hpcaitech/ColossalAI
[doc] improved docstring in the logging module (#861)
parent
8004c8e938
commit
b862d89d00
|
@ -6,22 +6,20 @@ from .logger import DistributedLogger
|
|||
__all__ = ['get_dist_logger', 'DistributedLogger', 'disable_existing_loggers']
|
||||
|
||||
|
||||
def get_dist_logger(name='colossalai'):
|
||||
def get_dist_logger(name: str = 'colossalai') -> DistributedLogger:
|
||||
"""Get logger instance based on name. The DistributedLogger will create singleton instances,
|
||||
which means that only one logger instance is created per name.
|
||||
|
||||
Args:
|
||||
|
||||
:param name: name of the logger, name must be unique
|
||||
:type name: str
|
||||
|
||||
:return: a distributed logger instance
|
||||
:rtype: :class:`colossalai.logging.DistributedLogger`
|
||||
name (str): name of the logger, name must be unique
|
||||
|
||||
Returns:
|
||||
:class:`colossalai.logging.DistributedLogger`: A distributed logger singleton instance.
|
||||
"""
|
||||
return DistributedLogger.get_instance(name=name)
|
||||
|
||||
|
||||
def disable_existing_loggers(include: Optional[List[str]] = None, exclude: List[str] = ['colossalai']):
|
||||
def disable_existing_loggers(include: Optional[List[str]] = None, exclude: List[str] = ['colossalai']) -> None:
|
||||
"""Set the level of existing loggers to `WARNING`. By default, it will "disable" all existing loggers except the logger named "colossalai".
|
||||
|
||||
Args:
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
import colossalai
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Union, List
|
||||
import inspect
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
|
@ -40,6 +40,7 @@ class DistributedLogger:
|
|||
|
||||
Args:
|
||||
name (str): The name of the logger.
|
||||
|
||||
Returns:
|
||||
DistributedLogger: A DistributedLogger object
|
||||
"""
|
||||
|
@ -75,7 +76,7 @@ class DistributedLogger:
|
|||
def _check_valid_logging_level(level: str):
|
||||
assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR'], 'found invalid logging level'
|
||||
|
||||
def set_level(self, level: str):
|
||||
def set_level(self, level: str) -> None:
|
||||
"""Set the logging level
|
||||
|
||||
Args:
|
||||
|
@ -84,7 +85,7 @@ class DistributedLogger:
|
|||
self._check_valid_logging_level(level)
|
||||
self._logger.setLevel(getattr(logging, level))
|
||||
|
||||
def log_to_file(self, path: Union[str, Path], mode: str = 'a', level: str = 'INFO', suffix: str = None):
|
||||
def log_to_file(self, path: Union[str, Path], mode: str = 'a', level: str = 'INFO', suffix: str = None) -> None:
|
||||
"""Save the logs to file
|
||||
|
||||
Args:
|
||||
|
@ -122,7 +123,11 @@ class DistributedLogger:
|
|||
file_handler.setFormatter(formatter)
|
||||
self._logger.addHandler(file_handler)
|
||||
|
||||
def _log(self, level, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
|
||||
def _log(self,
|
||||
level,
|
||||
message: str,
|
||||
parallel_mode: ParallelMode = ParallelMode.GLOBAL,
|
||||
ranks: List[int] = None) -> None:
|
||||
if ranks is None:
|
||||
getattr(self._logger, level)(message)
|
||||
else:
|
||||
|
@ -130,53 +135,53 @@ class DistributedLogger:
|
|||
if local_rank in ranks:
|
||||
getattr(self._logger, level)(message)
|
||||
|
||||
def info(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
|
||||
def info(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
|
||||
"""Log an info message.
|
||||
|
||||
Args:
|
||||
message (str): The message to be logged.
|
||||
parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`):
|
||||
The parallel mode used for logging. Defaults to ParallelMode.GLOBAL.
|
||||
ranks (List): List of parallel ranks.
|
||||
ranks (List[int]): List of parallel ranks.
|
||||
"""
|
||||
message_prefix = "{}:{} {}".format(*self.__get_call_info())
|
||||
self._log('info', message_prefix, parallel_mode, ranks)
|
||||
self._log('info', message, parallel_mode, ranks)
|
||||
|
||||
def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
|
||||
def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
|
||||
"""Log a warning message.
|
||||
|
||||
Args:
|
||||
message (str): The message to be logged.
|
||||
parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`):
|
||||
The parallel mode used for logging. Defaults to ParallelMode.GLOBAL.
|
||||
ranks (List): List of parallel ranks.
|
||||
ranks (List[int]): List of parallel ranks.
|
||||
"""
|
||||
message_prefix = "{}:{} {}".format(*self.__get_call_info())
|
||||
self._log('warning', message_prefix, parallel_mode, ranks)
|
||||
self._log('warning', message, parallel_mode, ranks)
|
||||
|
||||
def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
|
||||
def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
|
||||
"""Log a debug message.
|
||||
|
||||
Args:
|
||||
message (str): The message to be logged.
|
||||
parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`):
|
||||
The parallel mode used for logging. Defaults to ParallelMode.GLOBAL.
|
||||
ranks (List): List of parallel ranks.
|
||||
ranks (List[int]): List of parallel ranks.
|
||||
"""
|
||||
message_prefix = "{}:{} {}".format(*self.__get_call_info())
|
||||
self._log('debug', message_prefix, parallel_mode, ranks)
|
||||
self._log('debug', message, parallel_mode, ranks)
|
||||
|
||||
def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
|
||||
def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
|
||||
"""Log an error message.
|
||||
|
||||
Args:
|
||||
message (str): The message to be logged.
|
||||
parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`):
|
||||
The parallel mode used for logging. Defaults to ParallelMode.GLOBAL.
|
||||
ranks (List): List of parallel ranks.
|
||||
ranks (List[int]): List of parallel ranks.
|
||||
"""
|
||||
message_prefix = "{}:{} {}".format(*self.__get_call_info())
|
||||
self._log('error', message_prefix, parallel_mode, ranks)
|
||||
|
|
Loading…
Reference in New Issue