|
|
|
@ -135,18 +135,18 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|
|
|
|
|
|
|
|
|
# assign master param pointers to p.data.
|
|
|
|
|
# We will not trigger data copy here.
|
|
|
|
|
for group in self.optimizer.param_groups:
|
|
|
|
|
for group in self.optim.param_groups:
|
|
|
|
|
for p in group['params']:
|
|
|
|
|
p.data = self.master_params[p]
|
|
|
|
|
# Now p.data is sharded
|
|
|
|
|
# So optimizer states are sharded naturally
|
|
|
|
|
|
|
|
|
|
ret = self.optimizer.step(*args, **kwargs)
|
|
|
|
|
ret = self.optim.step(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
# Copy master param data (fp32) to payload of col_attr (fp16)
|
|
|
|
|
# TODO() improve efficiency by gathering tensors into a chunk and transfering
|
|
|
|
|
# a chunk.
|
|
|
|
|
for group in self.optimizer.param_groups:
|
|
|
|
|
for group in self.optim.param_groups:
|
|
|
|
|
for p in group['params']:
|
|
|
|
|
is_param_sharded = p.col_attr.data.is_sharded
|
|
|
|
|
if not is_param_sharded:
|
|
|
|
@ -190,7 +190,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|
|
|
|
self._found_overflow.fill_(0.0)
|
|
|
|
|
|
|
|
|
|
# check for overflow
|
|
|
|
|
for group in self.optimizer.param_groups:
|
|
|
|
|
for group in self.optim.param_groups:
|
|
|
|
|
for p in group['params']:
|
|
|
|
|
if has_inf_or_nan(p.grad):
|
|
|
|
|
self._found_overflow.fill_(1.0)
|
|
|
|
@ -206,7 +206,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|
|
|
|
|
|
|
|
|
def _unscale_grads(self):
|
|
|
|
|
assert self.optim_state == OptimState.SCALED
|
|
|
|
|
for group in self.optimizer.param_groups:
|
|
|
|
|
for group in self.optim.param_groups:
|
|
|
|
|
for p in group['params']:
|
|
|
|
|
if p.grad is not None:
|
|
|
|
|
p.grad.data.div_(self.loss_scale)
|
|
|
|
@ -216,7 +216,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|
|
|
|
# We must set grad to None
|
|
|
|
|
# Because we will judge whether local grad accumulation
|
|
|
|
|
# is enabled by wheter grad is None
|
|
|
|
|
self.optimizer.zero_grad(set_to_none=True)
|
|
|
|
|
self.optim.zero_grad(set_to_none=True)
|
|
|
|
|
|
|
|
|
|
def sync_grad(self):
|
|
|
|
|
pass
|
|
|
|
|