|
|
|
@ -3,7 +3,7 @@ import torch.distributed as dist
|
|
|
|
|
from colossalai.tensor import ColoTensor |
|
|
|
|
from colossalai.nn.optimizer import ColossalaiOptimizer |
|
|
|
|
from colossalai.utils.checkpoint.utils import gather_tensor, scatter_tensor |
|
|
|
|
from typing import Optional |
|
|
|
|
from typing import Optional, Dict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_checkpoint(path: str, |
|
|
|
@ -71,22 +71,23 @@ def save_checkpoint(path: str,
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_checkpoint(path, |
|
|
|
|
def load_checkpoint(path: str, |
|
|
|
|
epoch: int, |
|
|
|
|
model: torch.nn.Module, |
|
|
|
|
optimizer: Optional[ColossalaiOptimizer] = None, |
|
|
|
|
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, |
|
|
|
|
*args, |
|
|
|
|
**kwargs): |
|
|
|
|
torch_load_kwargs: Optional[Dict] = None, |
|
|
|
|
load_state_dict_kwargs: Optional[Dict] = None): |
|
|
|
|
"""load_checkpoint |
|
|
|
|
load a model, whose parameters are `ColoTensor`s. |
|
|
|
|
Args: |
|
|
|
|
path (_type_): _description_ |
|
|
|
|
epoch (int): _description_ |
|
|
|
|
rank (int): _description_ |
|
|
|
|
model (torch.nn.Module): _description_ |
|
|
|
|
optimizer (ColossalaiOptimizer, optional): _description_. Defaults to None. |
|
|
|
|
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None. |
|
|
|
|
path (str): directory to save the checkpoint files. |
|
|
|
|
epoch (int): the number of epoch |
|
|
|
|
model (torch.nn.Module): a torch module initialized by ColoInitContext |
|
|
|
|
optimizer (ColossalaiOptimizer, optional): optimizers. Defaults to None. |
|
|
|
|
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None. |
|
|
|
|
torch_load_kwargs: (dict, optional): The kwargs of torch.load inside the function |
|
|
|
|
load_state_dict_kwargs (dict, optional): The kwargs of load_state_dict inside the function |
|
|
|
|
""" |
|
|
|
|
rank = dist.get_rank() |
|
|
|
|
mapping = dict() |
|
|
|
@ -96,8 +97,8 @@ def load_checkpoint(path,
|
|
|
|
|
gather_tensor(p) |
|
|
|
|
|
|
|
|
|
if rank == 0: |
|
|
|
|
load_state = torch.load(path + '/epoch_{}_model.pth'.format(epoch), *args, **kwargs) |
|
|
|
|
model.load_state_dict(load_state['model']) |
|
|
|
|
load_state = torch.load(path + '/epoch_{}_model.pth'.format(epoch), **torch_load_kwargs) |
|
|
|
|
model.load_state_dict(load_state['model'], **load_state_dict_kwargs) |
|
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
|
# scatter loaded parameters |
|
|
|
@ -118,8 +119,8 @@ def load_checkpoint(path,
|
|
|
|
|
gather_tensor(t) |
|
|
|
|
|
|
|
|
|
if rank == 0: |
|
|
|
|
colo_checkpoint = torch.load(path + '/epoch_{}_optim.pth'.format(epoch), *args, **kwargs) |
|
|
|
|
optimizer.load_state_dict(colo_checkpoint['optim']) |
|
|
|
|
colo_checkpoint = torch.load(path + '/epoch_{}_optim.pth'.format(epoch), **torch_load_kwargs) |
|
|
|
|
optimizer.load_state_dict(colo_checkpoint['optim'], **load_state_dict_kwargs) |
|
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
|
for k, v in optimizer.state_dict()['state'].items(): |
|
|
|
|