From 9bee119104bf90c6c6426dd8f8fc968cd5e5c7db Mon Sep 17 00:00:00 2001 From: ver217 Date: Fri, 1 Apr 2022 12:41:20 +0800 Subject: [PATCH] [hotfix] fix sharded optim zero grad (#604) * fix sharded optim zero grad * polish comments --- .../zero/sharded_optim/sharded_optim_v2.py | 24 ++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index 3f691f58e..0ce0adda6 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -184,7 +184,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): if found_inf: self._logger.warning('found inf during ShardedOptimV2 step') - self.zero_grad() + self._zero_grad(recover_data=True) return self._prepare_data() @@ -246,13 +246,31 @@ class ShardedOptimizerV2(ColossalaiOptimizer): self.optim_state = OptimState.UNSCALED def zero_grad(self, *args, **kwargs): + self._zero_grad() + + def _zero_grad(self, recover_data: bool = False): + """zero grad and maybe recover fp16 params + When `reuse_fp16_shard` is enabled, + p.colo_attr.sharded_data_tensor stores grad here. + We have to recover them from fp32 params. + + Args: + recover_data (bool, optional): Whether to recover fp16 param from fp32 param. Defaults to False. + """ # We must set grad to None - # Because we will judge whether local grad accumulation - # is enabled by wheter grad is None + # Because grad here is sharded + # But next backward pass will create a full grad first + # Which leads to wrong accumulation self.optim.zero_grad(set_to_none=True) for group in self.optim.param_groups: for p in group['params']: + # p.colo_attr.sharded_data_tensor stores grad now + # we have to recover fp16 param + reuse_fp16_shard = p.colo_attr.saved_grad.data_ptr() == p.colo_attr.sharded_data_tensor.data_ptr() p.colo_attr.saved_grad.set_null() + if recover_data and reuse_fp16_shard: + p.colo_attr.sharded_data_tensor.reset_payload( + colo_model_tensor_clone(self.master_params[p].payload.half(), torch.cuda.current_device())) def sync_grad(self): pass