mirror of https://github.com/hpcaitech/ColossalAI
[polish] polish code for get_static_torch_model (#2405)
* [gemini] polish code * [testing] remove code * [gemini] make more robustpull/2408/head
parent
551cafec14
commit
ea13a201bb
|
@ -334,10 +334,9 @@ class ZeroDDP(ColoDDP):
|
||||||
self.grads_device[tensor] = device
|
self.grads_device[tensor] = device
|
||||||
|
|
||||||
def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True, strict: bool = True):
|
def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True, strict: bool = True):
|
||||||
r"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
strict (bool): whether to reture the whole model state
|
strict (bool): whether to reture the whole model state as the pytorch `Module.state_dict()`
|
||||||
as the original pytorch state_dict()
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict:
|
dict:
|
||||||
|
@ -349,25 +348,24 @@ class ZeroDDP(ColoDDP):
|
||||||
['bias', 'weight']
|
['bias', 'weight']
|
||||||
"""
|
"""
|
||||||
if strict:
|
if strict:
|
||||||
return get_static_torch_model(zero_ddp_model=self, device=get_current_device(),
|
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
|
||||||
only_rank_0=only_rank_0).state_dict(destination=destination,
|
torch_model = get_static_torch_model(zero_ddp_model=self, only_rank_0=only_rank_0)
|
||||||
prefix=prefix,
|
return torch_model.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||||
keep_vars=keep_vars)
|
|
||||||
return self._non_strict_state_dict(destination=destination,
|
return self._non_strict_state_dict(destination=destination,
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
keep_vars=keep_vars,
|
keep_vars=keep_vars,
|
||||||
only_rank_0=only_rank_0)
|
only_rank_0=only_rank_0)
|
||||||
|
|
||||||
def _non_strict_state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True):
|
def _non_strict_state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True):
|
||||||
r"""Returns a dictionary containing a whole state of the module.
|
"""Returns a dictionary containing a whole state of the module.
|
||||||
|
|
||||||
Both parameters and persistent buffers (e.g. running averages) are
|
Both parameters and persistent buffers (e.g. running averages) are included.
|
||||||
included. Keys are corresponding parameter and buffer names.
|
Keys are corresponding parameter and buffer names.
|
||||||
Parameters and buffers set to ``None`` are not included.
|
Parameters and buffers set to ``None`` are not included.
|
||||||
|
|
||||||
Warning: The non strict state dict would ignore the parameters if the
|
Warning: The non strict state dict would ignore the parameters if the tensors of the parameters
|
||||||
tensors of the parameters are shared with other parameters which
|
are shared with other parameters which have been included in the dictionary.
|
||||||
have been included in the dictionary.
|
When you need to load the state dict, you should set the argument `strict` to False.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict:
|
dict:
|
||||||
|
|
|
@ -47,17 +47,16 @@ def _get_shallow_copy_model(model: nn.Module):
|
||||||
"""Get a shallow copy of the given model. Each submodule is different from the original submodule.
|
"""Get a shallow copy of the given model. Each submodule is different from the original submodule.
|
||||||
But the new submodule and the old submodule share all attributes.
|
But the new submodule and the old submodule share all attributes.
|
||||||
"""
|
"""
|
||||||
name_to_module = dict()
|
old_to_new = dict()
|
||||||
for name, module in _get_dfs_module_list(model):
|
for name, module in _get_dfs_module_list(model):
|
||||||
new_module = copy(module)
|
new_module = copy(module)
|
||||||
new_module._modules = OrderedDict()
|
new_module._modules = OrderedDict()
|
||||||
for subname, submodule in module._modules.items():
|
for subname, submodule in module._modules.items():
|
||||||
if submodule is None:
|
if submodule is None:
|
||||||
continue
|
continue
|
||||||
full_name = name + ('.' if name else '') + subname
|
setattr(new_module, subname, old_to_new[submodule])
|
||||||
setattr(new_module, subname, name_to_module[full_name])
|
old_to_new[module] = new_module
|
||||||
name_to_module[name] = new_module
|
return old_to_new[model]
|
||||||
return name_to_module['']
|
|
||||||
|
|
||||||
|
|
||||||
def get_static_torch_model(zero_ddp_model,
|
def get_static_torch_model(zero_ddp_model,
|
||||||
|
|
|
@ -31,8 +31,6 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
|
||||||
for key, value in torch_dict.items():
|
for key, value in torch_dict.items():
|
||||||
# key is 'module.model.PARAMETER', so we truncate it
|
# key is 'module.model.PARAMETER', so we truncate it
|
||||||
key = key[7:]
|
key = key[7:]
|
||||||
if key == 'model.lm_head.weight':
|
|
||||||
continue
|
|
||||||
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
|
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
|
||||||
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
|
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
|
||||||
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
|
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
|
||||||
|
|
|
@ -36,8 +36,6 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
|
||||||
for key, value in torch_dict.items():
|
for key, value in torch_dict.items():
|
||||||
# key is 'module.model.PARAMETER', so we truncate it
|
# key is 'module.model.PARAMETER', so we truncate it
|
||||||
key = key[7:]
|
key = key[7:]
|
||||||
if key == 'model.lm_head.weight':
|
|
||||||
continue
|
|
||||||
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
|
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
|
||||||
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
|
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
|
||||||
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
|
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
|
||||||
|
|
|
@ -45,8 +45,6 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str):
|
||||||
torch_dict = torch_model.state_dict()
|
torch_dict = torch_model.state_dict()
|
||||||
|
|
||||||
for key, value in torch_dict.items():
|
for key, value in torch_dict.items():
|
||||||
if key == 'model.lm_head.weight':
|
|
||||||
continue
|
|
||||||
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
|
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
|
||||||
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
|
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
|
||||||
assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key)
|
assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key)
|
||||||
|
@ -84,8 +82,6 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str):
|
||||||
zero_dict = model.state_dict(only_rank_0=False)
|
zero_dict = model.state_dict(only_rank_0=False)
|
||||||
|
|
||||||
for key, value in torch_dict.items():
|
for key, value in torch_dict.items():
|
||||||
if key == 'model.lm_head.weight':
|
|
||||||
continue
|
|
||||||
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
|
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
|
||||||
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
|
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
|
||||||
assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key)
|
assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key)
|
||||||
|
|
|
@ -27,8 +27,6 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module, pg: ProcessGroup):
|
||||||
for key, value in torch_dict.items():
|
for key, value in torch_dict.items():
|
||||||
# key is 'module.model.PARAMETER', so we truncate it
|
# key is 'module.model.PARAMETER', so we truncate it
|
||||||
key = key[7:]
|
key = key[7:]
|
||||||
if key == 'model.lm_head.weight':
|
|
||||||
continue
|
|
||||||
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
|
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
|
||||||
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
|
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
|
||||||
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
|
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
|
||||||
|
|
Loading…
Reference in New Issue