mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
141 lines
5.4 KiB
141 lines
5.4 KiB
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
import os.path as osp
|
|
|
|
from colossalai.registry import HOOKS
|
|
from colossalai.trainer import Trainer
|
|
from colossalai.trainer.hooks import BaseHook
|
|
from colossalai.utils import is_dp_rank_0
|
|
from colossalai.utils.checkpointing import get_latest_checkpoint_path, get_checkpoint_path
|
|
from colossalai.utils.checkpointing import save_checkpoint, load_checkpoint
|
|
from ._lr_scheduler_hook import LRSchedulerHook
|
|
|
|
|
|
@HOOKS.register_module
|
|
class SaveCheckpointHook(BaseHook):
|
|
"""Saves the model by interval in training process.
|
|
|
|
:param trainer: Trainer attached with current hook
|
|
:param interval: Saving interval
|
|
:param checkpoint_dir: Directory of saving checkpoint
|
|
:param suffix: Saving suffix of the file
|
|
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
|
:type trainer: Trainer
|
|
:type interval: int, optional
|
|
:type checkpoint_dir: int, optional
|
|
:type suffix: str, optional
|
|
:type priority: int, optional
|
|
"""
|
|
|
|
def __init__(self,
|
|
trainer: Trainer,
|
|
interval: int = 1,
|
|
checkpoint_dir: str = None,
|
|
suffix: str = '',
|
|
priority: int = 10):
|
|
super().__init__(trainer=trainer, priority=priority)
|
|
assert isinstance(trainer, Trainer), \
|
|
f'SaveCheckpointHook expects a Trainer, got {type(trainer)}'
|
|
self.interval = interval
|
|
self.checkpoint_dir = checkpoint_dir
|
|
self.suffix = suffix
|
|
|
|
# get lr scheduler from the LRSchedulerHook before train
|
|
self._lr_scheduler = None
|
|
|
|
def before_train(self):
|
|
# check if lr scheduler is present in LRSchedulerHook
|
|
for hook in self.trainer.hooks:
|
|
if isinstance(hook, LRSchedulerHook):
|
|
self._lr_scheduler = hook.lr_scheduler
|
|
break
|
|
|
|
def after_train_epoch(self):
|
|
"""Saves the model after a training epoch.
|
|
"""
|
|
# save by interval
|
|
if self.trainer.cur_epoch % self.interval == 0:
|
|
# only gpus with data parallel rank equals to 0 write to the disk
|
|
if is_dp_rank_0():
|
|
save_path = get_checkpoint_path(self.checkpoint_dir,
|
|
self.trainer.cur_epoch,
|
|
suffix=self.suffix)
|
|
|
|
save_checkpoint(save_path,
|
|
self.trainer.cur_epoch,
|
|
self.trainer.engine.model,
|
|
self.trainer.engine.optimizer,
|
|
self._lr_scheduler)
|
|
self.logger.info(
|
|
f'checkpoint for epoch {self.trainer.cur_epoch} is saved to {self.checkpoint_dir}')
|
|
|
|
|
|
@HOOKS.register_module
|
|
class LoadCheckpointHook(BaseHook):
|
|
"""Loads the model before training process.
|
|
|
|
:param trainer: Trainer attached with current hook
|
|
:param checkpoint_dir: Directory of saving checkpoint
|
|
:param epoch: Epoch number to be set
|
|
:param finetune: Whether allows to load a part of the model
|
|
:param strict: Whether loads a model that has the same shape of parameters
|
|
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
|
:type trainer: Trainer
|
|
:type checkpoint_dir: str, optional
|
|
:type epoch: str, optional
|
|
:type finetune: bool, optional
|
|
:type strict: bool, optional
|
|
:type priority: int, optional
|
|
"""
|
|
|
|
def __init__(self,
|
|
trainer: Trainer = None,
|
|
checkpoint_dir: str = None,
|
|
epoch: int = -1,
|
|
finetune: bool = False,
|
|
strict: bool = False,
|
|
suffix: str = '',
|
|
priority: int = 0) -> None:
|
|
super().__init__(trainer=trainer, priority=priority)
|
|
assert isinstance(trainer, Trainer), \
|
|
f'LoadLatestCheckpointHook excepts a Trainer, got {type(trainer)}'
|
|
self.epoch = epoch
|
|
self.checkpoint_dir = checkpoint_dir
|
|
self.finetune = finetune
|
|
self.suffix = suffix
|
|
self.strict = strict
|
|
|
|
def before_train(self):
|
|
"""Loads parameters to the model before training.
|
|
"""
|
|
# check if lr scheduler is present in LRSchedulerHook
|
|
lr_scheduler = None
|
|
for hook in self.trainer.hooks:
|
|
if isinstance(hook, LRSchedulerHook):
|
|
lr_scheduler = hook.lr_scheduler
|
|
break
|
|
|
|
# use latest checkpoint if epoch = -1
|
|
if self.epoch == -1:
|
|
path = get_latest_checkpoint_path(self.checkpoint_dir, suffix=self.suffix)
|
|
else:
|
|
path = get_checkpoint_path(self.checkpoint_dir, epoch=self.epoch, suffix=self.suffix)
|
|
|
|
if osp.exists(path):
|
|
last_epoch, _ = load_checkpoint(path,
|
|
self.trainer.engine.model,
|
|
self.trainer.engine.optimizer,
|
|
lr_scheduler,
|
|
finetune=self.finetune,
|
|
strict=self.strict)
|
|
if self.finetune:
|
|
self.trainer.cur_epoch = 0
|
|
else:
|
|
self.trainer.cur_epoch = last_epoch
|
|
|
|
self.logger.info(
|
|
f'loaded checkpoint from {path}')
|
|
else:
|
|
raise FileNotFoundError(f'checkpoint is not found at {path}')
|