mirror of https://github.com/hpcaitech/ColossalAI
[checkpoint] add ColoOptimizer checkpointing (#1316)
parent
7c2634f4b3
commit
9e4c6449b0
|
@ -1,6 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
|
|
@ -1,12 +1,15 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.tensor import ColoTensor, DistSpecManager
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from copy import copy
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def save_checkpoint(dire: str,
|
||||
epoch: int,
|
||||
model: torch.nn.Module,
|
||||
optimizer: torch.optim.Optimizer = None,
|
||||
optimizer: Optional[ColossalaiOptimizer] = None,
|
||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||
*args,
|
||||
**kwargs):
|
||||
|
@ -16,7 +19,7 @@ def save_checkpoint(dire: str,
|
|||
dire (str): directory to save the checkpoint files.
|
||||
epoch (int): the number of epoch
|
||||
model (torch.nn.Module): a torch module initialized by ColoInitContext
|
||||
optimizer (torch.optim.Optimizer, optional): optimizers. Defaults to None.
|
||||
optimizer (ColossalaiOptimizer, optional): optimizers. Defaults to None.
|
||||
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
|
||||
"""
|
||||
|
||||
|
@ -41,11 +44,21 @@ def save_checkpoint(dire: str,
|
|||
# delete the new dict
|
||||
del new_dict
|
||||
|
||||
optim_state_copy = copy(optimizer.state_dict())
|
||||
for k, v in optim_state_copy['state'].items():
|
||||
for n, t in v.items():
|
||||
if isinstance(t, ColoTensor):
|
||||
t.to_replicate_()
|
||||
if dist.get_rank() == 0:
|
||||
model_state = {'epoch': epoch, 'optim': optim_state_copy}
|
||||
torch.save(model_state, dire + '/epoch_{}_optim.pth'.format(epoch))
|
||||
del optim_state_copy
|
||||
|
||||
|
||||
def load_checkpoint(dire,
|
||||
epoch: int,
|
||||
model: torch.nn.Module,
|
||||
optimizer: torch.optim.Optimizer = None,
|
||||
optimizer: Optional[ColossalaiOptimizer] = None,
|
||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||
*args,
|
||||
**kwargs):
|
||||
|
@ -56,7 +69,7 @@ def load_checkpoint(dire,
|
|||
epoch (int): _description_
|
||||
rank (int): _description_
|
||||
model (torch.nn.Module): _description_
|
||||
optimizer (torch.optim.Optimizer, optional): _description_. Defaults to None.
|
||||
optimizer (ColossalaiOptimizer, optional): _description_. Defaults to None.
|
||||
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None.
|
||||
"""
|
||||
|
||||
|
@ -74,3 +87,24 @@ def load_checkpoint(dire,
|
|||
for k, v in model.state_dict().items():
|
||||
if isinstance(v, ColoTensor):
|
||||
v.set_tensor_spec(*mapping[k])
|
||||
|
||||
del mapping
|
||||
mapping = dict()
|
||||
|
||||
for k, v in optimizer.state_dict()['state'].items():
|
||||
for n, t in v.items():
|
||||
if isinstance(t, ColoTensor):
|
||||
mapping[(k, n)] = (t.dist_spec, t.compute_spec)
|
||||
t.to_replicate_()
|
||||
|
||||
colo_checkpoint = torch.load(dire + '/epoch_{}_optim.pth'.format(epoch))
|
||||
optimizer.load_state_dict(colo_checkpoint['optim'])
|
||||
|
||||
for k, v in optimizer.state_dict()['state'].items():
|
||||
for n, t in v.items():
|
||||
if isinstance(t, ColoTensor):
|
||||
# skip key not in mapping.
|
||||
# For Adam, if it dose not execute step() once, there will be not exp_avg and exp_avg_sq in optimizer
|
||||
if (k, n) not in mapping:
|
||||
continue
|
||||
t.set_tensor_spec(*mapping[(k, n)])
|
||||
|
|
|
@ -77,6 +77,18 @@ def remove(path):
|
|||
raise ValueError("file {} is not a file or dir.".format(path))
|
||||
|
||||
|
||||
def compare_optims(optim1, optim2):
|
||||
state1 = optim1.state_dict()['state']
|
||||
state2 = optim2.state_dict()['state']
|
||||
for k, p1 in state1.items():
|
||||
if k not in state2:
|
||||
continue
|
||||
p2 = state2[k]
|
||||
if isinstance(p1, ColoTensor):
|
||||
assert isinstance(p2, ColoTensor)
|
||||
assert torch.allclose(p1.to_replicate_(), p2.to_replicate_(), rtol=1e-3, atol=1e-1)
|
||||
|
||||
|
||||
def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg):
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
@ -117,7 +129,10 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
|
|||
model_reload = model_reload.cuda()
|
||||
model_reload.train()
|
||||
|
||||
colo_optimizer = ColossalaiOptimizer(torch.optim.SGD(model.named_parameters(), r=0.1))
|
||||
opt_class = torch.optim.Adam
|
||||
colo_optimizer = ColossalaiOptimizer(opt_class(model.parameters(), lr=0.1))
|
||||
colo_optimizer_reload = ColossalaiOptimizer(opt_class(model_reload.parameters(), lr=0.1))
|
||||
run_reload = False
|
||||
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
|
||||
|
@ -130,22 +145,35 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
|
|||
# Bcast rank0 data to all processes
|
||||
if criterion:
|
||||
output = model(data)
|
||||
output_reload = model_reload(data)
|
||||
loss = criterion(output, label)
|
||||
loss_reload = criterion(output_reload, label)
|
||||
else:
|
||||
output = model(data, label)
|
||||
loss = output
|
||||
loss = model(data, label)
|
||||
loss_reload = model_reload(data, label)
|
||||
|
||||
loss.backward()
|
||||
colo_optimizer.step()
|
||||
loss_reload.backward()
|
||||
|
||||
if run_reload:
|
||||
colo_optimizer_reload.zero_grad()
|
||||
if criterion:
|
||||
output_reload = model_reload(data)
|
||||
loss_reload = criterion(output_reload, label)
|
||||
else:
|
||||
loss_reload = model_reload(data, label)
|
||||
loss_reload.backward()
|
||||
colo_optimizer_reload.step()
|
||||
|
||||
if i > 2:
|
||||
break
|
||||
|
||||
if not os.path.isdir('./checkpoint') and rank == 0:
|
||||
os.mkdir('./checkpoint')
|
||||
save_checkpoint('./checkpoint', 0, model, None, None)
|
||||
save_checkpoint('./checkpoint', 0, model, colo_optimizer, None)
|
||||
dist.barrier()
|
||||
load_checkpoint('./checkpoint', 0, model_reload, colo_optimizer_reload, None)
|
||||
dist.barrier()
|
||||
load_checkpoint('./checkpoint', 0, model_reload, None, None)
|
||||
|
||||
# Since model is sharded, we merge them before param checking.
|
||||
for p in model.parameters():
|
||||
|
@ -155,7 +183,7 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
|
|||
p.to_replicate_()
|
||||
|
||||
check_param_equal(model, model_reload)
|
||||
|
||||
compare_optims(colo_optimizer, colo_optimizer_reload)
|
||||
if rank == 0:
|
||||
remove('./checkpoint')
|
||||
|
||||
|
@ -163,7 +191,7 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
|
|||
def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
for model_name in ['bert', 'simple_net']:
|
||||
for model_name in ['simple_net', 'bert']:
|
||||
_run_checkpoint(model_name,
|
||||
init_1d_row_for_linear_weight_spec,
|
||||
use_ddp,
|
||||
|
|
Loading…
Reference in New Issue