mirror of https://github.com/InternLM/InternLM
overlap_param
parent
bbb5651582
commit
f68f34234d
|
@ -837,7 +837,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
fp16_param.data.copy_(fp32_param)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
with torch.cuda.stream(self._comm_bcast_stream):
|
||||
self.broadcast_params()
|
||||
|
||||
timer("step").stop()
|
||||
|
@ -875,8 +874,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
for handle in handles:
|
||||
handle.wait()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
##################
|
||||
# FP16 Utilities #
|
||||
##################
|
||||
|
|
|
@ -36,6 +36,7 @@ def empty_cache_and_diag(batch_count, interval=50):
|
|||
if interval <= 0:
|
||||
interval = 50
|
||||
|
||||
if not gpc.config.hybrid_zero_optimizer.overlap_sync_param:
|
||||
cuda_memory_analyze(batch_count, batch_count % int(interval) == 0 or batch_count <= 5)
|
||||
|
||||
if batch_count % int(interval) == 0:
|
||||
|
|
|
@ -5,6 +5,8 @@ import time
|
|||
|
||||
import torch
|
||||
|
||||
from internlm.core.context import global_context as gpc
|
||||
|
||||
|
||||
class _Timer:
|
||||
"""Timer."""
|
||||
|
@ -23,6 +25,7 @@ class _Timer:
|
|||
megatron_timer.reset()
|
||||
|
||||
assert not self.started_, "timer has already been started"
|
||||
if not gpc.config.hybrid_zero_optimizer.overlap_sync_param:
|
||||
self.stream.synchronize()
|
||||
self.start_time = time.time()
|
||||
self.started_ = True
|
||||
|
@ -30,6 +33,7 @@ class _Timer:
|
|||
def stop(self):
|
||||
"""Stop the timer."""
|
||||
assert self.started_, "timer is not started"
|
||||
if not gpc.config.hybrid_zero_optimizer.overlap_sync_param:
|
||||
self.stream.synchronize()
|
||||
self.elapsed_ += time.time() - self.start_time
|
||||
self.started_ = False
|
||||
|
|
Loading…
Reference in New Issue