overlap_param

pull/540/head
lijiaxing 2023-12-13 19:02:19 +08:00
parent bbb5651582
commit f68f34234d
3 changed files with 9 additions and 7 deletions

View File

@ -837,7 +837,6 @@ class HybridZeroOptimizer(BaseOptimizer):
fp16_param.data.copy_(fp32_param)
torch.cuda.synchronize()
with torch.cuda.stream(self._comm_bcast_stream):
self.broadcast_params()
timer("step").stop()
@ -875,8 +874,6 @@ class HybridZeroOptimizer(BaseOptimizer):
for handle in handles:
handle.wait()
torch.cuda.synchronize()
##################
# FP16 Utilities #
##################

View File

@ -36,6 +36,7 @@ def empty_cache_and_diag(batch_count, interval=50):
if interval <= 0:
interval = 50
if not gpc.config.hybrid_zero_optimizer.overlap_sync_param:
cuda_memory_analyze(batch_count, batch_count % int(interval) == 0 or batch_count <= 5)
if batch_count % int(interval) == 0:

View File

@ -5,6 +5,8 @@ import time
import torch
from internlm.core.context import global_context as gpc
class _Timer:
"""Timer."""
@ -23,6 +25,7 @@ class _Timer:
megatron_timer.reset()
assert not self.started_, "timer has already been started"
if not gpc.config.hybrid_zero_optimizer.overlap_sync_param:
self.stream.synchronize()
self.start_time = time.time()
self.started_ = True
@ -30,6 +33,7 @@ class _Timer:
def stop(self):
"""Stop the timer."""
assert self.started_, "timer is not started"
if not gpc.config.hybrid_zero_optimizer.overlap_sync_param:
self.stream.synchronize()
self.elapsed_ += time.time() - self.start_time
self.started_ = False