mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix sharded optim zero grad (#604)
* fix sharded optim zero grad * polish commentspull/610/head
parent
297b8baae2
commit
9bee119104
|
@ -184,7 +184,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
|
|
||||||
if found_inf:
|
if found_inf:
|
||||||
self._logger.warning('found inf during ShardedOptimV2 step')
|
self._logger.warning('found inf during ShardedOptimV2 step')
|
||||||
self.zero_grad()
|
self._zero_grad(recover_data=True)
|
||||||
return
|
return
|
||||||
|
|
||||||
self._prepare_data()
|
self._prepare_data()
|
||||||
|
@ -246,13 +246,31 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
self.optim_state = OptimState.UNSCALED
|
self.optim_state = OptimState.UNSCALED
|
||||||
|
|
||||||
def zero_grad(self, *args, **kwargs):
|
def zero_grad(self, *args, **kwargs):
|
||||||
|
self._zero_grad()
|
||||||
|
|
||||||
|
def _zero_grad(self, recover_data: bool = False):
|
||||||
|
"""zero grad and maybe recover fp16 params
|
||||||
|
When `reuse_fp16_shard` is enabled,
|
||||||
|
p.colo_attr.sharded_data_tensor stores grad here.
|
||||||
|
We have to recover them from fp32 params.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
recover_data (bool, optional): Whether to recover fp16 param from fp32 param. Defaults to False.
|
||||||
|
"""
|
||||||
# We must set grad to None
|
# We must set grad to None
|
||||||
# Because we will judge whether local grad accumulation
|
# Because grad here is sharded
|
||||||
# is enabled by wheter grad is None
|
# But next backward pass will create a full grad first
|
||||||
|
# Which leads to wrong accumulation
|
||||||
self.optim.zero_grad(set_to_none=True)
|
self.optim.zero_grad(set_to_none=True)
|
||||||
for group in self.optim.param_groups:
|
for group in self.optim.param_groups:
|
||||||
for p in group['params']:
|
for p in group['params']:
|
||||||
|
# p.colo_attr.sharded_data_tensor stores grad now
|
||||||
|
# we have to recover fp16 param
|
||||||
|
reuse_fp16_shard = p.colo_attr.saved_grad.data_ptr() == p.colo_attr.sharded_data_tensor.data_ptr()
|
||||||
p.colo_attr.saved_grad.set_null()
|
p.colo_attr.saved_grad.set_null()
|
||||||
|
if recover_data and reuse_fp16_shard:
|
||||||
|
p.colo_attr.sharded_data_tensor.reset_payload(
|
||||||
|
colo_model_tensor_clone(self.master_params[p].payload.half(), torch.cuda.current_device()))
|
||||||
|
|
||||||
def sync_grad(self):
|
def sync_grad(self):
|
||||||
pass
|
pass
|
||||||
|
|
Loading…
Reference in New Issue