mirror of https://github.com/hpcaitech/ColossalAI
[zero] zero optim copy chunk rather than copy tensor (#1070)
parent
4fc748f69b
commit
c5cd3b0f35
|
@ -158,6 +158,14 @@ class Chunk:
|
|||
return torch.isinf(self.data[:self.utilized_size]).any().item() or \
|
||||
torch.isnan(self.data[:self.utilized_size]).any().item()
|
||||
|
||||
def copy_(self, dest_chunk: 'Chunk'):
|
||||
assert not self.is_free
|
||||
assert not dest_chunk.is_free
|
||||
assert self.size == dest_chunk.size
|
||||
assert self.utilized_size == dest_chunk.utilized_size
|
||||
self.data.copy_(dest_chunk.data)
|
||||
self._update_tensors_ptr()
|
||||
|
||||
|
||||
class ChunkManager:
|
||||
|
||||
|
@ -306,3 +314,8 @@ class ChunkManager:
|
|||
max_chunk_util = chunk_util
|
||||
best_chunk_size = chunk_size
|
||||
return best_chunk_size
|
||||
|
||||
def copy_chunk_group(self, dest_group_name: str, src_group_name: str):
|
||||
for dest_chunk, src_chunk in zip(self.chunk_groups[dest_group_name], self.chunk_groups[src_group_name]):
|
||||
if not dest_chunk.is_free:
|
||||
dest_chunk.copy_(src_chunk)
|
||||
|
|
|
@ -5,8 +5,6 @@ from torch.optim import Optimizer
|
|||
from colossalai.nn.parallel.data_parallel import ColoDDPV2
|
||||
from typing import Dict
|
||||
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
|
||||
|
@ -56,12 +54,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
assert p.grad is None
|
||||
|
||||
def _update_fp16_params(self):
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
if not self.module.chunk_manager.is_chunk_free(p):
|
||||
# TODO(ver217): copy chunk
|
||||
fp32_p = self.fp16_param_to_fp32_param[p]
|
||||
self.module.chunk_manager.copy_tensor_to_chunk_slice(p, fp32_p)
|
||||
self.module.chunk_manager.copy_chunk_group('fp16_param', 'fp32_param')
|
||||
|
||||
def _check_overflow(self):
|
||||
# clear previous overflow record
|
||||
|
|
Loading…
Reference in New Issue