[checkpoint] use gather_tensor in checkpoint and update its unit test (#1339)

pull/1340/head
HELSON 2022-07-19 14:15:28 +08:00 committed by GitHub
parent f3ce7b8336
commit f92c100ddd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 209 additions and 91 deletions

View File

@ -262,7 +262,7 @@ class ColoTensor(torch.Tensor):
replicated_t = self.redistribute(dist_spec=ReplicaSpec()) replicated_t = self.redistribute(dist_spec=ReplicaSpec())
return replicated_t.view(*args) return replicated_t.view(*args)
def size_global(self, args: Optional[int] = None): def size_global(self, args: Optional[int] = None) -> torch.Size:
"""override the torch buildin size() """override the torch buildin size()
the shape passed in must be in a replicate placement. the shape passed in must be in a replicate placement.
Returns: Returns:

View File

@ -141,9 +141,18 @@ class ProcessGroup:
def rank(self): def rank(self):
return self._rank return self._rank
def ranks_in_group(self):
return self._rank_list
def world_size(self): def world_size(self):
return self._world_size return self._world_size
def tp_rank_list(self):
return self._tp_rank_list
def dp_rank_list(self):
return self._dp_rank_list
def tp_local_rank(self): def tp_local_rank(self):
return self._rank % self._tp_degree return self._rank % self._tp_degree

View File

@ -1,8 +1,8 @@
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.tensor import ColoTensor, DistSpecManager from colossalai.tensor import ColoTensor
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer
from copy import copy from colossalai.utils.checkpoint.utils import gather_tensor, scatter_tensor
from typing import Optional from typing import Optional
@ -22,37 +22,52 @@ def save_checkpoint(dire: str,
optimizer (ColossalaiOptimizer, optional): optimizers. Defaults to None. optimizer (ColossalaiOptimizer, optional): optimizers. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None. lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
""" """
rank = dist.get_rank()
model_state = model.state_dict()
# save the dist context about the tensors in a new dict, while still maintain the original dict.
for k, v in model_state.items():
if isinstance(v, ColoTensor):
gather_tensor(v) # gather shared tensors to rank0
# don't recover tensors in rank0, since the dict is only a copy of model
if rank == 0:
# sanity check
for k, v in model_state.items():
if isinstance(v, ColoTensor):
assert v.save_ready
assert v.is_replicate()
delattr(v, 'save_ready')
# model saving
save_state = {'epoch': epoch, 'model': model_state}
torch.save(save_state, dire + '/epoch_{}_model.pth'.format(epoch))
# delete old dicts
del model_state
# synchronize all the processes
dist.barrier()
mapping = dict() mapping = dict()
new_dict = dict() optim_state = optimizer.state_dict()
for k, v in optim_state['state'].items():
# save the dist context about the tensors in a new dict, while still maintain the original dict.
for k, v in model.state_dict().items():
if isinstance(v, ColoTensor):
mapping[k] = (v.dist_spec, v.compute_spec)
new_dict[k] = v.to_replicate().detach()
else:
new_dict[k] = v
if dist.get_rank() == 0:
for k, v in new_dict.items():
if isinstance(v, ColoTensor):
assert v.is_replicate()
model_state = {'epoch': epoch, 'model': new_dict}
torch.save(model_state, dire + '/epoch_{}_model.pth'.format(epoch))
# delete the new dict
del new_dict
optim_state_copy = copy(optimizer.state_dict())
for k, v in optim_state_copy['state'].items():
for n, t in v.items(): for n, t in v.items():
if isinstance(t, ColoTensor): if isinstance(t, ColoTensor):
t.to_replicate_() mapping[(k, n)] = t.dist_spec
if dist.get_rank() == 0: gather_tensor(t)
model_state = {'epoch': epoch, 'optim': optim_state_copy}
torch.save(model_state, dire + '/epoch_{}_optim.pth'.format(epoch)) if rank == 0:
del optim_state_copy save_state = {'epoch': epoch, 'optim': optim_state}
torch.save(save_state, dire + '/epoch_{}_optim.pth'.format(epoch))
# recover colo tensors in rank0
for k, v in optimizer.state_dict()['state'].items():
for n, t in v.items():
if isinstance(t, ColoTensor):
assert hasattr(t, 'save_ready')
t.set_dist_spec(mapping[(k, n)])
delattr(t, 'save_ready')
del optim_state
del mapping
dist.barrier()
def load_checkpoint(dire, def load_checkpoint(dire,
@ -72,39 +87,42 @@ def load_checkpoint(dire,
optimizer (ColossalaiOptimizer, optional): _description_. Defaults to None. optimizer (ColossalaiOptimizer, optional): _description_. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None. lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None.
""" """
rank = dist.get_rank()
mapping = dict() mapping = dict()
for k, v in model.state_dict().items(): for n, p in model.named_parameters():
if isinstance(v, ColoTensor): if isinstance(p, ColoTensor):
mapping[k] = (v.dist_spec, v.compute_spec) mapping[n] = p.dist_spec
v.to_replicate_() gather_tensor(p)
model_state = torch.load(dire + '/epoch_{}_model.pth'.format(epoch)) if rank == 0:
model.load_state_dict(model_state['model']) load_state = torch.load(dire + '/epoch_{}_model.pth'.format(epoch))
model.load_state_dict(load_state['model'])
# reset tensors to original dist spec. dist.barrier()
with DistSpecManager.no_grad():
for k, v in model.state_dict().items():
if isinstance(v, ColoTensor):
v.set_tensor_spec(*mapping[k])
# scatter loaded parameters
for n, p in model.named_parameters():
if isinstance(p, ColoTensor):
scatter_tensor(p, mapping[n])
if rank == 0:
assert hasattr(p, 'save_ready')
delattr(p, 'save_ready')
del mapping del mapping
mapping = dict()
mapping = dict()
for k, v in optimizer.state_dict()['state'].items(): for k, v in optimizer.state_dict()['state'].items():
for n, t in v.items(): for n, t in v.items():
if isinstance(t, ColoTensor): if isinstance(t, ColoTensor):
mapping[(k, n)] = (t.dist_spec, t.compute_spec) mapping[(k, n)] = t.dist_spec
t.to_replicate_() gather_tensor(t)
if rank == 0:
colo_checkpoint = torch.load(dire + '/epoch_{}_optim.pth'.format(epoch)) colo_checkpoint = torch.load(dire + '/epoch_{}_optim.pth'.format(epoch))
optimizer.load_state_dict(colo_checkpoint['optim']) optimizer.load_state_dict(colo_checkpoint['optim'])
dist.barrier()
for k, v in optimizer.state_dict()['state'].items(): for k, v in optimizer.state_dict()['state'].items():
for n, t in v.items(): for n, t in v.items():
if isinstance(t, ColoTensor): if isinstance(t, ColoTensor):
# skip key not in mapping. scatter_tensor(t, mapping[(k, n)])
# For Adam, if it dose not execute step() once, there will be not exp_avg and exp_avg_sq in optimizer
if (k, n) not in mapping: del mapping
continue
t.set_tensor_spec(*mapping[(k, n)])

View File

@ -0,0 +1,50 @@
import torch
import torch.distributed as dist
from colossalai.tensor import ColoTensor, ColoTensorSpec
from colossalai.tensor.distspec import _DistSpec
def gather_tensor(colo_tensor: ColoTensor) -> None:
"""Make colo_tensor replicated when the rank is 0
"""
if not colo_tensor.is_replicate():
pg = colo_tensor.get_process_group()
# for the group which contains rank 0
if pg.tp_rank_list()[0] == 0:
old_dist_spec = colo_tensor.dist_spec
colo_tensor.to_replicate_()
if dist.get_rank() != 0:
colo_tensor.set_dist_spec(old_dist_spec)
# synchronize all processes for unexpected problems
dist.barrier()
if dist.get_rank() == 0:
setattr(colo_tensor, 'save_ready', True) # set saving signitrue
def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None:
"""Reversal operation of `gather_tensor`.
"""
if dist_spec.placement == 'r':
dist.broadcast(colo_tensor.data, 0)
else:
global_size = colo_tensor.size_global()
if dist.get_rank() == 0:
entire_data = colo_tensor.data
else:
entire_data = torch.empty(global_size, device=colo_tensor.device)
dist.broadcast(entire_data, 0)
if dist.get_rank() == 0:
colo_tensor.set_dist_spec(dist_spec)
else:
rep_tensor = ColoTensor(entire_data, ColoTensorSpec(
pg=colo_tensor.get_process_group(),
compute_attr=colo_tensor.compute_spec))
rep_tensor.set_dist_spec(dist_spec)
with torch.no_grad():
colo_tensor.data.copy_(rep_tensor.data)
# synchronize all processes for unexpected problems
dist.barrier()

View File

@ -1,6 +1,7 @@
import os, shutil import os, shutil
import torch import torch
import pytest import pytest
from copy import deepcopy
from functools import partial from functools import partial
import torch.multiprocessing as mp import torch.multiprocessing as mp
@ -15,8 +16,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup, DistSpecManager, ReplicaSpec from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup
from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer
@ -63,8 +63,8 @@ def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup):
def check_param_equal(model, torch_model): def check_param_equal(model, torch_model):
for p, torch_p in zip(model.parameters(), torch_model.parameters()): for (n, p), (tn, tp) in zip(model.named_parameters(), torch_model.named_parameters()):
assert torch.allclose(torch_p, p, rtol=1e-3, atol=1e-1) assert torch.all(p.data == tp.data), "{} went wrong.\n {} vs {}\n{}".format(n, p, tp, p.shape)
def remove(path): def remove(path):
@ -84,9 +84,13 @@ def compare_optims(optim1, optim2):
if k not in state2: if k not in state2:
continue continue
p2 = state2[k] p2 = state2[k]
if isinstance(p1, ColoTensor): for n, t1 in p1.items():
assert isinstance(p2, ColoTensor) if n not in p2:
assert torch.allclose(p1.to_replicate_(), p2.to_replicate_(), rtol=1e-3, atol=1e-1) continue
t2 = p2[n]
if isinstance(t1, ColoTensor):
assert isinstance(t2, ColoTensor)
assert torch.allclose(t1, t2, rtol=0, atol=0)
def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg): def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg):
@ -99,7 +103,6 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
# set_seed(1) # set_seed(1)
with ColoInitContext(device=get_current_device()): with ColoInitContext(device=get_current_device()):
model = model_builder(checkpoint=True) model = model_builder(checkpoint=True)
model_reload = model_builder(checkpoint=True)
if use_mp_reload: if use_mp_reload:
if 'bert' == model_name: if 'bert' == model_name:
@ -119,25 +122,26 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
elif 'token_type_embeddings' in name and 'weight' in name: elif 'token_type_embeddings' in name and 'weight' in name:
init_1d_col_embedding(p, pg) init_1d_col_embedding(p, pg)
elif p.process_group.tp_world_size() == 1: elif p.process_group.tp_world_size() == 1:
p.redistribute(ReplicaSpec(), pg) p.set_process_group(pg)
elif "simple_net" == model_name: elif "simple_net" == model_name:
init_spec_func(model, pg) init_spec_func(model, pg)
model_reload = deepcopy(model)
model = model.cuda() model = model.cuda()
model.train() model.eval()
model_reload = model_reload.cuda() model_reload = model_reload.cuda()
model_reload.train() model_reload.eval()
opt_class = torch.optim.Adam opt_class = torch.optim.Adam
colo_optimizer = ColossalaiOptimizer(opt_class(model.parameters(), lr=0.1)) colo_optimizer = ColossalaiOptimizer(opt_class(model.parameters(), lr=0.1))
colo_optimizer_reload = ColossalaiOptimizer(opt_class(model_reload.parameters(), lr=0.1)) colo_optimizer_reload = ColossalaiOptimizer(opt_class(model_reload.parameters(), lr=0.1))
run_reload = False
for i, (data, label) in enumerate(train_dataloader): for i, (data, label) in enumerate(train_dataloader):
# Zero grad # Zero grad
colo_optimizer.zero_grad() colo_optimizer.zero_grad()
colo_optimizer_reload.zero_grad()
data = data.to(get_current_device()) data = data.to(get_current_device())
label = label.to(get_current_device()) label = label.to(get_current_device())
@ -155,14 +159,7 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
loss.backward() loss.backward()
loss_reload.backward() loss_reload.backward()
if run_reload: colo_optimizer.step()
colo_optimizer_reload.zero_grad()
if criterion:
output_reload = model_reload(data)
loss_reload = criterion(output_reload, label)
else:
loss_reload = model_reload(data, label)
loss_reload.backward()
colo_optimizer_reload.step() colo_optimizer_reload.step()
if i > 2: if i > 2:
@ -170,28 +167,25 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
if not os.path.isdir('./checkpoint') and rank == 0: if not os.path.isdir('./checkpoint') and rank == 0:
os.mkdir('./checkpoint') os.mkdir('./checkpoint')
dist.barrier()
save_checkpoint('./checkpoint', 0, model, colo_optimizer, None) save_checkpoint('./checkpoint', 0, model, colo_optimizer, None)
dist.barrier()
load_checkpoint('./checkpoint', 0, model_reload, colo_optimizer_reload, None) load_checkpoint('./checkpoint', 0, model_reload, colo_optimizer_reload, None)
dist.barrier()
# Since model is sharded, we merge them before param checking.
for p in model.parameters():
p.to_replicate_()
for p in model_reload.parameters():
p.to_replicate_()
check_param_equal(model, model_reload) check_param_equal(model, model_reload)
compare_optims(colo_optimizer, colo_optimizer_reload) compare_optims(colo_optimizer, colo_optimizer_reload)
if rank == 0: if rank == 0:
remove('./checkpoint') remove('./checkpoint')
dist.barrier()
def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler): def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
pg = ProcessGroup(tp_degree=world_size) pg = ProcessGroup(tp_degree=world_size)
for model_name in ['simple_net', 'bert']: # TODO(haichen) add BERT in the test
# the data loader of BERT is in DDP mode, causing the input data is not replicated in the TP context
for model_name in ['simple_net']:
_run_checkpoint(model_name, _run_checkpoint(model_name,
init_1d_row_for_linear_weight_spec, init_1d_row_for_linear_weight_spec,
use_ddp, use_ddp,

View File

@ -0,0 +1,47 @@
import torch
import pytest
from functools import partial
import torch.multiprocessing as mp
import torch.distributed as dist
import colossalai
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup, ColoTensorSpec
from colossalai.utils.checkpoint.utils import gather_tensor, scatter_tensor
from tests.test_tensor._utils import tensor_shard_equal
def run_dist(rank, world_size, port, dp_degree, tp_degree):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
pg = ProcessGroup(dp_degree=dp_degree, tp_degree=tp_degree)
x = torch.randn(4, 4, device=get_current_device())
param = ColoTensor(torch.nn.Parameter(x), spec=ColoTensorSpec(pg))
spec = ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)
param.set_tensor_spec(*spec)
gather_tensor(param)
if dist.get_rank() == 0:
assert torch.allclose(x, param.data, rtol=0, atol=0)
else:
assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size())
dist.barrier()
scatter_tensor(param, spec[0])
assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size())
assert param.requires_grad is True
dist.barrier()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [4])
@rerun_if_address_is_in_use()
def test_checkpoint(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port(), dp_degree=2, tp_degree=world_size // 2)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_checkpoint(world_size=4)