diff --git a/internlm/solver/optimizer/utils.py b/internlm/solver/optimizer/utils.py index 57130bb..0c60329 100644 --- a/internlm/solver/optimizer/utils.py +++ b/internlm/solver/optimizer/utils.py @@ -803,8 +803,8 @@ class ParamBcastSyncHandler: self._param_to_rank = dict() # self._block_to_rank = dict() # self._bcast_handles = dict() # - self._block_next_block = dict() # - self._block_to_handles = dict() # + self._block_next_block = OrderedDict() # + self._block_to_handles = OrderedDict() # zero1_size = gpc.get_world_size(ParallelMode.ZERO1) total_param_num = sum(p.numel() for p in model.parameters()) @@ -819,25 +819,35 @@ class ParamBcastSyncHandler: if isinstance(_chunk, NaiveAMPModel): _chunk = _chunk.model - for _, children in _chunk.named_children(): + # In order to ensure orderliness, non-PackedFlashBaseLayer1D modules are required to + # communicate synchronously before the first fwd. + # BUG: The order of traversal is not necessarily the order of actual fwd/bwd execution!!! + last_module = None + for module_idx, children in _chunk.named_children(): + if gpc.get_global_rank() == 0: + print(f"children: {children.__class__.__name__}", flush=True) # should be the transformer block definaton in modeling_xxx.py if isinstance(children, nn.ModuleList): + assert module_idx != 0 # record the block that a parameter belongs to - for _, block in enumerate(children): + for layer_idx, block in enumerate(children): # self._block_to_param[f"{name}.{idx}"] = list(block.parameters()) self._block_to_param[block] = list(block.parameters()) - key_list = list(self._block_to_param.keys()) - if len(key_list) > 1: - up_layer = key_list[-2] - self._block_next_block[up_layer] = key_list[-1] + self._block_next_block[last_module] = block + if gpc.get_global_rank() == 0: + print(f"{block.__class__.__name__}_layer_{layer_idx}", flush=True) + last_module = block else: # record the block that a parameter belongs to - # self._block_to_param[name] = list(children.parameters()) self._block_to_param[children] = list(children.parameters()) - key_list = list(self._block_to_param.keys()) - if len(key_list) > 1: - up_layer = key_list[-2] - self._block_next_block[up_layer] = key_list[-1] + if module_idx == 0: + assert "embedding" in f"{children.__class__.__name__}" + assert last_module is None + self._block_next_block[children] = children + else: + self._block_next_block[last_module] = children + + last_module = children alloc_num = 0 rank_to_go = 0 @@ -882,20 +892,20 @@ class ParamBcastSyncHandler: def _register_sync_parameters_hook(self) -> None: def _pre_forward_hook(model: nn.Module, inputs: Any): # pylint: disable=W0613 - current_layer = model - next_layer = self._block_next_block[current_layer] if current_layer in self._block_next_block else None + current_module = model + next_layer = self._block_next_block[current_module] if current_module in self._block_next_block else None # if this is the first layer # launch broadcast for current layer - if current_layer == list(self._block_to_param.keys())[0]: - self._launch_handle(current_layer) + if current_module == list(self._block_to_param.keys())[0]: + self._launch_handle(current_module) # if this is not the last layer # launch broadcast for next layer if next_layer: self._launch_handle(next_layer) - for handle in self._block_to_handles[current_layer]: + for handle in self._block_to_handles[current_module]: handle.wait() # register_forward_pre_hook for transformer/embeding/norm/xxx block