diff --git a/internlm/solver/optimizer/fsdp_optimizer.py b/internlm/solver/optimizer/fsdp_optimizer.py index 6000185..ab15917 100644 --- a/internlm/solver/optimizer/fsdp_optimizer.py +++ b/internlm/solver/optimizer/fsdp_optimizer.py @@ -79,6 +79,10 @@ class FSDPadaptOptimizer(BaseOptimizer): def _compute_norm_with_fsdp_flatten(self, group_id): params = [p for p in self._fp16_param_groups[group_id] if p.untyped_storage().size() != 0] gradients = [p.grad for p in params if p.untyped_storage().size() != 0] + + norm_group = 0 + if len(params) <= 0 or len(gradients) <= 0: + return norm_group norm_group = compute_norm(gradients=gradients, parameters=params, last_stage=True) return norm_group @@ -126,6 +130,8 @@ class FSDPadaptOptimizer(BaseOptimizer): # create gradient for fp32 params for group_idx in range(len(self.param_groups)): + if len(self._fp32_param_tensor_groups[group_idx]) <= 0: + continue dtype = self._fp32_param_tensor_groups[group_idx][0].dtype fp16_params = [p for p in self._fp16_param_groups[group_idx] if p.untyped_storage().size() != 0] grad_fp32 = [p.grad.to(dtype) for p in fp16_params]