mirror of https://github.com/hpcaitech/ColossalAI
Fix state_dict key missing issue of the ZeroDDP (#2363)
* Fix state_dict output for ZeroDDP duplicated parameters * Rewrite state_dict based on get_static_torch_model * Modify get_static_torch_model to be compatible with the lower version (ZeroDDP)pull/2405/head
parent
ce08661eb1
commit
9880fd2cd8
|
@ -18,6 +18,7 @@ from colossalai.utils import get_current_device
|
||||||
from colossalai.zero.utils.gemini_hook import GeminiZeROHook
|
from colossalai.zero.utils.gemini_hook import GeminiZeROHook
|
||||||
|
|
||||||
from .reducer import Reducer
|
from .reducer import Reducer
|
||||||
|
from .utils import get_static_torch_model
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
|
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
|
||||||
|
@ -251,6 +252,7 @@ class ZeroDDP(ColoDDP):
|
||||||
pin_memory=pin_memory)
|
pin_memory=pin_memory)
|
||||||
self.fp32_params.append(fp32_p)
|
self.fp32_params.append(fp32_p)
|
||||||
self.grads_device[p] = self.gemini_manager.default_device
|
self.grads_device[p] = self.gemini_manager.default_device
|
||||||
|
|
||||||
self.chunk_manager.close_all_groups()
|
self.chunk_manager.close_all_groups()
|
||||||
self._cast_buffers()
|
self._cast_buffers()
|
||||||
|
|
||||||
|
@ -331,12 +333,11 @@ class ZeroDDP(ColoDDP):
|
||||||
for tensor in chunk.get_tensors():
|
for tensor in chunk.get_tensors():
|
||||||
self.grads_device[tensor] = device
|
self.grads_device[tensor] = device
|
||||||
|
|
||||||
def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True):
|
def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True, strict: bool = True):
|
||||||
r"""Returns a dictionary containing a whole state of the module.
|
r"""
|
||||||
|
Args:
|
||||||
Both parameters and persistent buffers (e.g. running averages) are
|
strict (bool): whether to reture the whole model state
|
||||||
included. Keys are corresponding parameter and buffer names.
|
as the original pytorch state_dict()
|
||||||
Parameters and buffers set to ``None`` are not included.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict:
|
dict:
|
||||||
|
@ -346,7 +347,31 @@ class ZeroDDP(ColoDDP):
|
||||||
|
|
||||||
>>> module.state_dict().keys()
|
>>> module.state_dict().keys()
|
||||||
['bias', 'weight']
|
['bias', 'weight']
|
||||||
|
"""
|
||||||
|
if strict:
|
||||||
|
return get_static_torch_model(zero_ddp_model=self, device=get_current_device(),
|
||||||
|
only_rank_0=only_rank_0).state_dict(destination=destination,
|
||||||
|
prefix=prefix,
|
||||||
|
keep_vars=keep_vars)
|
||||||
|
return self._non_strict_state_dict(destination=destination,
|
||||||
|
prefix=prefix,
|
||||||
|
keep_vars=keep_vars,
|
||||||
|
only_rank_0=only_rank_0)
|
||||||
|
|
||||||
|
def _non_strict_state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True):
|
||||||
|
r"""Returns a dictionary containing a whole state of the module.
|
||||||
|
|
||||||
|
Both parameters and persistent buffers (e.g. running averages) are
|
||||||
|
included. Keys are corresponding parameter and buffer names.
|
||||||
|
Parameters and buffers set to ``None`` are not included.
|
||||||
|
|
||||||
|
Warning: The non strict state dict would ignore the parameters if the
|
||||||
|
tensors of the parameters are shared with other parameters which
|
||||||
|
have been included in the dictionary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict:
|
||||||
|
a dictionary containing a whole state of the module
|
||||||
"""
|
"""
|
||||||
if destination is None:
|
if destination is None:
|
||||||
destination = OrderedDict()
|
destination = OrderedDict()
|
||||||
|
|
|
@ -60,17 +60,17 @@ def _get_shallow_copy_model(model: nn.Module):
|
||||||
return name_to_module['']
|
return name_to_module['']
|
||||||
|
|
||||||
|
|
||||||
def get_static_torch_model(gemini_ddp_model,
|
def get_static_torch_model(zero_ddp_model,
|
||||||
device=torch.device("cpu"),
|
device=torch.device("cpu"),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
only_rank_0=True) -> torch.nn.Module:
|
only_rank_0=True) -> torch.nn.Module:
|
||||||
"""Get a static torch.nn.Module model from the given GeminiDDP module.
|
"""Get a static torch.nn.Module model from the given ZeroDDP module.
|
||||||
You should notice that the original GeminiDDP model is not modified.
|
You should notice that the original ZeroDDP model is not modified.
|
||||||
Thus, you can use the original model in further training.
|
Thus, you can use the original model in further training.
|
||||||
But you should not use the returned torch model to train, this can cause unexpected errors.
|
But you should not use the returned torch model to train, this can cause unexpected errors.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
gemini_ddp_model (GeminiDDP): a gemini ddp model
|
zero_ddp_model (ZeroDDP): a zero ddp model
|
||||||
device (torch.device): the device of the final torch model
|
device (torch.device): the device of the final torch model
|
||||||
dtype (torch.dtype): the dtype of the final torch model
|
dtype (torch.dtype): the dtype of the final torch model
|
||||||
only_rank_0 (bool): if True, only rank0 has the coverted torch model
|
only_rank_0 (bool): if True, only rank0 has the coverted torch model
|
||||||
|
@ -78,11 +78,11 @@ def get_static_torch_model(gemini_ddp_model,
|
||||||
Returns:
|
Returns:
|
||||||
torch.nn.Module: a static torch model used for saving checkpoints or numeric checks
|
torch.nn.Module: a static torch model used for saving checkpoints or numeric checks
|
||||||
"""
|
"""
|
||||||
from colossalai.nn.parallel import GeminiDDP
|
from colossalai.nn.parallel import ZeroDDP
|
||||||
assert isinstance(gemini_ddp_model, GeminiDDP)
|
assert isinstance(zero_ddp_model, ZeroDDP)
|
||||||
|
|
||||||
state_dict = gemini_ddp_model.state_dict(only_rank_0=only_rank_0)
|
state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0, strict=False)
|
||||||
colo_model = gemini_ddp_model.module
|
colo_model = zero_ddp_model.module
|
||||||
torch_model = _get_shallow_copy_model(colo_model)
|
torch_model = _get_shallow_copy_model(colo_model)
|
||||||
|
|
||||||
if not only_rank_0 or dist.get_rank() == 0:
|
if not only_rank_0 or dist.get_rank() == 0:
|
||||||
|
|
Loading…
Reference in New Issue