|
|
|
@ -360,24 +360,20 @@ class ZeroDDP(ColoDDP):
|
|
|
|
|
destination = hook_result |
|
|
|
|
return destination |
|
|
|
|
|
|
|
|
|
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True): |
|
|
|
|
r"""Saves module state to `destination` dictionary, containing a state |
|
|
|
|
of the module, but not its descendants. This is called on every |
|
|
|
|
submodule in :meth:`~torch.nn.Module.state_dict`. |
|
|
|
|
|
|
|
|
|
In rare cases, subclasses can achieve class-specific behavior by |
|
|
|
|
overriding this method with custom logic. |
|
|
|
|
def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict: |
|
|
|
|
""" |
|
|
|
|
get param content from chunks. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
destination (dict): a dict where state will be stored |
|
|
|
|
prefix (str): the prefix for parameters and buffers used in this |
|
|
|
|
module |
|
|
|
|
""" |
|
|
|
|
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now." |
|
|
|
|
param_list (_type_): a list of torch.nn.Parameters |
|
|
|
|
only_rank_0 (_type_): _description_ |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
Dict: a dict whose key is param name and value is param with correct payload |
|
|
|
|
""" |
|
|
|
|
# save parameters |
|
|
|
|
param_to_save_data = dict() |
|
|
|
|
chunk_list = self.chunk_manager.get_chunks(self.fp32_params) |
|
|
|
|
chunk_list = self.chunk_manager.get_chunks(param_list) |
|
|
|
|
for chunk in chunk_list: |
|
|
|
|
temp_chunk = get_temp_total_chunk_on_cuda(chunk) |
|
|
|
|
|
|
|
|
@ -391,7 +387,37 @@ class ZeroDDP(ColoDDP):
|
|
|
|
|
param_to_save_data[tensor] = record_tensor |
|
|
|
|
|
|
|
|
|
del temp_chunk |
|
|
|
|
return param_to_save_data |
|
|
|
|
|
|
|
|
|
def torch_named_parameters(self): |
|
|
|
|
""" |
|
|
|
|
get named_parameters() of self.module. It is used the same of PyTorch param and returns the real param.data payload. |
|
|
|
|
It works the same as torch.Module named_parameters |
|
|
|
|
""" |
|
|
|
|
params_list = [p for p in self.parameters(recurse=True)] |
|
|
|
|
param_to_save_data = self._get_param_to_save_data(params_list, False) |
|
|
|
|
for (name, _), p in zip(self.named_parameters(recurse=True), params_list): |
|
|
|
|
if p is not None: |
|
|
|
|
assert p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name) |
|
|
|
|
record_parameter = param_to_save_data[p] |
|
|
|
|
yield name, record_parameter |
|
|
|
|
|
|
|
|
|
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True): |
|
|
|
|
r"""Saves module state to `destination` dictionary, containing a state |
|
|
|
|
of the module, but not its descendants. This is called on every |
|
|
|
|
submodule in :meth:`~torch.nn.Module.state_dict`. |
|
|
|
|
|
|
|
|
|
In rare cases, subclasses can achieve class-specific behavior by |
|
|
|
|
overriding this method with custom logic. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
destination (dict): a dict where state will be stored |
|
|
|
|
prefix (str): the prefix for parameters and buffers used in this |
|
|
|
|
module |
|
|
|
|
""" |
|
|
|
|
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now." |
|
|
|
|
|
|
|
|
|
param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0) |
|
|
|
|
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params): |
|
|
|
|
if p is not None: |
|
|
|
|
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name) |
|
|
|
|