From 846329064207862801045abbbe8d38e3b5169a63 Mon Sep 17 00:00:00 2001 From: HELSON Date: Tue, 26 Jul 2022 14:41:53 +0800 Subject: [PATCH] [checkpoint] use args, kwargs in save_checkpoint, load_checkpoint (#1368) --- colossalai/utils/checkpoint/module_checkpoint.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/colossalai/utils/checkpoint/module_checkpoint.py b/colossalai/utils/checkpoint/module_checkpoint.py index c59d1ecf3..cf9b11cc6 100644 --- a/colossalai/utils/checkpoint/module_checkpoint.py +++ b/colossalai/utils/checkpoint/module_checkpoint.py @@ -39,7 +39,7 @@ def save_checkpoint(dire: str, delattr(v, 'save_ready') # model saving save_state = {'epoch': epoch, 'model': model_state} - torch.save(save_state, dire + '/epoch_{}_model.pth'.format(epoch)) + torch.save(save_state, dire + '/epoch_{}_model.pth'.format(epoch), *args, **kwargs) # delete old dicts del model_state @@ -57,7 +57,7 @@ def save_checkpoint(dire: str, if rank == 0: save_state = {'epoch': epoch, 'optim': optim_state} - torch.save(save_state, dire + '/epoch_{}_optim.pth'.format(epoch)) + torch.save(save_state, dire + '/epoch_{}_optim.pth'.format(epoch), *args, **kwargs) # recover colo tensors in rank0 for k, v in optimizer.state_dict()['state'].items(): for n, t in v.items(): @@ -96,7 +96,7 @@ def load_checkpoint(dire, gather_tensor(p) if rank == 0: - load_state = torch.load(dire + '/epoch_{}_model.pth'.format(epoch)) + load_state = torch.load(dire + '/epoch_{}_model.pth'.format(epoch), *args, **kwargs) model.load_state_dict(load_state['model']) dist.barrier() @@ -118,7 +118,7 @@ def load_checkpoint(dire, gather_tensor(t) if rank == 0: - colo_checkpoint = torch.load(dire + '/epoch_{}_optim.pth'.format(epoch)) + colo_checkpoint = torch.load(dire + '/epoch_{}_optim.pth'.format(epoch), *args, **kwargs) optimizer.load_state_dict(colo_checkpoint['optim']) dist.barrier()