diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index cc2e8dee3..25adb212f 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -143,10 +143,10 @@ class ColoTensor(torch.Tensor): self._redistribute(dist_spec) def set_tensor_spec(self, dist_spec, compute_spec): - if dist_spec: + if dist_spec is not None: assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)}" self.set_dist_spec(dist_spec) - if compute_spec: + if compute_spec is not None: self.compute_spec = compute_spec def has_compute_pattern(self, compute_pattern): diff --git a/colossalai/tensor/distspec.py b/colossalai/tensor/distspec.py index 4796d420c..1e79f68ac 100644 --- a/colossalai/tensor/distspec.py +++ b/colossalai/tensor/distspec.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import List +from typing import List, Optional __all__ = ['replicate', 'shard'] diff --git a/colossalai/utils/checkpoint/module_checkpoint.py b/colossalai/utils/checkpoint/module_checkpoint.py index 564ccf4b8..3f61aed2f 100644 --- a/colossalai/utils/checkpoint/module_checkpoint.py +++ b/colossalai/utils/checkpoint/module_checkpoint.py @@ -1,19 +1,6 @@ import torch -import torch.nn as nn import torch.distributed as dist -import collections -import inspect -from colossalai.utils.model.colo_init_context import colo_state_dict - - -def filter_dict(dict_to_filter, thing_with_kwargs): - sig = inspect.signature(thing_with_kwargs) - filter_keys = [param.name for param in sig.parameters.values() if param.kind == param.POSITIONAL_OR_KEYWORD] - filter_dict = {} - for filter_key in filter_keys: - if filter_key in dict_to_filter: - filter_dict[filter_key] = dict_to_filter[filter_key] - return filter_dict +from colossalai.tensor import ColoTensor, DistSpecManager def save_checkpoint(dire: str, @@ -32,21 +19,30 @@ def save_checkpoint(dire: str, optimizer (torch.optim.Optimizer, optional): optimizers. Defaults to None. lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None. """ - model_state = {'epoch': epoch, 'model': model.state_dict()} + + mapping = dict() + new_dict = dict() + + # save the dist context about the tensors in a new dict, while still maintain the original dict. + for k, v in model.state_dict().items(): + if isinstance(v, ColoTensor): + mapping[k] = (v.dist_spec, v.compute_spec) + new_dict[k] = v.to_replicate().detach() + if dist.get_rank() == 0: + for k, v in new_dict.items(): + if isinstance(v, ColoTensor): + assert v.is_replicate() + + model_state = {'epoch': epoch, 'model': new_dict} torch.save(model_state, dire + '/epoch_{}_model.pth'.format(epoch)) - # TODO() If use tensor parallelism, optim_states contain SHARD ColoTensors. - # 1. convert SHARD ColoTensor to REPLICATE - # only rank 0 saves the REPLICATE tensors. - optim_state = {'epoch': epoch, 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict()} - - torch.save(optim_state, dire + '/epoch_{}_optim_rank_{}.pth'.format(epoch, dist.get_rank())) + # delete the new dict + del new_dict def load_checkpoint(dire, epoch: int, - rank: int, model: torch.nn.Module, optimizer: torch.optim.Optimizer = None, lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, @@ -62,19 +58,18 @@ def load_checkpoint(dire, optimizer (torch.optim.Optimizer, optional): _description_. Defaults to None. lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None. """ + + mapping = dict() + for k, v in model.named_parameters(): + if isinstance(v, ColoTensor): + mapping[k] = (v.dist_spec, v.compute_spec) + v.to_replicate_() + model_state = torch.load(dire + '/epoch_{}_model.pth'.format(epoch)) - model_state['model'] = collections.OrderedDict([(k.split('.', 1)[1], v) for k, v in model_state['model'].items()]) model.load_state_dict(model_state['model']) - optim_state = torch.load(dire + '/epoch_{}_optim_rank_{}.pth'.format(epoch, rank)) - optimizer.load_state_dict(optim_state['optimizer']) - lr_scheduler_dict = optim_state['lr_scheduler'] - if 'after_scheduler_type' in lr_scheduler_dict: - after_scheduler_type = lr_scheduler_dict.pop('after_scheduler_type') - after_scheduler_dict = lr_scheduler_dict.pop('after_scheduler_dict') - reload_scheduler = getattr(torch.optim.lr_scheduler, after_scheduler_type) - filtered_dict = filter_dict(after_scheduler_dict, reload_scheduler) - lr_scheduler_dict['after_scheduler'] = reload_scheduler( - optimizer, - **filtered_dict, - ) - lr_scheduler.load_state_dict(lr_scheduler_dict) + + # reset tensors to original dist spec. + with DistSpecManager.no_grad(): + for k, v in model.named_parameters(): + if isinstance(v, ColoTensor): + v.set_tensor_spec(*mapping[k]) diff --git a/colossalai/utils/model/colo_init_context.py b/colossalai/utils/model/colo_init_context.py index 1766eca0a..a54740221 100644 --- a/colossalai/utils/model/colo_init_context.py +++ b/colossalai/utils/model/colo_init_context.py @@ -1,13 +1,10 @@ from .utils import InsertPostInitMethodToModuleSubClasses import torch -from colossalai.tensor import ColoTensor, ColoParameter, distspec, ProcessGroup, ReplicaSpec - +from colossalai.tensor import ColoTensor, ColoParameter from colossalai.nn.parallel.layers import register_colo_module, \ ColoLinear, ColoEmbedding -from copy import copy from torch import nn from typing import Iterator, Tuple, Union -from functools import partialmethod # find named_params includes replica @@ -34,47 +31,6 @@ def ColoModulize(module): module._colo_visited = True -def colo_state_dict(self, destination=None, prefix='', keep_vars=False, state_dict_func=None): - # build param to spec mapping - mapping1 = dict() - mapping2 = dict() - mapping3 = dict() - # gather all params - has_dist_parameter = False - with torch.no_grad(): - for param in self.parameters(): - if isinstance(param, ColoParameter): - has_dist_parameter = True - mapping1[id(param)] = copy(param.dist_spec) - mapping2[id(param)] = copy(param.compute_spec) - # TODO(jiaruifang) fixme, we should elegently handle the default PG in init context - if param.get_process_group() is None: - param.process_group = ProcessGroup() - param.set_dist_spec(distspec.replicate()) - mapping3[id(param)] = param.get_process_group() - param.process_group = None - - # TODO: fix when keep_vars = True - # when keep_vars = False, the state_dict_func will call detach to create - # new tensors, but when keep_vars = True, the recovery of spec will be reflected - # in the `ret`, such that the final state dict will still contain process group, - # raising exception as it is not serializable - assert not (keep_vars and has_dist_parameter), 'keep_vars cannot be True when there are distributed ColoParameters.' - - ret = state_dict_func(self, destination, prefix, keep_vars) - - # recover - with torch.no_grad(): - for param in self.parameters(): - param_id = id(param) - if param_id in mapping1: - dist_spec = mapping1[id(param)] - compute_spec = mapping2[id(param)] - param.process_group = mapping3[id(param)] - param.set_tensor_spec(dist_spec, compute_spec) - return ret - - class ColoInitContext(InsertPostInitMethodToModuleSubClasses): def __init__(self, lazy_memory_allocate: bool = False, device: torch.device = torch.device('cpu')): @@ -94,8 +50,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): register_colo_module(torch.nn.Embedding, ColoEmbedding()) def _pre_context_exec(self): - self.state_dict_func = nn.Module.state_dict - nn.Module.state_dict = partialmethod(colo_state_dict, state_dict_func=self.state_dict_func) + pass def _post_init_method(self, module: torch.nn.Module, *args, **kwargs): """ diff --git a/tests/test_tensor/test_tensor.py b/tests/test_tensor/test_tensor.py index addbf304d..ad9547ef6 100644 --- a/tests/test_tensor/test_tensor.py +++ b/tests/test_tensor/test_tensor.py @@ -122,6 +122,19 @@ def _run_redistributed(world_size): assert t1.is_replicate() +def _run_set_tensor_spec(world_size): + if world_size != 4: + return + pg = ProcessGroup(tp_degree=2, dp_degree=2) + spec1 = ColoTensorSpec(pg) + t1 = ColoTensor.from_torch_tensor(torch.randn(2, 3, 4), spec1) + + dist_spec2 = (ShardSpec([-1], [pg.tp_world_size()]), None) + assert t1.is_replicate() + t1.set_dist_spec(*dist_spec2) + assert t1.is_shard_1dcol() + + def run_dist_tests(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') _run_tensor_shard_init(world_size) @@ -132,6 +145,7 @@ def run_dist_tests(rank, world_size, port): _run_operand(world_size) _run_wrapped_tensor_func() _run_redistributed(world_size) + _run_set_tensor_spec(world_size) @pytest.mark.dist diff --git a/tests/test_utils/test_colo_checkpoint.py b/tests/test_utils/test_colo_checkpoint.py index 969c8b352..0581d7bf0 100644 --- a/tests/test_utils/test_colo_checkpoint.py +++ b/tests/test_utils/test_colo_checkpoint.py @@ -3,7 +3,6 @@ import os, shutil import torch import torch.nn as nn import pytest -import copy from functools import partial import torch.multiprocessing as mp @@ -104,7 +103,7 @@ def remove(path): raise ValueError("file {} is not a file or dir.".format(path)) -def run_checkpoint(init_spec_func, use_ddp, test_epoch, test_scheduler, pg): +def run_checkpoint(init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg): num_epoch = 5 warmup_epoch = 2 @@ -112,31 +111,28 @@ def run_checkpoint(init_spec_func, use_ddp, test_epoch, test_scheduler, pg): feature = 32 category = 16 - train_dataloader = DummyDataLoader(batch, category, feature, length=16) with ColoInitContext(device=get_current_device()): model = MLP(feature, category) + + with ColoInitContext(device=get_current_device()): model_reload = MLP(feature, category) - model_ref = MLP(feature, category) model = model.cuda() model_reload = model_reload.cuda() - model_ref = model_ref.cuda() if use_ddp: model = ColoDDP(model, pg) model_reload = ColoDDP(model_reload, pg) - model_ref = ColoDDP(model_ref, pg) init_spec_func(model, pg) - init_spec_func(model_ref, pg) + if use_mp_reload: + init_spec_func(model_reload, pg) - criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) optimizer_reload = torch.optim.Adam(model_reload.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) - optimizer_ref = torch.optim.Adam(model_ref.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) lr_scheduler = None if test_scheduler == 'colossalai_cosine_warmup': @@ -154,91 +150,48 @@ def run_checkpoint(init_spec_func, use_ddp, test_epoch, test_scheduler, pg): else: raise TypeError(f"{test_scheduler} is invalid") - for epoch in range(0, num_epoch): - if epoch <= test_epoch: - for i, image_dict in enumerate(train_dataloader): - if use_ddp: - model.zero_grad() - else: - optimizer.zero_grad() - logits = model(image_dict['pixel_values']) - loss = criterion(logits, image_dict['label']) - if use_ddp: - model.backward(loss) - else: - loss.backward() - optimizer.step() + save_checkpoint('./checkpoint', 0, model, optimizer, lr_scheduler) + dist.barrier() + load_checkpoint('./checkpoint', 0, model_reload, optimizer_reload, lr_scheduler_reload) - if epoch == test_epoch: - for ref_p, p in zip(model_ref.parameters(), model.parameters()): - ref_p.data.copy_(p) - optimizer_ref = copy.deepcopy(optimizer) + # Since model is sharded, we merge them before param checking. + for p in model.parameters(): + p.to_replicate_() - check_param_equal(model, model_ref) - save_checkpoint('./checkpoint', epoch, model, optimizer, lr_scheduler) - dist.barrier() - else: - if epoch == test_epoch + 1: - load_checkpoint('./checkpoint', test_epoch, dist.get_rank(), model_reload, optimizer_reload, - lr_scheduler_reload) - init_spec_func(model_reload, pg) - for i, image_dict in enumerate(train_dataloader): - if use_ddp: - model_ref.zero_grad() - model_reload.zero_grad() - else: - optimizer_ref.zero_grad() - optimizer_reload.zero_grad() - logits_ref = model_ref(image_dict['pixel_values']) - logits_reload = model_reload(image_dict['pixel_values']) - loss_ref = criterion(logits_ref, image_dict['label']) - loss_reload = criterion(logits_reload, image_dict['label']) - if use_ddp: - model_ref.backward(loss_ref) - model_reload.backward(loss_reload) - else: - loss_ref.backward() - loss_reload.backward() - optimizer_ref.step() - optimizer_reload.step() - lr_scheduler.step() + for p in model_reload.parameters(): + p.to_replicate_() - check_param_equal(model_ref, model_reload) + check_param_equal(model, model_reload) -def run_dist(rank, world_size, port, use_ddp, test_epoch, test_scheduler): +def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler): if use_ddp and world_size == 1: return tp_world_size = world_size // 2 if use_ddp else world_size config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),)) colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') pg = ProcessGroup(tp_degree=world_size) - run_checkpoint(init_1d_row_for_linear_weight_spec, - use_ddp, - test_epoch=test_epoch, - test_scheduler=test_scheduler, - pg=pg) + run_checkpoint(init_1d_row_for_linear_weight_spec, use_ddp, use_mp_reload, test_scheduler=test_scheduler, pg=pg) -@pytest.mark.skip @pytest.mark.dist -@pytest.mark.parametrize('world_size', [4]) -@pytest.mark.parametrize('use_ddp', [True]) -@pytest.mark.parametrize('test_epoch', [1, 2, 3]) +@pytest.mark.parametrize('world_size', [1, 2]) +@pytest.mark.parametrize('use_ddp', [True, False]) +@pytest.mark.parametrize('use_mp_reload', [True, False]) @pytest.mark.parametrize('test_scheduler', ['colossalai_cosine_warmup', 'torch_cosine', 'torch_lambda']) @rerun_if_address_is_in_use() -def test_checkpoint(world_size, use_ddp, test_epoch, test_scheduler): +def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler): if not os.path.isdir('./checkpoint'): os.mkdir('./checkpoint') run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp, - test_epoch=test_epoch, + use_mp_reload=use_mp_reload, test_scheduler=test_scheduler) mp.spawn(run_func, nprocs=world_size) remove('./checkpoint') if __name__ == '__main__': - test_checkpoint(4, True, 1, "colossalai_cosine_warmup") + test_checkpoint(2, True, False, "torch_cosine")