fix(optimizer/fsdp_optimizer.py): fsdp process empty params group (#408)

Co-authored-by: huangting4201 <huangting3@sensetime.com>
pull/421/head
huangting4201 2023-10-10 20:06:04 +08:00 committed by GitHub
parent b3645b0244
commit 9a731b6e9b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 0 deletions

View File

@ -79,6 +79,10 @@ class FSDPadaptOptimizer(BaseOptimizer):
def _compute_norm_with_fsdp_flatten(self, group_id):
params = [p for p in self._fp16_param_groups[group_id] if p.untyped_storage().size() != 0]
gradients = [p.grad for p in params if p.untyped_storage().size() != 0]
norm_group = 0
if len(params) <= 0 or len(gradients) <= 0:
return norm_group
norm_group = compute_norm(gradients=gradients, parameters=params, last_stage=True)
return norm_group
@ -126,6 +130,8 @@ class FSDPadaptOptimizer(BaseOptimizer):
# create gradient for fp32 params
for group_idx in range(len(self.param_groups)):
if len(self._fp32_param_tensor_groups[group_idx]) <= 0:
continue
dtype = self._fp32_param_tensor_groups[group_idx][0].dtype
fp16_params = [p for p in self._fp16_param_groups[group_idx] if p.untyped_storage().size() != 0]
grad_fp32 = [p.grad.to(dtype) for p in fp16_params]