[checkpoint] add ColoOptimizer checkpointing (#1316)

pull/1251/head^2
Jiarui Fang 2022-07-15 09:52:55 +08:00 committed by GitHub
parent 7c2634f4b3
commit 9e4c6449b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 74 additions and 15 deletions

View File

@ -1,6 +1,3 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
from torch import Tensor

View File

@ -1,12 +1,15 @@
import torch
import torch.distributed as dist
from colossalai.tensor import ColoTensor, DistSpecManager
from colossalai.nn.optimizer import ColossalaiOptimizer
from copy import copy
from typing import Optional
def save_checkpoint(dire: str,
epoch: int,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer = None,
optimizer: Optional[ColossalaiOptimizer] = None,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
*args,
**kwargs):
@ -16,7 +19,7 @@ def save_checkpoint(dire: str,
dire (str): directory to save the checkpoint files.
epoch (int): the number of epoch
model (torch.nn.Module): a torch module initialized by ColoInitContext
optimizer (torch.optim.Optimizer, optional): optimizers. Defaults to None.
optimizer (ColossalaiOptimizer, optional): optimizers. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
"""
@ -41,11 +44,21 @@ def save_checkpoint(dire: str,
# delete the new dict
del new_dict
optim_state_copy = copy(optimizer.state_dict())
for k, v in optim_state_copy['state'].items():
for n, t in v.items():
if isinstance(t, ColoTensor):
t.to_replicate_()
if dist.get_rank() == 0:
model_state = {'epoch': epoch, 'optim': optim_state_copy}
torch.save(model_state, dire + '/epoch_{}_optim.pth'.format(epoch))
del optim_state_copy
def load_checkpoint(dire,
epoch: int,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer = None,
optimizer: Optional[ColossalaiOptimizer] = None,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
*args,
**kwargs):
@ -56,7 +69,7 @@ def load_checkpoint(dire,
epoch (int): _description_
rank (int): _description_
model (torch.nn.Module): _description_
optimizer (torch.optim.Optimizer, optional): _description_. Defaults to None.
optimizer (ColossalaiOptimizer, optional): _description_. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None.
"""
@ -74,3 +87,24 @@ def load_checkpoint(dire,
for k, v in model.state_dict().items():
if isinstance(v, ColoTensor):
v.set_tensor_spec(*mapping[k])
del mapping
mapping = dict()
for k, v in optimizer.state_dict()['state'].items():
for n, t in v.items():
if isinstance(t, ColoTensor):
mapping[(k, n)] = (t.dist_spec, t.compute_spec)
t.to_replicate_()
colo_checkpoint = torch.load(dire + '/epoch_{}_optim.pth'.format(epoch))
optimizer.load_state_dict(colo_checkpoint['optim'])
for k, v in optimizer.state_dict()['state'].items():
for n, t in v.items():
if isinstance(t, ColoTensor):
# skip key not in mapping.
# For Adam, if it dose not execute step() once, there will be not exp_avg and exp_avg_sq in optimizer
if (k, n) not in mapping:
continue
t.set_tensor_spec(*mapping[(k, n)])

View File

@ -77,6 +77,18 @@ def remove(path):
raise ValueError("file {} is not a file or dir.".format(path))
def compare_optims(optim1, optim2):
state1 = optim1.state_dict()['state']
state2 = optim2.state_dict()['state']
for k, p1 in state1.items():
if k not in state2:
continue
p2 = state2[k]
if isinstance(p1, ColoTensor):
assert isinstance(p2, ColoTensor)
assert torch.allclose(p1.to_replicate_(), p2.to_replicate_(), rtol=1e-3, atol=1e-1)
def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg):
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
@ -117,7 +129,10 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
model_reload = model_reload.cuda()
model_reload.train()
colo_optimizer = ColossalaiOptimizer(torch.optim.SGD(model.named_parameters(), r=0.1))
opt_class = torch.optim.Adam
colo_optimizer = ColossalaiOptimizer(opt_class(model.parameters(), lr=0.1))
colo_optimizer_reload = ColossalaiOptimizer(opt_class(model_reload.parameters(), lr=0.1))
run_reload = False
for i, (data, label) in enumerate(train_dataloader):
@ -130,22 +145,35 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
# Bcast rank0 data to all processes
if criterion:
output = model(data)
output_reload = model_reload(data)
loss = criterion(output, label)
loss_reload = criterion(output_reload, label)
else:
output = model(data, label)
loss = output
loss = model(data, label)
loss_reload = model_reload(data, label)
loss.backward()
colo_optimizer.step()
loss_reload.backward()
if run_reload:
colo_optimizer_reload.zero_grad()
if criterion:
output_reload = model_reload(data)
loss_reload = criterion(output_reload, label)
else:
loss_reload = model_reload(data, label)
loss_reload.backward()
colo_optimizer_reload.step()
if i > 2:
break
if not os.path.isdir('./checkpoint') and rank == 0:
os.mkdir('./checkpoint')
save_checkpoint('./checkpoint', 0, model, None, None)
save_checkpoint('./checkpoint', 0, model, colo_optimizer, None)
dist.barrier()
load_checkpoint('./checkpoint', 0, model_reload, colo_optimizer_reload, None)
dist.barrier()
load_checkpoint('./checkpoint', 0, model_reload, None, None)
# Since model is sharded, we merge them before param checking.
for p in model.parameters():
@ -155,7 +183,7 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
p.to_replicate_()
check_param_equal(model, model_reload)
compare_optims(colo_optimizer, colo_optimizer_reload)
if rank == 0:
remove('./checkpoint')
@ -163,7 +191,7 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
pg = ProcessGroup(tp_degree=world_size)
for model_name in ['bert', 'simple_net']:
for model_name in ['simple_net', 'bert']:
_run_checkpoint(model_name,
init_1d_row_for_linear_weight_spec,
use_ddp,