|
|
|
@ -135,18 +135,18 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|
|
|
|
|
|
|
|
|
# assign master param pointers to p.data. |
|
|
|
|
# We will not trigger data copy here. |
|
|
|
|
for group in self.optimizer.param_groups: |
|
|
|
|
for group in self.optim.param_groups: |
|
|
|
|
for p in group['params']: |
|
|
|
|
p.data = self.master_params[p] |
|
|
|
|
# Now p.data is sharded |
|
|
|
|
# So optimizer states are sharded naturally |
|
|
|
|
|
|
|
|
|
ret = self.optimizer.step(*args, **kwargs) |
|
|
|
|
ret = self.optim.step(*args, **kwargs) |
|
|
|
|
|
|
|
|
|
# Copy master param data (fp32) to payload of col_attr (fp16) |
|
|
|
|
# TODO() improve efficiency by gathering tensors into a chunk and transfering |
|
|
|
|
# a chunk. |
|
|
|
|
for group in self.optimizer.param_groups: |
|
|
|
|
for group in self.optim.param_groups: |
|
|
|
|
for p in group['params']: |
|
|
|
|
is_param_sharded = p.col_attr.data.is_sharded |
|
|
|
|
if not is_param_sharded: |
|
|
|
@ -190,7 +190,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|
|
|
|
self._found_overflow.fill_(0.0) |
|
|
|
|
|
|
|
|
|
# check for overflow |
|
|
|
|
for group in self.optimizer.param_groups: |
|
|
|
|
for group in self.optim.param_groups: |
|
|
|
|
for p in group['params']: |
|
|
|
|
if has_inf_or_nan(p.grad): |
|
|
|
|
self._found_overflow.fill_(1.0) |
|
|
|
@ -206,7 +206,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|
|
|
|
|
|
|
|
|
def _unscale_grads(self): |
|
|
|
|
assert self.optim_state == OptimState.SCALED |
|
|
|
|
for group in self.optimizer.param_groups: |
|
|
|
|
for group in self.optim.param_groups: |
|
|
|
|
for p in group['params']: |
|
|
|
|
if p.grad is not None: |
|
|
|
|
p.grad.data.div_(self.loss_scale) |
|
|
|
@ -216,7 +216,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|
|
|
|
# We must set grad to None |
|
|
|
|
# Because we will judge whether local grad accumulation |
|
|
|
|
# is enabled by wheter grad is None |
|
|
|
|
self.optimizer.zero_grad(set_to_none=True) |
|
|
|
|
self.optim.zero_grad(set_to_none=True) |
|
|
|
|
|
|
|
|
|
def sync_grad(self): |
|
|
|
|
pass |
|
|
|
|