From 9880fd2cd8b3b24c28333926338656a06dd170f3 Mon Sep 17 00:00:00 2001 From: eric8607242 Date: Mon, 9 Jan 2023 14:35:14 +0800 Subject: [PATCH] Fix state_dict key missing issue of the ZeroDDP (#2363) * Fix state_dict output for ZeroDDP duplicated parameters * Rewrite state_dict based on get_static_torch_model * Modify get_static_torch_model to be compatible with the lower version (ZeroDDP) --- colossalai/nn/parallel/data_parallel.py | 37 +++++++++++++++++++++---- colossalai/nn/parallel/utils.py | 16 +++++------ 2 files changed, 39 insertions(+), 14 deletions(-) diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index e3bb83347..8fd08db95 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -18,6 +18,7 @@ from colossalai.utils import get_current_device from colossalai.zero.utils.gemini_hook import GeminiZeROHook from .reducer import Reducer +from .utils import get_static_torch_model try: from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys @@ -251,6 +252,7 @@ class ZeroDDP(ColoDDP): pin_memory=pin_memory) self.fp32_params.append(fp32_p) self.grads_device[p] = self.gemini_manager.default_device + self.chunk_manager.close_all_groups() self._cast_buffers() @@ -331,12 +333,11 @@ class ZeroDDP(ColoDDP): for tensor in chunk.get_tensors(): self.grads_device[tensor] = device - def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True): - r"""Returns a dictionary containing a whole state of the module. - - Both parameters and persistent buffers (e.g. running averages) are - included. Keys are corresponding parameter and buffer names. - Parameters and buffers set to ``None`` are not included. + def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True, strict: bool = True): + r""" + Args: + strict (bool): whether to reture the whole model state + as the original pytorch state_dict() Returns: dict: @@ -346,7 +347,31 @@ class ZeroDDP(ColoDDP): >>> module.state_dict().keys() ['bias', 'weight'] + """ + if strict: + return get_static_torch_model(zero_ddp_model=self, device=get_current_device(), + only_rank_0=only_rank_0).state_dict(destination=destination, + prefix=prefix, + keep_vars=keep_vars) + return self._non_strict_state_dict(destination=destination, + prefix=prefix, + keep_vars=keep_vars, + only_rank_0=only_rank_0) + + def _non_strict_state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True): + r"""Returns a dictionary containing a whole state of the module. + + Both parameters and persistent buffers (e.g. running averages) are + included. Keys are corresponding parameter and buffer names. + Parameters and buffers set to ``None`` are not included. + Warning: The non strict state dict would ignore the parameters if the + tensors of the parameters are shared with other parameters which + have been included in the dictionary. + + Returns: + dict: + a dictionary containing a whole state of the module """ if destination is None: destination = OrderedDict() diff --git a/colossalai/nn/parallel/utils.py b/colossalai/nn/parallel/utils.py index 1205cbc3a..988f97825 100644 --- a/colossalai/nn/parallel/utils.py +++ b/colossalai/nn/parallel/utils.py @@ -60,17 +60,17 @@ def _get_shallow_copy_model(model: nn.Module): return name_to_module[''] -def get_static_torch_model(gemini_ddp_model, +def get_static_torch_model(zero_ddp_model, device=torch.device("cpu"), dtype=torch.float32, only_rank_0=True) -> torch.nn.Module: - """Get a static torch.nn.Module model from the given GeminiDDP module. - You should notice that the original GeminiDDP model is not modified. + """Get a static torch.nn.Module model from the given ZeroDDP module. + You should notice that the original ZeroDDP model is not modified. Thus, you can use the original model in further training. But you should not use the returned torch model to train, this can cause unexpected errors. Args: - gemini_ddp_model (GeminiDDP): a gemini ddp model + zero_ddp_model (ZeroDDP): a zero ddp model device (torch.device): the device of the final torch model dtype (torch.dtype): the dtype of the final torch model only_rank_0 (bool): if True, only rank0 has the coverted torch model @@ -78,11 +78,11 @@ def get_static_torch_model(gemini_ddp_model, Returns: torch.nn.Module: a static torch model used for saving checkpoints or numeric checks """ - from colossalai.nn.parallel import GeminiDDP - assert isinstance(gemini_ddp_model, GeminiDDP) + from colossalai.nn.parallel import ZeroDDP + assert isinstance(zero_ddp_model, ZeroDDP) - state_dict = gemini_ddp_model.state_dict(only_rank_0=only_rank_0) - colo_model = gemini_ddp_model.module + state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0, strict=False) + colo_model = zero_ddp_model.module torch_model = _get_shallow_copy_model(colo_model) if not only_rank_0 or dist.get_rank() == 0: