diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 208bdb7c5..df741de47 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ repos: rev: v0.32.0 hooks: - id: yapf - args: ['--style=google', '--parallel', '--in-place'] + args: ['--style=.style.yapf', '--parallel', '--in-place'] - repo: https://github.com/pycqa/flake8 rev: '4.0.1' hooks: diff --git a/.style.yapf b/.style.yapf new file mode 100644 index 000000000..05be0dc6a --- /dev/null +++ b/.style.yapf @@ -0,0 +1,5 @@ +[style] +based_on_style = google +spaces_before_comment = 4 +split_before_logical_operator = true +column_limit = 120 diff --git a/colossalai/logging/logging.py b/colossalai/logging/logging.py index c1760f7ed..089308188 100644 --- a/colossalai/logging/logging.py +++ b/colossalai/logging/logging.py @@ -8,7 +8,6 @@ from typing import Union from colossalai.context.parallel_mode import ParallelMode - _FORMAT = 'colossalai - %(name)s - %(asctime)s %(levelname)s: %(message)s' logging.basicConfig(level=logging.INFO, format=_FORMAT) @@ -39,7 +38,8 @@ class DistributedLogger: def __init__(self, name): if name in DistributedLogger.__instances: - raise Exception('Logger with the same name has been created, you should use colossalai.logging.get_dist_logger') + raise Exception( + 'Logger with the same name has been created, you should use colossalai.logging.get_dist_logger') else: self._name = name self._logger = logging.getLogger(name) @@ -58,11 +58,7 @@ class DistributedLogger: self._check_valid_logging_level(level) self._logger.setLevel(getattr(logging, level)) - def log_to_file(self, - path: Union[str, Path], - mode: str = 'a', - level: str = 'INFO', - suffix: str = None): + def log_to_file(self, path: Union[str, Path], mode: str = 'a', level: str = 'INFO', suffix: str = None): """Save the logs to file :param path: The file to save the log @@ -77,9 +73,13 @@ class DistributedLogger: assert isinstance(path, (str, Path)), \ f'expected argument path to be type str or Path, but got {type(path)}' self._check_valid_logging_level(level) + if isinstance(path, str): path = Path(path) + # create log directory + path.mkdir(parents=True, exist_ok=True) + # set the default file name if path is a directory if not colossalai.core.global_context.is_initialized(ParallelMode.GLOBAL): rank = 0 diff --git a/colossalai/utils/checkpointing.py b/colossalai/utils/checkpointing.py index d818ad0c0..bb39c07d2 100644 --- a/colossalai/utils/checkpointing.py +++ b/colossalai/utils/checkpointing.py @@ -2,6 +2,7 @@ import os import os.path as osp import re from typing import Tuple +from pathlib import Path import torch @@ -10,10 +11,7 @@ from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc __all__ = [ - 'get_checkpoint_path', - 'get_latest_checkpoint_path', - 'get_latest_checkpoint_pattern', - 'save_checkpoint', + 'get_checkpoint_path', 'get_latest_checkpoint_path', 'get_latest_checkpoint_pattern', 'save_checkpoint', 'load_checkpoint' ] @@ -70,9 +68,9 @@ def get_checkpoint_path(checkpoint_dir: str, epoch: int, suffix: str = ''): def _ensure_directory_exists(filename: str): # ensure the directory exists - dir = os.path.dirname(filename) - if not os.path.exists(dir): - os.makedirs(dir) + dirpath = os.path.dirname(filename) + if not os.path.exists(dirpath): + Path(dirpath).mkdir(parents=True, exist_ok=True) def get_latest_checkpoint_pattern(suffix: str = ''): @@ -84,7 +82,8 @@ def get_latest_checkpoint_pattern(suffix: str = ''): :rtype: regular expression """ ranks_name = _get_ranks_name() - ckpt_pattern = re.compile(f'epoch(\d+)-{ranks_name}{suffix}\.pt') + pattern = r'epoch(\d+)-{}{}\.pt'.format(ranks_name, suffix) + ckpt_pattern = re.compile(pattern) return ckpt_pattern @@ -127,7 +126,8 @@ def save_checkpoint(checkpoint_path: str, optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, **kwargs): - """Given a directory to store the checkpoints, saves all the training components' parameters or buffers, such as model, optimizer, lr_scheduler and etc. into a checkpoint dictionary. + """Given a directory to store the checkpoints, saves all the training components' parameters or buffers, such as model, + optimizer, lr_scheduler and etc. into a checkpoint dictionary. This method can be used for both colosalai nn.BaseModel and normal pytorch nn.Module. @@ -150,12 +150,7 @@ def save_checkpoint(checkpoint_path: str, model_sd = model.state_dict() # ckpt container - checkpoint = { - 'epoch': epoch, - 'model': model_sd, - 'optimizer': optimizer.state_dict(), - **kwargs - } + checkpoint = {'epoch': epoch, 'model': model_sd, 'optimizer': optimizer.state_dict(), **kwargs} if lr_scheduler is not None: checkpoint['lr_scheduler'] = lr_scheduler.state_dict() @@ -171,9 +166,11 @@ def load_checkpoint(checkpoint_path: str, strict: bool = True) -> Tuple: """Loads the checkpoint file. If finetune is False, then we intend to continue/resume the training process from the checkpoint given. - So we copy parameters and buffers from state_dict into these modules(model, optimizer,lr_scheduler) and its descendants. + So we copy parameters and buffers from state_dict into these modules(model, optimizer,lr_scheduler) + and its descendants. If finetune is True, then only the weights and buffers of model should be reload. - If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function. + If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s + state_dict() function. :param checkpoint_path: The exact and matched checkpoint_path directory to retrieve appropriate state_dict :type checkpoint_path: str