[zero] zero optim copy chunk rather than copy tensor (#1070)

pull/1072/head
ver217 3 years ago committed by GitHub
parent 4fc748f69b
commit c5cd3b0f35
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -158,6 +158,14 @@ class Chunk:
return torch.isinf(self.data[:self.utilized_size]).any().item() or \ return torch.isinf(self.data[:self.utilized_size]).any().item() or \
torch.isnan(self.data[:self.utilized_size]).any().item() torch.isnan(self.data[:self.utilized_size]).any().item()
def copy_(self, dest_chunk: 'Chunk'):
assert not self.is_free
assert not dest_chunk.is_free
assert self.size == dest_chunk.size
assert self.utilized_size == dest_chunk.utilized_size
self.data.copy_(dest_chunk.data)
self._update_tensors_ptr()
class ChunkManager: class ChunkManager:
@ -306,3 +314,8 @@ class ChunkManager:
max_chunk_util = chunk_util max_chunk_util = chunk_util
best_chunk_size = chunk_size best_chunk_size = chunk_size
return best_chunk_size return best_chunk_size
def copy_chunk_group(self, dest_group_name: str, src_group_name: str):
for dest_chunk, src_chunk in zip(self.chunk_groups[dest_group_name], self.chunk_groups[src_group_name]):
if not dest_chunk.is_free:
dest_chunk.copy_(src_chunk)

@ -5,8 +5,6 @@ from torch.optim import Optimizer
from colossalai.nn.parallel.data_parallel import ColoDDPV2 from colossalai.nn.parallel.data_parallel import ColoDDPV2
from typing import Dict from typing import Dict
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer
@ -56,12 +54,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
assert p.grad is None assert p.grad is None
def _update_fp16_params(self): def _update_fp16_params(self):
for group in self.optim.param_groups: self.module.chunk_manager.copy_chunk_group('fp16_param', 'fp32_param')
for p in group['params']:
if not self.module.chunk_manager.is_chunk_free(p):
# TODO(ver217): copy chunk
fp32_p = self.fp16_param_to_fp32_param[p]
self.module.chunk_manager.copy_tensor_to_chunk_slice(p, fp32_p)
def _check_overflow(self): def _check_overflow(self):
# clear previous overflow record # clear previous overflow record

Loading…
Cancel
Save