|
|
|
@ -5,8 +5,6 @@ from torch.optim import Optimizer
|
|
|
|
|
from colossalai.nn.parallel.data_parallel import ColoDDPV2
|
|
|
|
|
from typing import Dict
|
|
|
|
|
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
|
|
|
|
from colossalai.core import global_context as gpc
|
|
|
|
|
from colossalai.context import ParallelMode
|
|
|
|
|
from colossalai.logging import get_dist_logger
|
|
|
|
|
from colossalai.nn.optimizer import ColossalaiOptimizer
|
|
|
|
|
|
|
|
|
@ -56,12 +54,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|
|
|
|
assert p.grad is None
|
|
|
|
|
|
|
|
|
|
def _update_fp16_params(self):
|
|
|
|
|
for group in self.optim.param_groups:
|
|
|
|
|
for p in group['params']:
|
|
|
|
|
if not self.module.chunk_manager.is_chunk_free(p):
|
|
|
|
|
# TODO(ver217): copy chunk
|
|
|
|
|
fp32_p = self.fp16_param_to_fp32_param[p]
|
|
|
|
|
self.module.chunk_manager.copy_tensor_to_chunk_slice(p, fp32_p)
|
|
|
|
|
self.module.chunk_manager.copy_chunk_group('fp16_param', 'fp32_param')
|
|
|
|
|
|
|
|
|
|
def _check_overflow(self):
|
|
|
|
|
# clear previous overflow record
|
|
|
|
|