mirror of https://github.com/InternLM/InternLM
overlap_param
parent
bbb5651582
commit
f68f34234d
|
@ -837,8 +837,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
fp16_param.data.copy_(fp32_param)
|
fp16_param.data.copy_(fp32_param)
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
with torch.cuda.stream(self._comm_bcast_stream):
|
self.broadcast_params()
|
||||||
self.broadcast_params()
|
|
||||||
|
|
||||||
timer("step").stop()
|
timer("step").stop()
|
||||||
|
|
||||||
|
@ -875,8 +874,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
for handle in handles:
|
for handle in handles:
|
||||||
handle.wait()
|
handle.wait()
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
##################
|
##################
|
||||||
# FP16 Utilities #
|
# FP16 Utilities #
|
||||||
##################
|
##################
|
||||||
|
|
|
@ -36,7 +36,8 @@ def empty_cache_and_diag(batch_count, interval=50):
|
||||||
if interval <= 0:
|
if interval <= 0:
|
||||||
interval = 50
|
interval = 50
|
||||||
|
|
||||||
cuda_memory_analyze(batch_count, batch_count % int(interval) == 0 or batch_count <= 5)
|
if not gpc.config.hybrid_zero_optimizer.overlap_sync_param:
|
||||||
|
cuda_memory_analyze(batch_count, batch_count % int(interval) == 0 or batch_count <= 5)
|
||||||
|
|
||||||
if batch_count % int(interval) == 0:
|
if batch_count % int(interval) == 0:
|
||||||
# there is no need to do diag on the first batch
|
# there is no need to do diag on the first batch
|
||||||
|
|
|
@ -5,6 +5,8 @@ import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from internlm.core.context import global_context as gpc
|
||||||
|
|
||||||
|
|
||||||
class _Timer:
|
class _Timer:
|
||||||
"""Timer."""
|
"""Timer."""
|
||||||
|
@ -23,14 +25,16 @@ class _Timer:
|
||||||
megatron_timer.reset()
|
megatron_timer.reset()
|
||||||
|
|
||||||
assert not self.started_, "timer has already been started"
|
assert not self.started_, "timer has already been started"
|
||||||
self.stream.synchronize()
|
if not gpc.config.hybrid_zero_optimizer.overlap_sync_param:
|
||||||
|
self.stream.synchronize()
|
||||||
self.start_time = time.time()
|
self.start_time = time.time()
|
||||||
self.started_ = True
|
self.started_ = True
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
"""Stop the timer."""
|
"""Stop the timer."""
|
||||||
assert self.started_, "timer is not started"
|
assert self.started_, "timer is not started"
|
||||||
self.stream.synchronize()
|
if not gpc.config.hybrid_zero_optimizer.overlap_sync_param:
|
||||||
|
self.stream.synchronize()
|
||||||
self.elapsed_ += time.time() - self.start_time
|
self.elapsed_ += time.time() - self.start_time
|
||||||
self.started_ = False
|
self.started_ = False
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue