From 2e9487096709ba619bec0c2a4961885a24b8c1f3 Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Sun, 8 Oct 2023 20:22:40 +0800 Subject: [PATCH] fix(internlm/train/training_internlm.py): remove set IS_TENSOR_PARALLEL attr --- internlm/solver/optimizer/fsdp_optimizer.py | 16 +++++++++------- internlm/train/training_internlm.py | 17 +---------------- 2 files changed, 10 insertions(+), 23 deletions(-) diff --git a/internlm/solver/optimizer/fsdp_optimizer.py b/internlm/solver/optimizer/fsdp_optimizer.py index 463ed91..a3f21db 100644 --- a/internlm/solver/optimizer/fsdp_optimizer.py +++ b/internlm/solver/optimizer/fsdp_optimizer.py @@ -78,8 +78,8 @@ class FSDPadaptOptimizer(BaseOptimizer): loss.backward(retain_graph=retain_graph) def _compute_norm_with_fsdp_flatten(self, group_id): - params = [p for p in self._fp16_param_groups[group_id] if p.storage().size() != 0] - gradients = [p.grad for p in params if p.storage().size() != 0] + params = [p for p in self._fp16_param_groups[group_id] if p.untyped_storage().size() != 0] + gradients = [p.grad for p in params if p.untyped_storage().size() != 0] norm_group = compute_norm(gradients=gradients, parameters=params, last_stage=True) return norm_group @@ -128,11 +128,11 @@ class FSDPadaptOptimizer(BaseOptimizer): # create gradient for fp32 params for group_idx in range(len(self.param_groups)): dtype = self._fp32_param_tensor_groups[group_idx][0].dtype - fp16_params = [p for p in self._fp16_param_groups[group_idx] if p.storage().size() != 0] + fp16_params = [p for p in self._fp16_param_groups[group_idx] if p.untyped_storage().size() != 0] grad_fp32 = [p.grad.to(dtype) for p in fp16_params] device = self._fp32_param_tensor_groups[group_idx][0].device - nonzero_fp32 = [p for p in self._fp32_param_tensor_groups[group_idx] if p.storage().size() != 0] + nonzero_fp32 = [p for p in self._fp32_param_tensor_groups[group_idx] if p.untyped_storage().size() != 0] for p, g in zip(nonzero_fp32, grad_fp32): p.grad = g.to(device) @@ -143,8 +143,10 @@ class FSDPadaptOptimizer(BaseOptimizer): self.zero_grad() for group_idx in range(len(self._fp16_param_groups)): - fp16_params = [p for p in self._fp16_param_groups[group_idx] if p.storage().size() != 0] - fp32_tensor_params = [p for p in self._fp32_param_tensor_groups[group_idx] if p.storage().size() != 0] + fp16_params = [p for p in self._fp16_param_groups[group_idx] if p.untyped_storage().size() != 0] + fp32_tensor_params = [ + p for p in self._fp32_param_tensor_groups[group_idx] if p.untyped_storage().size() != 0 + ] # release fp32 grad release_param_grad(fp32_tensor_params) # update fp16 param @@ -177,7 +179,7 @@ class FSDPadaptOptimizer(BaseOptimizer): for group_id, param in self._fp32_param_tensor_groups.items(): for p in param: - if p.storage().size() != 0: + if p.untyped_storage().size() != 0: p.grad.data.mul_(1.0 / combined_scale_groups[group_id]) def state_dict(self): diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index b2f42e0..154ec04 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -108,15 +108,6 @@ def initialize_model(): def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]): RMSNorm = try_import_RMSNorm() if gpc.config.parallel.use_fsdp: - # pre-save info for tensor parallel - if gpc.get_world_size(ParallelMode.TENSOR) > 1: - tp_dict = dict() - for name, param in model.named_parameters(): - if hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL): - tp_dict.update({name.replace("model.", ""): True}) - else: - tp_dict.update({name.replace("model.", ""): False}) - # set wrap_policy for fsdp wrap transformer_wrap_policy = functools.partial( transformer_auto_wrap_policy, @@ -133,15 +124,9 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]): forward_prefetch=True, backward_prefetch=BackwardPrefetch.BACKWARD_PRE, limit_all_gathers=True, - use_orig_params=False, + use_orig_params=True, ) - # re-set attribute for fsdp module with tensor parallel - if gpc.get_world_size(ParallelMode.TENSOR) > 1: - for (name, param), pre in zip(model.named_parameters(), tp_dict): - if pre in name and tp_dict[pre]: - setattr(param, IS_TENSOR_PARALLEL, True) - return model