diff --git a/colossalai/tensor/chunk.py b/colossalai/tensor/chunk.py index 5d228e937..d0c8315dc 100644 --- a/colossalai/tensor/chunk.py +++ b/colossalai/tensor/chunk.py @@ -158,6 +158,14 @@ class Chunk: return torch.isinf(self.data[:self.utilized_size]).any().item() or \ torch.isnan(self.data[:self.utilized_size]).any().item() + def copy_(self, dest_chunk: 'Chunk'): + assert not self.is_free + assert not dest_chunk.is_free + assert self.size == dest_chunk.size + assert self.utilized_size == dest_chunk.utilized_size + self.data.copy_(dest_chunk.data) + self._update_tensors_ptr() + class ChunkManager: @@ -306,3 +314,8 @@ class ChunkManager: max_chunk_util = chunk_util best_chunk_size = chunk_size return best_chunk_size + + def copy_chunk_group(self, dest_group_name: str, src_group_name: str): + for dest_chunk, src_chunk in zip(self.chunk_groups[dest_group_name], self.chunk_groups[src_group_name]): + if not dest_chunk.is_free: + dest_chunk.copy_(src_chunk) diff --git a/colossalai/zero/zero_optimizer.py b/colossalai/zero/zero_optimizer.py index fa5b84456..8cc21515d 100644 --- a/colossalai/zero/zero_optimizer.py +++ b/colossalai/zero/zero_optimizer.py @@ -5,8 +5,6 @@ from torch.optim import Optimizer from colossalai.nn.parallel.data_parallel import ColoDDPV2 from typing import Dict from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler -from colossalai.core import global_context as gpc -from colossalai.context import ParallelMode from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import ColossalaiOptimizer @@ -56,12 +54,7 @@ class ZeroOptimizer(ColossalaiOptimizer): assert p.grad is None def _update_fp16_params(self): - for group in self.optim.param_groups: - for p in group['params']: - if not self.module.chunk_manager.is_chunk_free(p): - # TODO(ver217): copy chunk - fp32_p = self.fp16_param_to_fp32_param[p] - self.module.chunk_manager.copy_tensor_to_chunk_slice(p, fp32_p) + self.module.chunk_manager.copy_chunk_group('fp16_param', 'fp32_param') def _check_overflow(self): # clear previous overflow record