From ea13a201bbd7eb6022069c8379f3626f9788b0f9 Mon Sep 17 00:00:00 2001 From: HELSON Date: Mon, 9 Jan 2023 17:41:38 +0800 Subject: [PATCH] [polish] polish code for get_static_torch_model (#2405) * [gemini] polish code * [testing] remove code * [gemini] make more robust --- colossalai/nn/parallel/data_parallel.py | 24 +++++++++---------- colossalai/nn/parallel/utils.py | 9 ++++--- tests/test_gemini/update/test_grad_clip.py | 2 -- tests/test_gemini/update/test_optim.py | 2 -- .../update/test_zeroddp_state_dict.py | 4 ---- tests/test_tensor/test_tp_with_zero.py | 2 -- 6 files changed, 15 insertions(+), 28 deletions(-) diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index 8fd08db95..a7d79be16 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -334,10 +334,9 @@ class ZeroDDP(ColoDDP): self.grads_device[tensor] = device def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True, strict: bool = True): - r""" + """ Args: - strict (bool): whether to reture the whole model state - as the original pytorch state_dict() + strict (bool): whether to reture the whole model state as the pytorch `Module.state_dict()` Returns: dict: @@ -349,25 +348,24 @@ class ZeroDDP(ColoDDP): ['bias', 'weight'] """ if strict: - return get_static_torch_model(zero_ddp_model=self, device=get_current_device(), - only_rank_0=only_rank_0).state_dict(destination=destination, - prefix=prefix, - keep_vars=keep_vars) + assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now." + torch_model = get_static_torch_model(zero_ddp_model=self, only_rank_0=only_rank_0) + return torch_model.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) return self._non_strict_state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars, only_rank_0=only_rank_0) def _non_strict_state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True): - r"""Returns a dictionary containing a whole state of the module. + """Returns a dictionary containing a whole state of the module. - Both parameters and persistent buffers (e.g. running averages) are - included. Keys are corresponding parameter and buffer names. + Both parameters and persistent buffers (e.g. running averages) are included. + Keys are corresponding parameter and buffer names. Parameters and buffers set to ``None`` are not included. - Warning: The non strict state dict would ignore the parameters if the - tensors of the parameters are shared with other parameters which - have been included in the dictionary. + Warning: The non strict state dict would ignore the parameters if the tensors of the parameters + are shared with other parameters which have been included in the dictionary. + When you need to load the state dict, you should set the argument `strict` to False. Returns: dict: diff --git a/colossalai/nn/parallel/utils.py b/colossalai/nn/parallel/utils.py index 988f97825..d323556d5 100644 --- a/colossalai/nn/parallel/utils.py +++ b/colossalai/nn/parallel/utils.py @@ -47,17 +47,16 @@ def _get_shallow_copy_model(model: nn.Module): """Get a shallow copy of the given model. Each submodule is different from the original submodule. But the new submodule and the old submodule share all attributes. """ - name_to_module = dict() + old_to_new = dict() for name, module in _get_dfs_module_list(model): new_module = copy(module) new_module._modules = OrderedDict() for subname, submodule in module._modules.items(): if submodule is None: continue - full_name = name + ('.' if name else '') + subname - setattr(new_module, subname, name_to_module[full_name]) - name_to_module[name] = new_module - return name_to_module[''] + setattr(new_module, subname, old_to_new[submodule]) + old_to_new[module] = new_module + return old_to_new[model] def get_static_torch_model(zero_ddp_model, diff --git a/tests/test_gemini/update/test_grad_clip.py b/tests/test_gemini/update/test_grad_clip.py index 185521edb..fda1cf8cf 100644 --- a/tests/test_gemini/update/test_grad_clip.py +++ b/tests/test_gemini/update/test_grad_clip.py @@ -31,8 +31,6 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module): for key, value in torch_dict.items(): # key is 'module.model.PARAMETER', so we truncate it key = key[7:] - if key == 'model.lm_head.weight': - continue assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value))) diff --git a/tests/test_gemini/update/test_optim.py b/tests/test_gemini/update/test_optim.py index 34509cc0c..07e6e65f2 100644 --- a/tests/test_gemini/update/test_optim.py +++ b/tests/test_gemini/update/test_optim.py @@ -36,8 +36,6 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module): for key, value in torch_dict.items(): # key is 'module.model.PARAMETER', so we truncate it key = key[7:] - if key == 'model.lm_head.weight': - continue assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value))) diff --git a/tests/test_gemini/update/test_zeroddp_state_dict.py b/tests/test_gemini/update/test_zeroddp_state_dict.py index 7b0c6e37a..b902bb0f0 100644 --- a/tests/test_gemini/update/test_zeroddp_state_dict.py +++ b/tests/test_gemini/update/test_zeroddp_state_dict.py @@ -45,8 +45,6 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str): torch_dict = torch_model.state_dict() for key, value in torch_dict.items(): - if key == 'model.lm_head.weight': - continue assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key) @@ -84,8 +82,6 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str): zero_dict = model.state_dict(only_rank_0=False) for key, value in torch_dict.items(): - if key == 'model.lm_head.weight': - continue assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key) diff --git a/tests/test_tensor/test_tp_with_zero.py b/tests/test_tensor/test_tp_with_zero.py index 33db676cb..7e611e8a1 100644 --- a/tests/test_tensor/test_tp_with_zero.py +++ b/tests/test_tensor/test_tp_with_zero.py @@ -27,8 +27,6 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module, pg: ProcessGroup): for key, value in torch_dict.items(): # key is 'module.model.PARAMETER', so we truncate it key = key[7:] - if key == 'model.lm_head.weight': - continue assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))