[hotfix] fix sharded optim zero grad (#604)

* fix sharded optim zero grad

* polish comments
pull/610/head
ver217 2022-04-01 12:41:20 +08:00 committed by GitHub
parent 297b8baae2
commit 9bee119104
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 21 additions and 3 deletions

View File

@ -184,7 +184,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
if found_inf:
self._logger.warning('found inf during ShardedOptimV2 step')
self.zero_grad()
self._zero_grad(recover_data=True)
return
self._prepare_data()
@ -246,13 +246,31 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self.optim_state = OptimState.UNSCALED
def zero_grad(self, *args, **kwargs):
self._zero_grad()
def _zero_grad(self, recover_data: bool = False):
"""zero grad and maybe recover fp16 params
When `reuse_fp16_shard` is enabled,
p.colo_attr.sharded_data_tensor stores grad here.
We have to recover them from fp32 params.
Args:
recover_data (bool, optional): Whether to recover fp16 param from fp32 param. Defaults to False.
"""
# We must set grad to None
# Because we will judge whether local grad accumulation
# is enabled by wheter grad is None
# Because grad here is sharded
# But next backward pass will create a full grad first
# Which leads to wrong accumulation
self.optim.zero_grad(set_to_none=True)
for group in self.optim.param_groups:
for p in group['params']:
# p.colo_attr.sharded_data_tensor stores grad now
# we have to recover fp16 param
reuse_fp16_shard = p.colo_attr.saved_grad.data_ptr() == p.colo_attr.sharded_data_tensor.data_ptr()
p.colo_attr.saved_grad.set_null()
if recover_data and reuse_fp16_shard:
p.colo_attr.sharded_data_tensor.reset_payload(
colo_model_tensor_clone(self.master_params[p].payload.half(), torch.cuda.current_device()))
def sync_grad(self):
pass