|
|
|
@ -334,10 +334,9 @@ class ZeroDDP(ColoDDP):
|
|
|
|
|
self.grads_device[tensor] = device |
|
|
|
|
|
|
|
|
|
def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True, strict: bool = True): |
|
|
|
|
r""" |
|
|
|
|
""" |
|
|
|
|
Args: |
|
|
|
|
strict (bool): whether to reture the whole model state |
|
|
|
|
as the original pytorch state_dict() |
|
|
|
|
strict (bool): whether to reture the whole model state as the pytorch `Module.state_dict()` |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
dict: |
|
|
|
@ -349,25 +348,24 @@ class ZeroDDP(ColoDDP):
|
|
|
|
|
['bias', 'weight'] |
|
|
|
|
""" |
|
|
|
|
if strict: |
|
|
|
|
return get_static_torch_model(zero_ddp_model=self, device=get_current_device(), |
|
|
|
|
only_rank_0=only_rank_0).state_dict(destination=destination, |
|
|
|
|
prefix=prefix, |
|
|
|
|
keep_vars=keep_vars) |
|
|
|
|
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now." |
|
|
|
|
torch_model = get_static_torch_model(zero_ddp_model=self, only_rank_0=only_rank_0) |
|
|
|
|
return torch_model.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) |
|
|
|
|
return self._non_strict_state_dict(destination=destination, |
|
|
|
|
prefix=prefix, |
|
|
|
|
keep_vars=keep_vars, |
|
|
|
|
only_rank_0=only_rank_0) |
|
|
|
|
|
|
|
|
|
def _non_strict_state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True): |
|
|
|
|
r"""Returns a dictionary containing a whole state of the module. |
|
|
|
|
"""Returns a dictionary containing a whole state of the module. |
|
|
|
|
|
|
|
|
|
Both parameters and persistent buffers (e.g. running averages) are |
|
|
|
|
included. Keys are corresponding parameter and buffer names. |
|
|
|
|
Both parameters and persistent buffers (e.g. running averages) are included. |
|
|
|
|
Keys are corresponding parameter and buffer names. |
|
|
|
|
Parameters and buffers set to ``None`` are not included. |
|
|
|
|
|
|
|
|
|
Warning: The non strict state dict would ignore the parameters if the |
|
|
|
|
tensors of the parameters are shared with other parameters which |
|
|
|
|
have been included in the dictionary. |
|
|
|
|
Warning: The non strict state dict would ignore the parameters if the tensors of the parameters |
|
|
|
|
are shared with other parameters which have been included in the dictionary. |
|
|
|
|
When you need to load the state dict, you should set the argument `strict` to False. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
dict: |
|
|
|
|