[tensor] distributed checkpointing for parameters (#1240)

pull/1271/head
Jiarui Fang 2022-07-12 15:51:06 +08:00 committed by GitHub
parent 49114d8df0
commit c92f84fcdb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 72 additions and 155 deletions

View File

@ -143,10 +143,10 @@ class ColoTensor(torch.Tensor):
self._redistribute(dist_spec) self._redistribute(dist_spec)
def set_tensor_spec(self, dist_spec, compute_spec): def set_tensor_spec(self, dist_spec, compute_spec):
if dist_spec: if dist_spec is not None:
assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)}" assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)}"
self.set_dist_spec(dist_spec) self.set_dist_spec(dist_spec)
if compute_spec: if compute_spec is not None:
self.compute_spec = compute_spec self.compute_spec = compute_spec
def has_compute_pattern(self, compute_pattern): def has_compute_pattern(self, compute_pattern):

View File

@ -1,5 +1,5 @@
from enum import Enum from enum import Enum
from typing import List from typing import List, Optional
__all__ = ['replicate', 'shard'] __all__ = ['replicate', 'shard']

View File

@ -1,19 +1,6 @@
import torch import torch
import torch.nn as nn
import torch.distributed as dist import torch.distributed as dist
import collections from colossalai.tensor import ColoTensor, DistSpecManager
import inspect
from colossalai.utils.model.colo_init_context import colo_state_dict
def filter_dict(dict_to_filter, thing_with_kwargs):
sig = inspect.signature(thing_with_kwargs)
filter_keys = [param.name for param in sig.parameters.values() if param.kind == param.POSITIONAL_OR_KEYWORD]
filter_dict = {}
for filter_key in filter_keys:
if filter_key in dict_to_filter:
filter_dict[filter_key] = dict_to_filter[filter_key]
return filter_dict
def save_checkpoint(dire: str, def save_checkpoint(dire: str,
@ -32,21 +19,30 @@ def save_checkpoint(dire: str,
optimizer (torch.optim.Optimizer, optional): optimizers. Defaults to None. optimizer (torch.optim.Optimizer, optional): optimizers. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None. lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
""" """
model_state = {'epoch': epoch, 'model': model.state_dict()}
mapping = dict()
new_dict = dict()
# save the dist context about the tensors in a new dict, while still maintain the original dict.
for k, v in model.state_dict().items():
if isinstance(v, ColoTensor):
mapping[k] = (v.dist_spec, v.compute_spec)
new_dict[k] = v.to_replicate().detach()
if dist.get_rank() == 0: if dist.get_rank() == 0:
for k, v in new_dict.items():
if isinstance(v, ColoTensor):
assert v.is_replicate()
model_state = {'epoch': epoch, 'model': new_dict}
torch.save(model_state, dire + '/epoch_{}_model.pth'.format(epoch)) torch.save(model_state, dire + '/epoch_{}_model.pth'.format(epoch))
# TODO() If use tensor parallelism, optim_states contain SHARD ColoTensors. # delete the new dict
# 1. convert SHARD ColoTensor to REPLICATE del new_dict
# only rank 0 saves the REPLICATE tensors.
optim_state = {'epoch': epoch, 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict()}
torch.save(optim_state, dire + '/epoch_{}_optim_rank_{}.pth'.format(epoch, dist.get_rank()))
def load_checkpoint(dire, def load_checkpoint(dire,
epoch: int, epoch: int,
rank: int,
model: torch.nn.Module, model: torch.nn.Module,
optimizer: torch.optim.Optimizer = None, optimizer: torch.optim.Optimizer = None,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
@ -62,19 +58,18 @@ def load_checkpoint(dire,
optimizer (torch.optim.Optimizer, optional): _description_. Defaults to None. optimizer (torch.optim.Optimizer, optional): _description_. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None. lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None.
""" """
mapping = dict()
for k, v in model.named_parameters():
if isinstance(v, ColoTensor):
mapping[k] = (v.dist_spec, v.compute_spec)
v.to_replicate_()
model_state = torch.load(dire + '/epoch_{}_model.pth'.format(epoch)) model_state = torch.load(dire + '/epoch_{}_model.pth'.format(epoch))
model_state['model'] = collections.OrderedDict([(k.split('.', 1)[1], v) for k, v in model_state['model'].items()])
model.load_state_dict(model_state['model']) model.load_state_dict(model_state['model'])
optim_state = torch.load(dire + '/epoch_{}_optim_rank_{}.pth'.format(epoch, rank))
optimizer.load_state_dict(optim_state['optimizer']) # reset tensors to original dist spec.
lr_scheduler_dict = optim_state['lr_scheduler'] with DistSpecManager.no_grad():
if 'after_scheduler_type' in lr_scheduler_dict: for k, v in model.named_parameters():
after_scheduler_type = lr_scheduler_dict.pop('after_scheduler_type') if isinstance(v, ColoTensor):
after_scheduler_dict = lr_scheduler_dict.pop('after_scheduler_dict') v.set_tensor_spec(*mapping[k])
reload_scheduler = getattr(torch.optim.lr_scheduler, after_scheduler_type)
filtered_dict = filter_dict(after_scheduler_dict, reload_scheduler)
lr_scheduler_dict['after_scheduler'] = reload_scheduler(
optimizer,
**filtered_dict,
)
lr_scheduler.load_state_dict(lr_scheduler_dict)

View File

@ -1,13 +1,10 @@
from .utils import InsertPostInitMethodToModuleSubClasses from .utils import InsertPostInitMethodToModuleSubClasses
import torch import torch
from colossalai.tensor import ColoTensor, ColoParameter, distspec, ProcessGroup, ReplicaSpec from colossalai.tensor import ColoTensor, ColoParameter
from colossalai.nn.parallel.layers import register_colo_module, \ from colossalai.nn.parallel.layers import register_colo_module, \
ColoLinear, ColoEmbedding ColoLinear, ColoEmbedding
from copy import copy
from torch import nn from torch import nn
from typing import Iterator, Tuple, Union from typing import Iterator, Tuple, Union
from functools import partialmethod
# find named_params includes replica # find named_params includes replica
@ -34,47 +31,6 @@ def ColoModulize(module):
module._colo_visited = True module._colo_visited = True
def colo_state_dict(self, destination=None, prefix='', keep_vars=False, state_dict_func=None):
# build param to spec mapping
mapping1 = dict()
mapping2 = dict()
mapping3 = dict()
# gather all params
has_dist_parameter = False
with torch.no_grad():
for param in self.parameters():
if isinstance(param, ColoParameter):
has_dist_parameter = True
mapping1[id(param)] = copy(param.dist_spec)
mapping2[id(param)] = copy(param.compute_spec)
# TODO(jiaruifang) fixme, we should elegently handle the default PG in init context
if param.get_process_group() is None:
param.process_group = ProcessGroup()
param.set_dist_spec(distspec.replicate())
mapping3[id(param)] = param.get_process_group()
param.process_group = None
# TODO: fix when keep_vars = True
# when keep_vars = False, the state_dict_func will call detach to create
# new tensors, but when keep_vars = True, the recovery of spec will be reflected
# in the `ret`, such that the final state dict will still contain process group,
# raising exception as it is not serializable
assert not (keep_vars and has_dist_parameter), 'keep_vars cannot be True when there are distributed ColoParameters.'
ret = state_dict_func(self, destination, prefix, keep_vars)
# recover
with torch.no_grad():
for param in self.parameters():
param_id = id(param)
if param_id in mapping1:
dist_spec = mapping1[id(param)]
compute_spec = mapping2[id(param)]
param.process_group = mapping3[id(param)]
param.set_tensor_spec(dist_spec, compute_spec)
return ret
class ColoInitContext(InsertPostInitMethodToModuleSubClasses): class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
def __init__(self, lazy_memory_allocate: bool = False, device: torch.device = torch.device('cpu')): def __init__(self, lazy_memory_allocate: bool = False, device: torch.device = torch.device('cpu')):
@ -94,8 +50,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
register_colo_module(torch.nn.Embedding, ColoEmbedding()) register_colo_module(torch.nn.Embedding, ColoEmbedding())
def _pre_context_exec(self): def _pre_context_exec(self):
self.state_dict_func = nn.Module.state_dict pass
nn.Module.state_dict = partialmethod(colo_state_dict, state_dict_func=self.state_dict_func)
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs): def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
""" """

View File

@ -122,6 +122,19 @@ def _run_redistributed(world_size):
assert t1.is_replicate() assert t1.is_replicate()
def _run_set_tensor_spec(world_size):
if world_size != 4:
return
pg = ProcessGroup(tp_degree=2, dp_degree=2)
spec1 = ColoTensorSpec(pg)
t1 = ColoTensor.from_torch_tensor(torch.randn(2, 3, 4), spec1)
dist_spec2 = (ShardSpec([-1], [pg.tp_world_size()]), None)
assert t1.is_replicate()
t1.set_dist_spec(*dist_spec2)
assert t1.is_shard_1dcol()
def run_dist_tests(rank, world_size, port): def run_dist_tests(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
_run_tensor_shard_init(world_size) _run_tensor_shard_init(world_size)
@ -132,6 +145,7 @@ def run_dist_tests(rank, world_size, port):
_run_operand(world_size) _run_operand(world_size)
_run_wrapped_tensor_func() _run_wrapped_tensor_func()
_run_redistributed(world_size) _run_redistributed(world_size)
_run_set_tensor_spec(world_size)
@pytest.mark.dist @pytest.mark.dist

View File

@ -3,7 +3,6 @@ import os, shutil
import torch import torch
import torch.nn as nn import torch.nn as nn
import pytest import pytest
import copy
from functools import partial from functools import partial
import torch.multiprocessing as mp import torch.multiprocessing as mp
@ -104,7 +103,7 @@ def remove(path):
raise ValueError("file {} is not a file or dir.".format(path)) raise ValueError("file {} is not a file or dir.".format(path))
def run_checkpoint(init_spec_func, use_ddp, test_epoch, test_scheduler, pg): def run_checkpoint(init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg):
num_epoch = 5 num_epoch = 5
warmup_epoch = 2 warmup_epoch = 2
@ -112,31 +111,28 @@ def run_checkpoint(init_spec_func, use_ddp, test_epoch, test_scheduler, pg):
feature = 32 feature = 32
category = 16 category = 16
train_dataloader = DummyDataLoader(batch, category, feature, length=16)
with ColoInitContext(device=get_current_device()): with ColoInitContext(device=get_current_device()):
model = MLP(feature, category) model = MLP(feature, category)
with ColoInitContext(device=get_current_device()):
model_reload = MLP(feature, category) model_reload = MLP(feature, category)
model_ref = MLP(feature, category)
model = model.cuda() model = model.cuda()
model_reload = model_reload.cuda() model_reload = model_reload.cuda()
model_ref = model_ref.cuda()
if use_ddp: if use_ddp:
model = ColoDDP(model, pg) model = ColoDDP(model, pg)
model_reload = ColoDDP(model_reload, pg) model_reload = ColoDDP(model_reload, pg)
model_ref = ColoDDP(model_ref, pg)
init_spec_func(model, pg) init_spec_func(model, pg)
init_spec_func(model_ref, pg) if use_mp_reload:
init_spec_func(model_reload, pg)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
optimizer_reload = torch.optim.Adam(model_reload.parameters(), optimizer_reload = torch.optim.Adam(model_reload.parameters(),
lr=0.001, lr=0.001,
betas=(0.9, 0.999), betas=(0.9, 0.999),
eps=1e-08, eps=1e-08,
weight_decay=0) weight_decay=0)
optimizer_ref = torch.optim.Adam(model_ref.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
lr_scheduler = None lr_scheduler = None
if test_scheduler == 'colossalai_cosine_warmup': if test_scheduler == 'colossalai_cosine_warmup':
@ -154,91 +150,48 @@ def run_checkpoint(init_spec_func, use_ddp, test_epoch, test_scheduler, pg):
else: else:
raise TypeError(f"{test_scheduler} is invalid") raise TypeError(f"{test_scheduler} is invalid")
for epoch in range(0, num_epoch): save_checkpoint('./checkpoint', 0, model, optimizer, lr_scheduler)
if epoch <= test_epoch: dist.barrier()
for i, image_dict in enumerate(train_dataloader): load_checkpoint('./checkpoint', 0, model_reload, optimizer_reload, lr_scheduler_reload)
if use_ddp:
model.zero_grad()
else:
optimizer.zero_grad()
logits = model(image_dict['pixel_values'])
loss = criterion(logits, image_dict['label'])
if use_ddp:
model.backward(loss)
else:
loss.backward()
optimizer.step()
if epoch == test_epoch: # Since model is sharded, we merge them before param checking.
for ref_p, p in zip(model_ref.parameters(), model.parameters()): for p in model.parameters():
ref_p.data.copy_(p) p.to_replicate_()
optimizer_ref = copy.deepcopy(optimizer)
check_param_equal(model, model_ref) for p in model_reload.parameters():
save_checkpoint('./checkpoint', epoch, model, optimizer, lr_scheduler) p.to_replicate_()
dist.barrier()
else:
if epoch == test_epoch + 1:
load_checkpoint('./checkpoint', test_epoch, dist.get_rank(), model_reload, optimizer_reload,
lr_scheduler_reload)
init_spec_func(model_reload, pg)
for i, image_dict in enumerate(train_dataloader):
if use_ddp:
model_ref.zero_grad()
model_reload.zero_grad()
else:
optimizer_ref.zero_grad()
optimizer_reload.zero_grad()
logits_ref = model_ref(image_dict['pixel_values'])
logits_reload = model_reload(image_dict['pixel_values'])
loss_ref = criterion(logits_ref, image_dict['label'])
loss_reload = criterion(logits_reload, image_dict['label'])
if use_ddp:
model_ref.backward(loss_ref)
model_reload.backward(loss_reload)
else:
loss_ref.backward()
loss_reload.backward()
optimizer_ref.step()
optimizer_reload.step()
lr_scheduler.step()
check_param_equal(model_ref, model_reload) check_param_equal(model, model_reload)
def run_dist(rank, world_size, port, use_ddp, test_epoch, test_scheduler): def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler):
if use_ddp and world_size == 1: if use_ddp and world_size == 1:
return return
tp_world_size = world_size // 2 if use_ddp else world_size tp_world_size = world_size // 2 if use_ddp else world_size
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),)) config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
pg = ProcessGroup(tp_degree=world_size) pg = ProcessGroup(tp_degree=world_size)
run_checkpoint(init_1d_row_for_linear_weight_spec, run_checkpoint(init_1d_row_for_linear_weight_spec, use_ddp, use_mp_reload, test_scheduler=test_scheduler, pg=pg)
use_ddp,
test_epoch=test_epoch,
test_scheduler=test_scheduler,
pg=pg)
@pytest.mark.skip
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize('world_size', [4]) @pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.parametrize('use_ddp', [True]) @pytest.mark.parametrize('use_ddp', [True, False])
@pytest.mark.parametrize('test_epoch', [1, 2, 3]) @pytest.mark.parametrize('use_mp_reload', [True, False])
@pytest.mark.parametrize('test_scheduler', ['colossalai_cosine_warmup', 'torch_cosine', 'torch_lambda']) @pytest.mark.parametrize('test_scheduler', ['colossalai_cosine_warmup', 'torch_cosine', 'torch_lambda'])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_checkpoint(world_size, use_ddp, test_epoch, test_scheduler): def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler):
if not os.path.isdir('./checkpoint'): if not os.path.isdir('./checkpoint'):
os.mkdir('./checkpoint') os.mkdir('./checkpoint')
run_func = partial(run_dist, run_func = partial(run_dist,
world_size=world_size, world_size=world_size,
port=free_port(), port=free_port(),
use_ddp=use_ddp, use_ddp=use_ddp,
test_epoch=test_epoch, use_mp_reload=use_mp_reload,
test_scheduler=test_scheduler) test_scheduler=test_scheduler)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
remove('./checkpoint') remove('./checkpoint')
if __name__ == '__main__': if __name__ == '__main__':
test_checkpoint(4, True, 1, "colossalai_cosine_warmup") test_checkpoint(2, True, False, "torch_cosine")