diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index eb7aae3..f8c697f 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -837,8 +837,7 @@ class HybridZeroOptimizer(BaseOptimizer): fp16_param.data.copy_(fp32_param) torch.cuda.synchronize() - with torch.cuda.stream(self._comm_bcast_stream): - self.broadcast_params() + self.broadcast_params() timer("step").stop() @@ -875,8 +874,6 @@ class HybridZeroOptimizer(BaseOptimizer): for handle in handles: handle.wait() - torch.cuda.synchronize() - ################## # FP16 Utilities # ################## diff --git a/internlm/utils/gputest.py b/internlm/utils/gputest.py index 48ec0e3..ec61a04 100644 --- a/internlm/utils/gputest.py +++ b/internlm/utils/gputest.py @@ -36,7 +36,8 @@ def empty_cache_and_diag(batch_count, interval=50): if interval <= 0: interval = 50 - cuda_memory_analyze(batch_count, batch_count % int(interval) == 0 or batch_count <= 5) + if not gpc.config.hybrid_zero_optimizer.overlap_sync_param: + cuda_memory_analyze(batch_count, batch_count % int(interval) == 0 or batch_count <= 5) if batch_count % int(interval) == 0: # there is no need to do diag on the first batch diff --git a/internlm/utils/megatron_timers.py b/internlm/utils/megatron_timers.py index d5d89e5..94e52fa 100644 --- a/internlm/utils/megatron_timers.py +++ b/internlm/utils/megatron_timers.py @@ -5,6 +5,8 @@ import time import torch +from internlm.core.context import global_context as gpc + class _Timer: """Timer.""" @@ -23,14 +25,16 @@ class _Timer: megatron_timer.reset() assert not self.started_, "timer has already been started" - self.stream.synchronize() + if not gpc.config.hybrid_zero_optimizer.overlap_sync_param: + self.stream.synchronize() self.start_time = time.time() self.started_ = True def stop(self): """Stop the timer.""" assert self.started_, "timer is not started" - self.stream.synchronize() + if not gpc.config.hybrid_zero_optimizer.overlap_sync_param: + self.stream.synchronize() self.elapsed_ += time.time() - self.start_time self.started_ = False