fixed mkdir conflict and align yapf config with flake (#220)

pull/232/head
Frank Lee 2022-02-14 16:19:24 +08:00
parent 65e72983dc
commit 3a1a9820b0
4 changed files with 27 additions and 25 deletions

View File

@ -3,7 +3,7 @@ repos:
rev: v0.32.0 rev: v0.32.0
hooks: hooks:
- id: yapf - id: yapf
args: ['--style=google', '--parallel', '--in-place'] args: ['--style=.style.yapf', '--parallel', '--in-place']
- repo: https://github.com/pycqa/flake8 - repo: https://github.com/pycqa/flake8
rev: '4.0.1' rev: '4.0.1'
hooks: hooks:

5
.style.yapf Normal file
View File

@ -0,0 +1,5 @@
[style]
based_on_style = google
spaces_before_comment = 4
split_before_logical_operator = true
column_limit = 120

View File

@ -8,7 +8,6 @@ from typing import Union
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
_FORMAT = 'colossalai - %(name)s - %(asctime)s %(levelname)s: %(message)s' _FORMAT = 'colossalai - %(name)s - %(asctime)s %(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=_FORMAT) logging.basicConfig(level=logging.INFO, format=_FORMAT)
@ -39,7 +38,8 @@ class DistributedLogger:
def __init__(self, name): def __init__(self, name):
if name in DistributedLogger.__instances: if name in DistributedLogger.__instances:
raise Exception('Logger with the same name has been created, you should use colossalai.logging.get_dist_logger') raise Exception(
'Logger with the same name has been created, you should use colossalai.logging.get_dist_logger')
else: else:
self._name = name self._name = name
self._logger = logging.getLogger(name) self._logger = logging.getLogger(name)
@ -58,11 +58,7 @@ class DistributedLogger:
self._check_valid_logging_level(level) self._check_valid_logging_level(level)
self._logger.setLevel(getattr(logging, level)) self._logger.setLevel(getattr(logging, level))
def log_to_file(self, def log_to_file(self, path: Union[str, Path], mode: str = 'a', level: str = 'INFO', suffix: str = None):
path: Union[str, Path],
mode: str = 'a',
level: str = 'INFO',
suffix: str = None):
"""Save the logs to file """Save the logs to file
:param path: The file to save the log :param path: The file to save the log
@ -77,9 +73,13 @@ class DistributedLogger:
assert isinstance(path, (str, Path)), \ assert isinstance(path, (str, Path)), \
f'expected argument path to be type str or Path, but got {type(path)}' f'expected argument path to be type str or Path, but got {type(path)}'
self._check_valid_logging_level(level) self._check_valid_logging_level(level)
if isinstance(path, str): if isinstance(path, str):
path = Path(path) path = Path(path)
# create log directory
path.mkdir(parents=True, exist_ok=True)
# set the default file name if path is a directory # set the default file name if path is a directory
if not colossalai.core.global_context.is_initialized(ParallelMode.GLOBAL): if not colossalai.core.global_context.is_initialized(ParallelMode.GLOBAL):
rank = 0 rank = 0

View File

@ -2,6 +2,7 @@ import os
import os.path as osp import os.path as osp
import re import re
from typing import Tuple from typing import Tuple
from pathlib import Path
import torch import torch
@ -10,10 +11,7 @@ from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
__all__ = [ __all__ = [
'get_checkpoint_path', 'get_checkpoint_path', 'get_latest_checkpoint_path', 'get_latest_checkpoint_pattern', 'save_checkpoint',
'get_latest_checkpoint_path',
'get_latest_checkpoint_pattern',
'save_checkpoint',
'load_checkpoint' 'load_checkpoint'
] ]
@ -70,9 +68,9 @@ def get_checkpoint_path(checkpoint_dir: str, epoch: int, suffix: str = ''):
def _ensure_directory_exists(filename: str): def _ensure_directory_exists(filename: str):
# ensure the directory exists # ensure the directory exists
dir = os.path.dirname(filename) dirpath = os.path.dirname(filename)
if not os.path.exists(dir): if not os.path.exists(dirpath):
os.makedirs(dir) Path(dirpath).mkdir(parents=True, exist_ok=True)
def get_latest_checkpoint_pattern(suffix: str = ''): def get_latest_checkpoint_pattern(suffix: str = ''):
@ -84,7 +82,8 @@ def get_latest_checkpoint_pattern(suffix: str = ''):
:rtype: regular expression :rtype: regular expression
""" """
ranks_name = _get_ranks_name() ranks_name = _get_ranks_name()
ckpt_pattern = re.compile(f'epoch(\d+)-{ranks_name}{suffix}\.pt') pattern = r'epoch(\d+)-{}{}\.pt'.format(ranks_name, suffix)
ckpt_pattern = re.compile(pattern)
return ckpt_pattern return ckpt_pattern
@ -127,7 +126,8 @@ def save_checkpoint(checkpoint_path: str,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
**kwargs): **kwargs):
"""Given a directory to store the checkpoints, saves all the training components' parameters or buffers, such as model, optimizer, lr_scheduler and etc. into a checkpoint dictionary. """Given a directory to store the checkpoints, saves all the training components' parameters or buffers, such as model,
optimizer, lr_scheduler and etc. into a checkpoint dictionary.
This method can be used for both colosalai nn.BaseModel and normal pytorch nn.Module. This method can be used for both colosalai nn.BaseModel and normal pytorch nn.Module.
@ -150,12 +150,7 @@ def save_checkpoint(checkpoint_path: str,
model_sd = model.state_dict() model_sd = model.state_dict()
# ckpt container # ckpt container
checkpoint = { checkpoint = {'epoch': epoch, 'model': model_sd, 'optimizer': optimizer.state_dict(), **kwargs}
'epoch': epoch,
'model': model_sd,
'optimizer': optimizer.state_dict(),
**kwargs
}
if lr_scheduler is not None: if lr_scheduler is not None:
checkpoint['lr_scheduler'] = lr_scheduler.state_dict() checkpoint['lr_scheduler'] = lr_scheduler.state_dict()
@ -171,9 +166,11 @@ def load_checkpoint(checkpoint_path: str,
strict: bool = True) -> Tuple: strict: bool = True) -> Tuple:
"""Loads the checkpoint file. """Loads the checkpoint file.
If finetune is False, then we intend to continue/resume the training process from the checkpoint given. If finetune is False, then we intend to continue/resume the training process from the checkpoint given.
So we copy parameters and buffers from state_dict into these modules(model, optimizer,lr_scheduler) and its descendants. So we copy parameters and buffers from state_dict into these modules(model, optimizer,lr_scheduler)
and its descendants.
If finetune is True, then only the weights and buffers of model should be reload. If finetune is True, then only the weights and buffers of model should be reload.
If strict is True, then the keys of state_dict must exactly match the keys returned by this modules state_dict() function. If strict is True, then the keys of state_dict must exactly match the keys returned by this modules
state_dict() function.
:param checkpoint_path: The exact and matched checkpoint_path directory to retrieve appropriate state_dict :param checkpoint_path: The exact and matched checkpoint_path directory to retrieve appropriate state_dict
:type checkpoint_path: str :type checkpoint_path: str