mirror of https://github.com/InternLM/InternLM
wgt reformat
parent
d9c9f7c9ee
commit
eae9b97ab2
|
@ -803,8 +803,8 @@ class ParamBcastSyncHandler:
|
|||
self._param_to_rank = dict() # <key: param> <value: rank)>
|
||||
self._block_to_rank = dict() # <key: nn.Module> <value: rank)>
|
||||
self._bcast_handles = dict() # <key: rank> <value: list(bcast handles))>
|
||||
self._block_next_block = dict() # <key: nn.Module> <value: nn.Module>
|
||||
self._block_to_handles = dict() # <key: nn.Module> <value: list(bcast handles)>
|
||||
self._block_next_block = OrderedDict() # <key: nn.Module> <value: nn.Module>
|
||||
self._block_to_handles = OrderedDict() # <key: nn.Module> <value: list(bcast handles)>
|
||||
|
||||
zero1_size = gpc.get_world_size(ParallelMode.ZERO1)
|
||||
total_param_num = sum(p.numel() for p in model.parameters())
|
||||
|
@ -819,25 +819,35 @@ class ParamBcastSyncHandler:
|
|||
if isinstance(_chunk, NaiveAMPModel):
|
||||
_chunk = _chunk.model
|
||||
|
||||
for _, children in _chunk.named_children():
|
||||
# In order to ensure orderliness, non-PackedFlashBaseLayer1D modules are required to
|
||||
# communicate synchronously before the first fwd.
|
||||
# BUG: The order of traversal is not necessarily the order of actual fwd/bwd execution!!!
|
||||
last_module = None
|
||||
for module_idx, children in _chunk.named_children():
|
||||
if gpc.get_global_rank() == 0:
|
||||
print(f"children: {children.__class__.__name__}", flush=True)
|
||||
# should be the transformer block definaton in modeling_xxx.py
|
||||
if isinstance(children, nn.ModuleList):
|
||||
assert module_idx != 0
|
||||
# record the block that a parameter belongs to
|
||||
for _, block in enumerate(children):
|
||||
for layer_idx, block in enumerate(children):
|
||||
# self._block_to_param[f"{name}.{idx}"] = list(block.parameters())
|
||||
self._block_to_param[block] = list(block.parameters())
|
||||
key_list = list(self._block_to_param.keys())
|
||||
if len(key_list) > 1:
|
||||
up_layer = key_list[-2]
|
||||
self._block_next_block[up_layer] = key_list[-1]
|
||||
self._block_next_block[last_module] = block
|
||||
if gpc.get_global_rank() == 0:
|
||||
print(f"{block.__class__.__name__}_layer_{layer_idx}", flush=True)
|
||||
last_module = block
|
||||
else:
|
||||
# record the block that a parameter belongs to
|
||||
# self._block_to_param[name] = list(children.parameters())
|
||||
self._block_to_param[children] = list(children.parameters())
|
||||
key_list = list(self._block_to_param.keys())
|
||||
if len(key_list) > 1:
|
||||
up_layer = key_list[-2]
|
||||
self._block_next_block[up_layer] = key_list[-1]
|
||||
if module_idx == 0:
|
||||
assert "embedding" in f"{children.__class__.__name__}"
|
||||
assert last_module is None
|
||||
self._block_next_block[children] = children
|
||||
else:
|
||||
self._block_next_block[last_module] = children
|
||||
|
||||
last_module = children
|
||||
|
||||
alloc_num = 0
|
||||
rank_to_go = 0
|
||||
|
@ -882,20 +892,20 @@ class ParamBcastSyncHandler:
|
|||
|
||||
def _register_sync_parameters_hook(self) -> None:
|
||||
def _pre_forward_hook(model: nn.Module, inputs: Any): # pylint: disable=W0613
|
||||
current_layer = model
|
||||
next_layer = self._block_next_block[current_layer] if current_layer in self._block_next_block else None
|
||||
current_module = model
|
||||
next_layer = self._block_next_block[current_module] if current_module in self._block_next_block else None
|
||||
|
||||
# if this is the first layer
|
||||
# launch broadcast for current layer
|
||||
if current_layer == list(self._block_to_param.keys())[0]:
|
||||
self._launch_handle(current_layer)
|
||||
if current_module == list(self._block_to_param.keys())[0]:
|
||||
self._launch_handle(current_module)
|
||||
|
||||
# if this is not the last layer
|
||||
# launch broadcast for next layer
|
||||
if next_layer:
|
||||
self._launch_handle(next_layer)
|
||||
|
||||
for handle in self._block_to_handles[current_layer]:
|
||||
for handle in self._block_to_handles[current_module]:
|
||||
handle.wait()
|
||||
|
||||
# register_forward_pre_hook for transformer/embeding/norm/xxx block
|
||||
|
|
Loading…
Reference in New Issue