mirror of https://github.com/InternLM/InternLM
fix(optimizer/fsdp_optimizer.py): fsdp process empty params group (#408)
Co-authored-by: huangting4201 <huangting3@sensetime.com>pull/421/head
parent
b3645b0244
commit
9a731b6e9b
|
@ -79,6 +79,10 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
|||
def _compute_norm_with_fsdp_flatten(self, group_id):
|
||||
params = [p for p in self._fp16_param_groups[group_id] if p.untyped_storage().size() != 0]
|
||||
gradients = [p.grad for p in params if p.untyped_storage().size() != 0]
|
||||
|
||||
norm_group = 0
|
||||
if len(params) <= 0 or len(gradients) <= 0:
|
||||
return norm_group
|
||||
norm_group = compute_norm(gradients=gradients, parameters=params, last_stage=True)
|
||||
|
||||
return norm_group
|
||||
|
@ -126,6 +130,8 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
|||
|
||||
# create gradient for fp32 params
|
||||
for group_idx in range(len(self.param_groups)):
|
||||
if len(self._fp32_param_tensor_groups[group_idx]) <= 0:
|
||||
continue
|
||||
dtype = self._fp32_param_tensor_groups[group_idx][0].dtype
|
||||
fp16_params = [p for p in self._fp16_param_groups[group_idx] if p.untyped_storage().size() != 0]
|
||||
grad_fp32 = [p.grad.to(dtype) for p in fp16_params]
|
||||
|
|
Loading…
Reference in New Issue