mirror of https://github.com/hpcaitech/ColossalAI
fixed mkdir conflict and align yapf config with flake (#220)
parent
65e72983dc
commit
3a1a9820b0
|
@ -3,7 +3,7 @@ repos:
|
||||||
rev: v0.32.0
|
rev: v0.32.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: yapf
|
- id: yapf
|
||||||
args: ['--style=google', '--parallel', '--in-place']
|
args: ['--style=.style.yapf', '--parallel', '--in-place']
|
||||||
- repo: https://github.com/pycqa/flake8
|
- repo: https://github.com/pycqa/flake8
|
||||||
rev: '4.0.1'
|
rev: '4.0.1'
|
||||||
hooks:
|
hooks:
|
||||||
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
[style]
|
||||||
|
based_on_style = google
|
||||||
|
spaces_before_comment = 4
|
||||||
|
split_before_logical_operator = true
|
||||||
|
column_limit = 120
|
|
@ -8,7 +8,6 @@ from typing import Union
|
||||||
|
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
|
|
||||||
|
|
||||||
_FORMAT = 'colossalai - %(name)s - %(asctime)s %(levelname)s: %(message)s'
|
_FORMAT = 'colossalai - %(name)s - %(asctime)s %(levelname)s: %(message)s'
|
||||||
logging.basicConfig(level=logging.INFO, format=_FORMAT)
|
logging.basicConfig(level=logging.INFO, format=_FORMAT)
|
||||||
|
|
||||||
|
@ -39,7 +38,8 @@ class DistributedLogger:
|
||||||
|
|
||||||
def __init__(self, name):
|
def __init__(self, name):
|
||||||
if name in DistributedLogger.__instances:
|
if name in DistributedLogger.__instances:
|
||||||
raise Exception('Logger with the same name has been created, you should use colossalai.logging.get_dist_logger')
|
raise Exception(
|
||||||
|
'Logger with the same name has been created, you should use colossalai.logging.get_dist_logger')
|
||||||
else:
|
else:
|
||||||
self._name = name
|
self._name = name
|
||||||
self._logger = logging.getLogger(name)
|
self._logger = logging.getLogger(name)
|
||||||
|
@ -58,11 +58,7 @@ class DistributedLogger:
|
||||||
self._check_valid_logging_level(level)
|
self._check_valid_logging_level(level)
|
||||||
self._logger.setLevel(getattr(logging, level))
|
self._logger.setLevel(getattr(logging, level))
|
||||||
|
|
||||||
def log_to_file(self,
|
def log_to_file(self, path: Union[str, Path], mode: str = 'a', level: str = 'INFO', suffix: str = None):
|
||||||
path: Union[str, Path],
|
|
||||||
mode: str = 'a',
|
|
||||||
level: str = 'INFO',
|
|
||||||
suffix: str = None):
|
|
||||||
"""Save the logs to file
|
"""Save the logs to file
|
||||||
|
|
||||||
:param path: The file to save the log
|
:param path: The file to save the log
|
||||||
|
@ -77,9 +73,13 @@ class DistributedLogger:
|
||||||
assert isinstance(path, (str, Path)), \
|
assert isinstance(path, (str, Path)), \
|
||||||
f'expected argument path to be type str or Path, but got {type(path)}'
|
f'expected argument path to be type str or Path, but got {type(path)}'
|
||||||
self._check_valid_logging_level(level)
|
self._check_valid_logging_level(level)
|
||||||
|
|
||||||
if isinstance(path, str):
|
if isinstance(path, str):
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
|
|
||||||
|
# create log directory
|
||||||
|
path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# set the default file name if path is a directory
|
# set the default file name if path is a directory
|
||||||
if not colossalai.core.global_context.is_initialized(ParallelMode.GLOBAL):
|
if not colossalai.core.global_context.is_initialized(ParallelMode.GLOBAL):
|
||||||
rank = 0
|
rank = 0
|
||||||
|
|
|
@ -2,6 +2,7 @@ import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import re
|
import re
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -10,10 +11,7 @@ from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'get_checkpoint_path',
|
'get_checkpoint_path', 'get_latest_checkpoint_path', 'get_latest_checkpoint_pattern', 'save_checkpoint',
|
||||||
'get_latest_checkpoint_path',
|
|
||||||
'get_latest_checkpoint_pattern',
|
|
||||||
'save_checkpoint',
|
|
||||||
'load_checkpoint'
|
'load_checkpoint'
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -70,9 +68,9 @@ def get_checkpoint_path(checkpoint_dir: str, epoch: int, suffix: str = ''):
|
||||||
|
|
||||||
def _ensure_directory_exists(filename: str):
|
def _ensure_directory_exists(filename: str):
|
||||||
# ensure the directory exists
|
# ensure the directory exists
|
||||||
dir = os.path.dirname(filename)
|
dirpath = os.path.dirname(filename)
|
||||||
if not os.path.exists(dir):
|
if not os.path.exists(dirpath):
|
||||||
os.makedirs(dir)
|
Path(dirpath).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
def get_latest_checkpoint_pattern(suffix: str = ''):
|
def get_latest_checkpoint_pattern(suffix: str = ''):
|
||||||
|
@ -84,7 +82,8 @@ def get_latest_checkpoint_pattern(suffix: str = ''):
|
||||||
:rtype: regular expression
|
:rtype: regular expression
|
||||||
"""
|
"""
|
||||||
ranks_name = _get_ranks_name()
|
ranks_name = _get_ranks_name()
|
||||||
ckpt_pattern = re.compile(f'epoch(\d+)-{ranks_name}{suffix}\.pt')
|
pattern = r'epoch(\d+)-{}{}\.pt'.format(ranks_name, suffix)
|
||||||
|
ckpt_pattern = re.compile(pattern)
|
||||||
return ckpt_pattern
|
return ckpt_pattern
|
||||||
|
|
||||||
|
|
||||||
|
@ -127,7 +126,8 @@ def save_checkpoint(checkpoint_path: str,
|
||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""Given a directory to store the checkpoints, saves all the training components' parameters or buffers, such as model, optimizer, lr_scheduler and etc. into a checkpoint dictionary.
|
"""Given a directory to store the checkpoints, saves all the training components' parameters or buffers, such as model,
|
||||||
|
optimizer, lr_scheduler and etc. into a checkpoint dictionary.
|
||||||
|
|
||||||
This method can be used for both colosalai nn.BaseModel and normal pytorch nn.Module.
|
This method can be used for both colosalai nn.BaseModel and normal pytorch nn.Module.
|
||||||
|
|
||||||
|
@ -150,12 +150,7 @@ def save_checkpoint(checkpoint_path: str,
|
||||||
model_sd = model.state_dict()
|
model_sd = model.state_dict()
|
||||||
|
|
||||||
# ckpt container
|
# ckpt container
|
||||||
checkpoint = {
|
checkpoint = {'epoch': epoch, 'model': model_sd, 'optimizer': optimizer.state_dict(), **kwargs}
|
||||||
'epoch': epoch,
|
|
||||||
'model': model_sd,
|
|
||||||
'optimizer': optimizer.state_dict(),
|
|
||||||
**kwargs
|
|
||||||
}
|
|
||||||
if lr_scheduler is not None:
|
if lr_scheduler is not None:
|
||||||
checkpoint['lr_scheduler'] = lr_scheduler.state_dict()
|
checkpoint['lr_scheduler'] = lr_scheduler.state_dict()
|
||||||
|
|
||||||
|
@ -171,9 +166,11 @@ def load_checkpoint(checkpoint_path: str,
|
||||||
strict: bool = True) -> Tuple:
|
strict: bool = True) -> Tuple:
|
||||||
"""Loads the checkpoint file.
|
"""Loads the checkpoint file.
|
||||||
If finetune is False, then we intend to continue/resume the training process from the checkpoint given.
|
If finetune is False, then we intend to continue/resume the training process from the checkpoint given.
|
||||||
So we copy parameters and buffers from state_dict into these modules(model, optimizer,lr_scheduler) and its descendants.
|
So we copy parameters and buffers from state_dict into these modules(model, optimizer,lr_scheduler)
|
||||||
|
and its descendants.
|
||||||
If finetune is True, then only the weights and buffers of model should be reload.
|
If finetune is True, then only the weights and buffers of model should be reload.
|
||||||
If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.
|
If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s
|
||||||
|
state_dict() function.
|
||||||
|
|
||||||
:param checkpoint_path: The exact and matched checkpoint_path directory to retrieve appropriate state_dict
|
:param checkpoint_path: The exact and matched checkpoint_path directory to retrieve appropriate state_dict
|
||||||
:type checkpoint_path: str
|
:type checkpoint_path: str
|
||||||
|
|
Loading…
Reference in New Issue