mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish _checkpoint_hook.py code style (#1722)
parent
b38efe4e8a
commit
730f88f8e1
|
@ -50,32 +50,23 @@ class SaveCheckpointHook(BaseHook):
|
|||
break
|
||||
self.model = self.model if self.model is not None else trainer.engine.model
|
||||
|
||||
|
||||
def after_train_iter(self, trainer, output, label, loss):
|
||||
"""Saves the model after a training iter.
|
||||
"""
|
||||
# save by interval
|
||||
if self.save_by_iter and trainer.cur_step % self.interval == 0:
|
||||
save_checkpoint(self.checkpoint_dir,
|
||||
trainer.cur_epoch,
|
||||
self.model,
|
||||
trainer.engine.optimizer,
|
||||
save_checkpoint(self.checkpoint_dir, trainer.cur_epoch, self.model, trainer.engine.optimizer,
|
||||
self._lr_scheduler)
|
||||
self.logger.info(
|
||||
f'checkpoint for iteration {trainer.cur_step} is saved to {self.checkpoint_dir}', ranks=[0])
|
||||
self.logger.info(f'checkpoint for iteration {trainer.cur_step} is saved to {self.checkpoint_dir}',
|
||||
ranks=[0])
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
def after_train_epoch(self, trainer):
|
||||
"""Saves the model after a training epoch.
|
||||
"""
|
||||
# save by interval
|
||||
if trainer.cur_epoch % self.interval == 0:
|
||||
save_checkpoint(self.checkpoint_dir,
|
||||
trainer.cur_epoch,
|
||||
self.model,
|
||||
trainer.engine.optimizer,
|
||||
save_checkpoint(self.checkpoint_dir, trainer.cur_epoch, self.model, trainer.engine.optimizer,
|
||||
self._lr_scheduler)
|
||||
self.logger.info(
|
||||
f'checkpoint for epoch {trainer.cur_epoch} is saved to {self.checkpoint_dir}', ranks=[0])
|
||||
self.logger.info(f'checkpoint for epoch {trainer.cur_epoch} is saved to {self.checkpoint_dir}', ranks=[0])
|
||||
|
|
Loading…
Reference in New Issue