mirror of https://github.com/hpcaitech/ColossalAI
rename variables
parent
46add4a5c5
commit
e99af94ab8
|
@ -135,18 +135,18 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
|
||||
# assign master param pointers to p.data.
|
||||
# We will not trigger data copy here.
|
||||
for group in self.optimizer.param_groups:
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
p.data = self.master_params[p]
|
||||
# Now p.data is sharded
|
||||
# So optimizer states are sharded naturally
|
||||
|
||||
ret = self.optimizer.step(*args, **kwargs)
|
||||
ret = self.optim.step(*args, **kwargs)
|
||||
|
||||
# Copy master param data (fp32) to payload of col_attr (fp16)
|
||||
# TODO() improve efficiency by gathering tensors into a chunk and transfering
|
||||
# a chunk.
|
||||
for group in self.optimizer.param_groups:
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
is_param_sharded = p.col_attr.data.is_sharded
|
||||
if not is_param_sharded:
|
||||
|
@ -190,7 +190,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
self._found_overflow.fill_(0.0)
|
||||
|
||||
# check for overflow
|
||||
for group in self.optimizer.param_groups:
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
if has_inf_or_nan(p.grad):
|
||||
self._found_overflow.fill_(1.0)
|
||||
|
@ -206,7 +206,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
|
||||
def _unscale_grads(self):
|
||||
assert self.optim_state == OptimState.SCALED
|
||||
for group in self.optimizer.param_groups:
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is not None:
|
||||
p.grad.data.div_(self.loss_scale)
|
||||
|
@ -216,7 +216,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
# We must set grad to None
|
||||
# Because we will judge whether local grad accumulation
|
||||
# is enabled by wheter grad is None
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
self.optim.zero_grad(set_to_none=True)
|
||||
|
||||
def sync_grad(self):
|
||||
pass
|
||||
|
|
Loading…
Reference in New Issue