mirror of https://github.com/hpcaitech/ColossalAI
Fix state_dict key missing issue of the ZeroDDP (#2363)
* Fix state_dict output for ZeroDDP duplicated parameters * Rewrite state_dict based on get_static_torch_model * Modify get_static_torch_model to be compatible with the lower version (ZeroDDP)pull/2405/head
parent
ce08661eb1
commit
9880fd2cd8
|
@ -18,6 +18,7 @@ from colossalai.utils import get_current_device
|
|||
from colossalai.zero.utils.gemini_hook import GeminiZeROHook
|
||||
|
||||
from .reducer import Reducer
|
||||
from .utils import get_static_torch_model
|
||||
|
||||
try:
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
|
||||
|
@ -251,6 +252,7 @@ class ZeroDDP(ColoDDP):
|
|||
pin_memory=pin_memory)
|
||||
self.fp32_params.append(fp32_p)
|
||||
self.grads_device[p] = self.gemini_manager.default_device
|
||||
|
||||
self.chunk_manager.close_all_groups()
|
||||
self._cast_buffers()
|
||||
|
||||
|
@ -331,12 +333,11 @@ class ZeroDDP(ColoDDP):
|
|||
for tensor in chunk.get_tensors():
|
||||
self.grads_device[tensor] = device
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True):
|
||||
r"""Returns a dictionary containing a whole state of the module.
|
||||
|
||||
Both parameters and persistent buffers (e.g. running averages) are
|
||||
included. Keys are corresponding parameter and buffer names.
|
||||
Parameters and buffers set to ``None`` are not included.
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True, strict: bool = True):
|
||||
r"""
|
||||
Args:
|
||||
strict (bool): whether to reture the whole model state
|
||||
as the original pytorch state_dict()
|
||||
|
||||
Returns:
|
||||
dict:
|
||||
|
@ -346,7 +347,31 @@ class ZeroDDP(ColoDDP):
|
|||
|
||||
>>> module.state_dict().keys()
|
||||
['bias', 'weight']
|
||||
"""
|
||||
if strict:
|
||||
return get_static_torch_model(zero_ddp_model=self, device=get_current_device(),
|
||||
only_rank_0=only_rank_0).state_dict(destination=destination,
|
||||
prefix=prefix,
|
||||
keep_vars=keep_vars)
|
||||
return self._non_strict_state_dict(destination=destination,
|
||||
prefix=prefix,
|
||||
keep_vars=keep_vars,
|
||||
only_rank_0=only_rank_0)
|
||||
|
||||
def _non_strict_state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True):
|
||||
r"""Returns a dictionary containing a whole state of the module.
|
||||
|
||||
Both parameters and persistent buffers (e.g. running averages) are
|
||||
included. Keys are corresponding parameter and buffer names.
|
||||
Parameters and buffers set to ``None`` are not included.
|
||||
|
||||
Warning: The non strict state dict would ignore the parameters if the
|
||||
tensors of the parameters are shared with other parameters which
|
||||
have been included in the dictionary.
|
||||
|
||||
Returns:
|
||||
dict:
|
||||
a dictionary containing a whole state of the module
|
||||
"""
|
||||
if destination is None:
|
||||
destination = OrderedDict()
|
||||
|
|
|
@ -60,17 +60,17 @@ def _get_shallow_copy_model(model: nn.Module):
|
|||
return name_to_module['']
|
||||
|
||||
|
||||
def get_static_torch_model(gemini_ddp_model,
|
||||
def get_static_torch_model(zero_ddp_model,
|
||||
device=torch.device("cpu"),
|
||||
dtype=torch.float32,
|
||||
only_rank_0=True) -> torch.nn.Module:
|
||||
"""Get a static torch.nn.Module model from the given GeminiDDP module.
|
||||
You should notice that the original GeminiDDP model is not modified.
|
||||
"""Get a static torch.nn.Module model from the given ZeroDDP module.
|
||||
You should notice that the original ZeroDDP model is not modified.
|
||||
Thus, you can use the original model in further training.
|
||||
But you should not use the returned torch model to train, this can cause unexpected errors.
|
||||
|
||||
Args:
|
||||
gemini_ddp_model (GeminiDDP): a gemini ddp model
|
||||
zero_ddp_model (ZeroDDP): a zero ddp model
|
||||
device (torch.device): the device of the final torch model
|
||||
dtype (torch.dtype): the dtype of the final torch model
|
||||
only_rank_0 (bool): if True, only rank0 has the coverted torch model
|
||||
|
@ -78,11 +78,11 @@ def get_static_torch_model(gemini_ddp_model,
|
|||
Returns:
|
||||
torch.nn.Module: a static torch model used for saving checkpoints or numeric checks
|
||||
"""
|
||||
from colossalai.nn.parallel import GeminiDDP
|
||||
assert isinstance(gemini_ddp_model, GeminiDDP)
|
||||
from colossalai.nn.parallel import ZeroDDP
|
||||
assert isinstance(zero_ddp_model, ZeroDDP)
|
||||
|
||||
state_dict = gemini_ddp_model.state_dict(only_rank_0=only_rank_0)
|
||||
colo_model = gemini_ddp_model.module
|
||||
state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0, strict=False)
|
||||
colo_model = zero_ddp_model.module
|
||||
torch_model = _get_shallow_copy_model(colo_model)
|
||||
|
||||
if not only_rank_0 or dist.get_rank() == 0:
|
||||
|
|
Loading…
Reference in New Issue