mirror of https://github.com/hpcaitech/ColossalAI
[polish] polish code for get_static_torch_model (#2405)
* [gemini] polish code * [testing] remove code * [gemini] make more robustpull/2408/head
parent
551cafec14
commit
ea13a201bb
|
@ -334,10 +334,9 @@ class ZeroDDP(ColoDDP):
|
|||
self.grads_device[tensor] = device
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True, strict: bool = True):
|
||||
r"""
|
||||
"""
|
||||
Args:
|
||||
strict (bool): whether to reture the whole model state
|
||||
as the original pytorch state_dict()
|
||||
strict (bool): whether to reture the whole model state as the pytorch `Module.state_dict()`
|
||||
|
||||
Returns:
|
||||
dict:
|
||||
|
@ -349,25 +348,24 @@ class ZeroDDP(ColoDDP):
|
|||
['bias', 'weight']
|
||||
"""
|
||||
if strict:
|
||||
return get_static_torch_model(zero_ddp_model=self, device=get_current_device(),
|
||||
only_rank_0=only_rank_0).state_dict(destination=destination,
|
||||
prefix=prefix,
|
||||
keep_vars=keep_vars)
|
||||
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
|
||||
torch_model = get_static_torch_model(zero_ddp_model=self, only_rank_0=only_rank_0)
|
||||
return torch_model.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||
return self._non_strict_state_dict(destination=destination,
|
||||
prefix=prefix,
|
||||
keep_vars=keep_vars,
|
||||
only_rank_0=only_rank_0)
|
||||
|
||||
def _non_strict_state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True):
|
||||
r"""Returns a dictionary containing a whole state of the module.
|
||||
"""Returns a dictionary containing a whole state of the module.
|
||||
|
||||
Both parameters and persistent buffers (e.g. running averages) are
|
||||
included. Keys are corresponding parameter and buffer names.
|
||||
Both parameters and persistent buffers (e.g. running averages) are included.
|
||||
Keys are corresponding parameter and buffer names.
|
||||
Parameters and buffers set to ``None`` are not included.
|
||||
|
||||
Warning: The non strict state dict would ignore the parameters if the
|
||||
tensors of the parameters are shared with other parameters which
|
||||
have been included in the dictionary.
|
||||
Warning: The non strict state dict would ignore the parameters if the tensors of the parameters
|
||||
are shared with other parameters which have been included in the dictionary.
|
||||
When you need to load the state dict, you should set the argument `strict` to False.
|
||||
|
||||
Returns:
|
||||
dict:
|
||||
|
|
|
@ -47,17 +47,16 @@ def _get_shallow_copy_model(model: nn.Module):
|
|||
"""Get a shallow copy of the given model. Each submodule is different from the original submodule.
|
||||
But the new submodule and the old submodule share all attributes.
|
||||
"""
|
||||
name_to_module = dict()
|
||||
old_to_new = dict()
|
||||
for name, module in _get_dfs_module_list(model):
|
||||
new_module = copy(module)
|
||||
new_module._modules = OrderedDict()
|
||||
for subname, submodule in module._modules.items():
|
||||
if submodule is None:
|
||||
continue
|
||||
full_name = name + ('.' if name else '') + subname
|
||||
setattr(new_module, subname, name_to_module[full_name])
|
||||
name_to_module[name] = new_module
|
||||
return name_to_module['']
|
||||
setattr(new_module, subname, old_to_new[submodule])
|
||||
old_to_new[module] = new_module
|
||||
return old_to_new[model]
|
||||
|
||||
|
||||
def get_static_torch_model(zero_ddp_model,
|
||||
|
|
|
@ -31,8 +31,6 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
|
|||
for key, value in torch_dict.items():
|
||||
# key is 'module.model.PARAMETER', so we truncate it
|
||||
key = key[7:]
|
||||
if key == 'model.lm_head.weight':
|
||||
continue
|
||||
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
|
||||
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
|
||||
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
|
||||
|
|
|
@ -36,8 +36,6 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
|
|||
for key, value in torch_dict.items():
|
||||
# key is 'module.model.PARAMETER', so we truncate it
|
||||
key = key[7:]
|
||||
if key == 'model.lm_head.weight':
|
||||
continue
|
||||
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
|
||||
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
|
||||
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
|
||||
|
|
|
@ -45,8 +45,6 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str):
|
|||
torch_dict = torch_model.state_dict()
|
||||
|
||||
for key, value in torch_dict.items():
|
||||
if key == 'model.lm_head.weight':
|
||||
continue
|
||||
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
|
||||
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
|
||||
assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key)
|
||||
|
@ -84,8 +82,6 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str):
|
|||
zero_dict = model.state_dict(only_rank_0=False)
|
||||
|
||||
for key, value in torch_dict.items():
|
||||
if key == 'model.lm_head.weight':
|
||||
continue
|
||||
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
|
||||
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
|
||||
assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key)
|
||||
|
|
|
@ -27,8 +27,6 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module, pg: ProcessGroup):
|
|||
for key, value in torch_dict.items():
|
||||
# key is 'module.model.PARAMETER', so we truncate it
|
||||
key = key[7:]
|
||||
if key == 'model.lm_head.weight':
|
||||
continue
|
||||
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
|
||||
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
|
||||
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
|
||||
|
|
Loading…
Reference in New Issue