mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix sharded optim zero grad (#604)
* fix sharded optim zero grad * polish commentspull/610/head
parent
297b8baae2
commit
9bee119104
|
@ -184,7 +184,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
|
||||
if found_inf:
|
||||
self._logger.warning('found inf during ShardedOptimV2 step')
|
||||
self.zero_grad()
|
||||
self._zero_grad(recover_data=True)
|
||||
return
|
||||
|
||||
self._prepare_data()
|
||||
|
@ -246,13 +246,31 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
self.optim_state = OptimState.UNSCALED
|
||||
|
||||
def zero_grad(self, *args, **kwargs):
|
||||
self._zero_grad()
|
||||
|
||||
def _zero_grad(self, recover_data: bool = False):
|
||||
"""zero grad and maybe recover fp16 params
|
||||
When `reuse_fp16_shard` is enabled,
|
||||
p.colo_attr.sharded_data_tensor stores grad here.
|
||||
We have to recover them from fp32 params.
|
||||
|
||||
Args:
|
||||
recover_data (bool, optional): Whether to recover fp16 param from fp32 param. Defaults to False.
|
||||
"""
|
||||
# We must set grad to None
|
||||
# Because we will judge whether local grad accumulation
|
||||
# is enabled by wheter grad is None
|
||||
# Because grad here is sharded
|
||||
# But next backward pass will create a full grad first
|
||||
# Which leads to wrong accumulation
|
||||
self.optim.zero_grad(set_to_none=True)
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
# p.colo_attr.sharded_data_tensor stores grad now
|
||||
# we have to recover fp16 param
|
||||
reuse_fp16_shard = p.colo_attr.saved_grad.data_ptr() == p.colo_attr.sharded_data_tensor.data_ptr()
|
||||
p.colo_attr.saved_grad.set_null()
|
||||
if recover_data and reuse_fp16_shard:
|
||||
p.colo_attr.sharded_data_tensor.reset_payload(
|
||||
colo_model_tensor_clone(self.master_params[p].payload.half(), torch.cuda.current_device()))
|
||||
|
||||
def sync_grad(self):
|
||||
pass
|
||||
|
|
Loading…
Reference in New Issue