[doc] improved docstring in the logging module (#861)

pull/867/head
Frank Lee 2022-04-25 13:42:00 +08:00 committed by GitHub
parent 8004c8e938
commit b862d89d00
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 20 deletions

View File

@ -6,22 +6,20 @@ from .logger import DistributedLogger
__all__ = ['get_dist_logger', 'DistributedLogger', 'disable_existing_loggers']
def get_dist_logger(name='colossalai'):
def get_dist_logger(name: str = 'colossalai') -> DistributedLogger:
"""Get logger instance based on name. The DistributedLogger will create singleton instances,
which means that only one logger instance is created per name.
Args:
:param name: name of the logger, name must be unique
:type name: str
:return: a distributed logger instance
:rtype: :class:`colossalai.logging.DistributedLogger`
name (str): name of the logger, name must be unique
Returns:
:class:`colossalai.logging.DistributedLogger`: A distributed logger singleton instance.
"""
return DistributedLogger.get_instance(name=name)
def disable_existing_loggers(include: Optional[List[str]] = None, exclude: List[str] = ['colossalai']):
def disable_existing_loggers(include: Optional[List[str]] = None, exclude: List[str] = ['colossalai']) -> None:
"""Set the level of existing loggers to `WARNING`. By default, it will "disable" all existing loggers except the logger named "colossalai".
Args:

View File

@ -4,7 +4,7 @@
import colossalai
import logging
from pathlib import Path
from typing import Union
from typing import Union, List
import inspect
from colossalai.context.parallel_mode import ParallelMode
@ -40,6 +40,7 @@ class DistributedLogger:
Args:
name (str): The name of the logger.
Returns:
DistributedLogger: A DistributedLogger object
"""
@ -75,7 +76,7 @@ class DistributedLogger:
def _check_valid_logging_level(level: str):
assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR'], 'found invalid logging level'
def set_level(self, level: str):
def set_level(self, level: str) -> None:
"""Set the logging level
Args:
@ -84,7 +85,7 @@ class DistributedLogger:
self._check_valid_logging_level(level)
self._logger.setLevel(getattr(logging, level))
def log_to_file(self, path: Union[str, Path], mode: str = 'a', level: str = 'INFO', suffix: str = None):
def log_to_file(self, path: Union[str, Path], mode: str = 'a', level: str = 'INFO', suffix: str = None) -> None:
"""Save the logs to file
Args:
@ -122,7 +123,11 @@ class DistributedLogger:
file_handler.setFormatter(formatter)
self._logger.addHandler(file_handler)
def _log(self, level, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
def _log(self,
level,
message: str,
parallel_mode: ParallelMode = ParallelMode.GLOBAL,
ranks: List[int] = None) -> None:
if ranks is None:
getattr(self._logger, level)(message)
else:
@ -130,53 +135,53 @@ class DistributedLogger:
if local_rank in ranks:
getattr(self._logger, level)(message)
def info(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
def info(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
"""Log an info message.
Args:
message (str): The message to be logged.
parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`):
The parallel mode used for logging. Defaults to ParallelMode.GLOBAL.
ranks (List): List of parallel ranks.
ranks (List[int]): List of parallel ranks.
"""
message_prefix = "{}:{} {}".format(*self.__get_call_info())
self._log('info', message_prefix, parallel_mode, ranks)
self._log('info', message, parallel_mode, ranks)
def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
"""Log a warning message.
Args:
message (str): The message to be logged.
parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`):
The parallel mode used for logging. Defaults to ParallelMode.GLOBAL.
ranks (List): List of parallel ranks.
ranks (List[int]): List of parallel ranks.
"""
message_prefix = "{}:{} {}".format(*self.__get_call_info())
self._log('warning', message_prefix, parallel_mode, ranks)
self._log('warning', message, parallel_mode, ranks)
def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
"""Log a debug message.
Args:
message (str): The message to be logged.
parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`):
The parallel mode used for logging. Defaults to ParallelMode.GLOBAL.
ranks (List): List of parallel ranks.
ranks (List[int]): List of parallel ranks.
"""
message_prefix = "{}:{} {}".format(*self.__get_call_info())
self._log('debug', message_prefix, parallel_mode, ranks)
self._log('debug', message, parallel_mode, ranks)
def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
"""Log an error message.
Args:
message (str): The message to be logged.
parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`):
The parallel mode used for logging. Defaults to ParallelMode.GLOBAL.
ranks (List): List of parallel ranks.
ranks (List[int]): List of parallel ranks.
"""
message_prefix = "{}:{} {}".format(*self.__get_call_info())
self._log('error', message_prefix, parallel_mode, ranks)