From d9c9f7c9ee6bc3872888335cd039873a3732079d Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Mon, 18 Dec 2023 21:37:17 +0800 Subject: [PATCH] fix --- .../solver/optimizer/hybrid_zero_optim.py | 19 ++++---- internlm/solver/optimizer/utils.py | 45 +++++++++++++++---- internlm/utils/megatron_timers.py | 8 +--- 3 files changed, 50 insertions(+), 22 deletions(-) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index f8c697f..ca7449a 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -100,8 +100,6 @@ class HybridZeroOptimizer(BaseOptimizer): # self._overlap_communication = overlap_communication self._reduce_bucket_size = reduce_bucket_size - self._comm_bcast_stream = torch.cuda.Stream() - # gradient scaler self.grad_scaler = DynamicGradScaler( initial_scale=initial_scale, @@ -859,16 +857,21 @@ class HybridZeroOptimizer(BaseOptimizer): # grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank # assert grank == rank, f"{grank} == {rank}" g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode[group_id])[rank] - handle = dist.broadcast( - fp16_param, - src=g_rank, - group=gpc.get_group(self._broadcast_parallel_mode[group_id]), - async_op=True, - ) if self._overlap_sync_param: + handle = dict() + handle["tensor"] = fp16_param + handle["src"] = g_rank + handle["group"] = gpc.get_group(self._broadcast_parallel_mode[group_id]) + handle["async_op"] = True self._param_bcast_sync_handler.add_bcast_handle(rank, handle) else: + handle = dist.broadcast( + fp16_param, + src=g_rank, + group=gpc.get_group(self._broadcast_parallel_mode[group_id]), + async_op=True, + ) handles.append(handle) for handle in handles: diff --git a/internlm/solver/optimizer/utils.py b/internlm/solver/optimizer/utils.py index db9eefa..57130bb 100644 --- a/internlm/solver/optimizer/utils.py +++ b/internlm/solver/optimizer/utils.py @@ -803,6 +803,8 @@ class ParamBcastSyncHandler: self._param_to_rank = dict() # self._block_to_rank = dict() # self._bcast_handles = dict() # + self._block_next_block = dict() # + self._block_to_handles = dict() # zero1_size = gpc.get_world_size(ParallelMode.ZERO1) total_param_num = sum(p.numel() for p in model.parameters()) @@ -824,10 +826,18 @@ class ParamBcastSyncHandler: for _, block in enumerate(children): # self._block_to_param[f"{name}.{idx}"] = list(block.parameters()) self._block_to_param[block] = list(block.parameters()) + key_list = list(self._block_to_param.keys()) + if len(key_list) > 1: + up_layer = key_list[-2] + self._block_next_block[up_layer] = key_list[-1] else: # record the block that a parameter belongs to # self._block_to_param[name] = list(children.parameters()) self._block_to_param[children] = list(children.parameters()) + key_list = list(self._block_to_param.keys()) + if len(key_list) > 1: + up_layer = key_list[-2] + self._block_next_block[up_layer] = key_list[-1] alloc_num = 0 rank_to_go = 0 @@ -857,16 +867,35 @@ class ParamBcastSyncHandler: # register_forward_pre_hook for transformer/embeding/norm/xxx block self._register_sync_parameters_hook() + def _launch_handle(self, layer): + handle_metas = [] + for rank in self._block_to_rank[layer]: + handle_metas.extend(self._bcast_handles[rank]) + # need to clear _bcast_handles since they would be processed later + self._bcast_handles[rank] = [] + # wait all required broadcast handles to be completed + handles = [] + for handle_meta in handle_metas: + handle = dist.broadcast(**handle_meta) + handles.append(handle) + self._block_to_handles[layer] = handles + def _register_sync_parameters_hook(self) -> None: def _pre_forward_hook(model: nn.Module, inputs: Any): # pylint: disable=W0613 - bcast_handles = [] - # gather all required broadcast hanles into a list - for rank in self._block_to_rank[model]: - bcast_handles.extend(self._bcast_handles[rank]) - # need to clear _bcast_handles since they would be processed later - self._bcast_handles[rank] = [] - # wait all required broadcast handles to be completed - for handle in bcast_handles: + current_layer = model + next_layer = self._block_next_block[current_layer] if current_layer in self._block_next_block else None + + # if this is the first layer + # launch broadcast for current layer + if current_layer == list(self._block_to_param.keys())[0]: + self._launch_handle(current_layer) + + # if this is not the last layer + # launch broadcast for next layer + if next_layer: + self._launch_handle(next_layer) + + for handle in self._block_to_handles[current_layer]: handle.wait() # register_forward_pre_hook for transformer/embeding/norm/xxx block diff --git a/internlm/utils/megatron_timers.py b/internlm/utils/megatron_timers.py index 94e52fa..d5d89e5 100644 --- a/internlm/utils/megatron_timers.py +++ b/internlm/utils/megatron_timers.py @@ -5,8 +5,6 @@ import time import torch -from internlm.core.context import global_context as gpc - class _Timer: """Timer.""" @@ -25,16 +23,14 @@ class _Timer: megatron_timer.reset() assert not self.started_, "timer has already been started" - if not gpc.config.hybrid_zero_optimizer.overlap_sync_param: - self.stream.synchronize() + self.stream.synchronize() self.start_time = time.time() self.started_ = True def stop(self): """Stop the timer.""" assert self.started_, "timer is not started" - if not gpc.config.hybrid_zero_optimizer.overlap_sync_param: - self.stream.synchronize() + self.stream.synchronize() self.elapsed_ += time.time() - self.start_time self.started_ = False