diff --git a/colossalai/logging/__init__.py b/colossalai/logging/__init__.py index 9a7309910..729f99619 100644 --- a/colossalai/logging/__init__.py +++ b/colossalai/logging/__init__.py @@ -1,8 +1,9 @@ -from typing import List -from .logging import DistributedLogger import logging +from typing import List, Optional -__all__ = ['get_dist_logger', 'DistributedLogger'] +from .logger import DistributedLogger + +__all__ = ['get_dist_logger', 'DistributedLogger', 'disable_existing_loggers'] def get_dist_logger(name='colossalai'): @@ -18,12 +19,20 @@ def get_dist_logger(name='colossalai'): return DistributedLogger.get_instance(name=name) -def disable_existing_loggers(except_loggers: List[str] = ['colossalai']): - """Set the level of existing loggers to `WARNING`. +def disable_existing_loggers(include: Optional[List[str]] = None, exclude: List[str] = ['colossalai']): + """Set the level of existing loggers to `WARNING`. By default, it will "disable" all existing loggers except the logger named "colossalai". - :param except_loggers: loggers in this `list` will be ignored when disabling, defaults to ['colossalai'] - :type except_loggers: list, optional + Args: + include (Optional[List[str]], optional): Loggers whose name in this list will be disabled. + If set to `None`, `exclude` argument will be used. Defaults to None. + exclude (List[str], optional): Loggers whose name not in this list will be disabled. + This argument will be used only when `include` is None. Defaults to ['colossalai']. """ + if include is None: + filter_func = lambda name: name not in exclude + else: + filter_func = lambda name: name in include + for log_name in logging.Logger.manager.loggerDict.keys(): - if log_name not in except_loggers: + if filter_func(log_name): logging.getLogger(log_name).setLevel(logging.WARNING) diff --git a/colossalai/logging/logging.py b/colossalai/logging/logger.py similarity index 100% rename from colossalai/logging/logging.py rename to colossalai/logging/logger.py