diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index 9aca524e9..0e35e694d 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -6,12 +6,13 @@ from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2 from colossalai.gemini.chunk import TensorState, Chunk from colossalai.tensor.param_op_hook import ParamOpHookManager from colossalai.gemini.gemini_mgr import GeminiManager -from typing import Dict, Iterable, List, Optional +from typing import Dict, Iterable, List, Optional, Set from colossalai.logging import get_dist_logger from collections import OrderedDict from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor import ProcessGroup as ColoProcessGroup from .reducer import Reducer + try: from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys except ImportError: @@ -84,6 +85,18 @@ class ColoDDP(torch.nn.Module): def named_parameters(self, prefix: str = '', recurse: bool = True): return self.module.named_parameters(prefix, recurse) + def named_buffers(self, prefix: str = '', recurse: bool = True): + return self.module.named_buffers(prefix, recurse) + + def named_children(self): + return self.module.named_children() + + def named_modules(self, + memo: Optional[Set[torch.nn.Module]] = None, + prefix: str = '', + remove_duplicate: bool = True): + return self.module.named_modules(memo, prefix, remove_duplicate) + def forward(self, *args, **kwargs): self.module.zero_grad(set_to_none=True) return self.module(*args, **kwargs) @@ -274,7 +287,7 @@ class ZeroDDP(ColoDDP): for tensor in chunk.get_tensors(): self.grads_device[tensor] = device - def state_dict(self, destination=None, prefix='', keep_vars=False): + def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True): r"""Returns a dictionary containing a whole state of the module. Both parameters and persistent buffers (e.g. running averages) are @@ -291,18 +304,22 @@ class ZeroDDP(ColoDDP): ['bias', 'weight'] """ + is_rank_0 = self.chunk_manager.process_group.dp_local_rank() == 0 + record_flag = (not only_rank_0) or is_rank_0 + if destination is None: destination = OrderedDict() destination._metadata = OrderedDict() destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version) - self._save_to_state_dict(destination, prefix, keep_vars) + self._save_to_state_dict(destination, prefix, keep_vars, record_flag) + for hook in self._state_dict_hooks.values(): hook_result = hook(self, destination, prefix, local_metadata) if hook_result is not None: destination = hook_result return destination - def _save_to_state_dict(self, destination, prefix, keep_vars): + def _save_to_state_dict(self, destination, prefix, keep_vars, record_flag: bool = True): r"""Saves module state to `destination` dictionary, containing a state of the module, but not its descendants. This is called on every submodule in :meth:`~torch.nn.Module.state_dict`. @@ -315,22 +332,36 @@ class ZeroDDP(ColoDDP): prefix (str): the prefix for parameters and buffers used in this module """ - chunks = self.chunk_manager.get_chunks(self.fp32_params) - chunks_orig_device_type = [] - for chunk in chunks: - chunks_orig_device_type.append(chunk.device_type) + # save parameters + param_to_save_data = dict() + chunk_list = self.chunk_manager.get_chunks(self.fp32_params) + for chunk in chunk_list: + # record the original device of the chunk + org_chunk_dev_typ = chunk.device_type self.chunk_manager.access_chunk(chunk) + + for tensor in chunk.get_tensors(): + rec_p = torch.empty([0]) + if record_flag: + rec_p = tensor.cpu() # move the whole tensor to CPU mem + assert tensor not in param_to_save_data + param_to_save_data[tensor] = rec_p + # release the actual memory of the chunk + self.chunk_manager.release_chunk(chunk) + if not chunk.is_empty and org_chunk_dev_typ == 'cpu': + self.chunk_manager.move_chunk(chunk, torch.device('cpu')) + for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params): if p is not None: - rec_p = fp32_p.clone() if fp32_p.device.type == 'cpu' else fp32_p.cpu() + assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name) + rec_p = param_to_save_data[fp32_p] destination[prefix + name] = rec_p if keep_vars else rec_p.detach() - for orig_dvice_type, chunk in zip(chunks_orig_device_type, chunks): - self.chunk_manager.release_chunk(chunk) - if not chunk.is_empty and orig_dvice_type == 'cpu': - self.chunk_manager.move_chunk(chunk, torch.device('cpu')) + + # save all buffers for name, buf in self.named_buffers(): if buf is not None and name not in self._non_persistent_buffers_set: destination[prefix + name] = buf if keep_vars else buf.detach() + # save extra states extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX if getattr(self.__class__, "get_extra_state", torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: @@ -368,7 +399,7 @@ class ZeroDDP(ColoDDP): state_dict = state_dict.copy() if metadata is not None: # mypy isn't aware that "_metadata" exists in state_dict - state_dict._metadata = metadata # type: ignore[attr-defined] + state_dict._metadata = metadata # type: ignore[attr-defined] prefix = '' local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) diff --git a/tests/test_ddp/test_ddp_state_dict.py b/tests/test_ddp/test_ddp_state_dict.py index 359dcafac..c13f7a72c 100644 --- a/tests/test_ddp/test_ddp_state_dict.py +++ b/tests/test_ddp/test_ddp_state_dict.py @@ -1,3 +1,5 @@ +import copy + import pytest import colossalai import torch @@ -11,9 +13,9 @@ from functools import partial from tests.components_to_test.registry import non_distributed_component_funcs from colossalai.nn.parallel import ZeroDDP, ColoDDP from colossalai.gemini.gemini_mgr import GeminiManager -from typing import Callable from collections import OrderedDict from colossalai.tensor import ProcessGroup, ColoParameter +from colossalai.testing import parameterize def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDict): @@ -25,7 +27,27 @@ def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDic else: temp_t2 = t2 - assert torch.equal(t1, temp_t2) + assert torch.equal(t1, temp_t2), "\t{}\n\t{}".format(t1, temp_t2) + + +def check_model_equal(model_a, model_b, allow_empty: bool = False, same_dtype: bool = True): + for (na, pa), (nb, pb) in zip(model_a.named_parameters(), model_b.named_parameters()): + assert na == nb + + if not allow_empty: + assert pa.storage().size() > 0 + assert pb.storage().size() > 0 + else: + if pa.storage().size() == 0 or pb.storage().size() == 0: + continue + + if same_dtype: + assert pa.dtype == pb.dtype + temp_pb = pb + else: + temp_pb = pb.to(pa.dtype) + + assert torch.equal(pa, temp_pb), "Parameter '{}' is not equal.\n {} {}".format(na, pa, pb) def init_ddp(module: torch.nn.Module) -> ColoDDP: @@ -33,22 +55,26 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP: return ColoDDP(module, process_group=pg) -def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False, use_zero: bool = False) -> ZeroDDP: +def init_ddpv2(module: torch.nn.Module, + use_chunk: bool = False, + use_zero: bool = False, + placement_policy: str = 'cuda') -> ZeroDDP: pg = ProcessGroup() chunk_size = ChunkManager.search_chunk_size(module, 64, 4) if use_chunk else None chunk_manager = ChunkManager(chunk_size, pg, enable_distributed_storage=use_zero) - gemini_manager = GeminiManager('cuda', chunk_manager) + gemini_manager = GeminiManager(placement_policy, chunk_manager) return ZeroDDP(module, gemini_manager) -def run_state_dict(ddp_init_func: Callable[[torch.nn.Module], ColoDDP]): - get_components_func = non_distributed_component_funcs.get_callable('nested_model') +def run_ddp_state_dict(): + get_components_func = non_distributed_component_funcs.get_callable('gpt2') model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() torch_model = model_builder().cuda() with ColoInitContext(device=get_current_device()): model = model_builder() - model = ddp_init_func(model) + model = init_ddp(model) torch_state_dict = torch_model.state_dict() + for param in model.parameters(): if isinstance(param, ColoParameter): assert param.get_process_group() is not None @@ -62,13 +88,44 @@ def run_state_dict(ddp_init_func: Callable[[torch.nn.Module], ColoDDP]): check_state_dict_equal(torch_state_dict, state_dict) +@parameterize('use_chunk', [False, True]) +@parameterize('placement_policy', ['cuda', 'cpu']) +@parameterize('use_zero', [False, True]) +@parameterize('only_rank_0', [False, True]) +def run_zero_state_dict(use_chunk, placement_policy, use_zero, only_rank_0): + get_components_func = non_distributed_component_funcs.get_callable('gpt2') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + torch_model = model_builder().cuda() + org_torch_model = copy.deepcopy(torch_model) + torch_state_dict = torch_model.state_dict() + + with ColoInitContext(device=get_current_device()): + model = model_builder() + model = init_ddpv2(model, use_chunk, use_zero, placement_policy) + + for param in model.parameters(): + if isinstance(param, ColoParameter): + assert param.get_process_group() is not None + + model.load_state_dict(torch_state_dict, strict=False) + check_model_equal(model, torch_model, allow_empty=True, same_dtype=False) + + for param in model.parameters(): + if isinstance(param, ColoParameter): + assert param.get_process_group() is not None + + pg = ProcessGroup() + state_dict = model.state_dict(only_rank_0=only_rank_0) + if not only_rank_0 or pg.dp_local_rank() == 0: + torch_model.load_state_dict(state_dict, strict=False) + check_model_equal(torch_model, org_torch_model, allow_empty=False, same_dtype=True) + + def run_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_state_dict(init_ddp) - run_state_dict(partial(init_ddpv2, use_chunk=False, use_zero=False)) - run_state_dict(partial(init_ddpv2, use_chunk=False, use_zero=True)) - run_state_dict(partial(init_ddpv2, use_chunk=True, use_zero=False)) - run_state_dict(partial(init_ddpv2, use_chunk=True, use_zero=True)) + run_ddp_state_dict() + run_zero_state_dict() @pytest.mark.dist