wgt reformat

pull/540/head
877825076@qq.com 2023-12-19 00:02:00 +08:00
parent d9c9f7c9ee
commit eae9b97ab2
1 changed files with 28 additions and 18 deletions

View File

@ -803,8 +803,8 @@ class ParamBcastSyncHandler:
self._param_to_rank = dict() # <key: param> <value: rank)>
self._block_to_rank = dict() # <key: nn.Module> <value: rank)>
self._bcast_handles = dict() # <key: rank> <value: list(bcast handles))>
self._block_next_block = dict() # <key: nn.Module> <value: nn.Module>
self._block_to_handles = dict() # <key: nn.Module> <value: list(bcast handles)>
self._block_next_block = OrderedDict() # <key: nn.Module> <value: nn.Module>
self._block_to_handles = OrderedDict() # <key: nn.Module> <value: list(bcast handles)>
zero1_size = gpc.get_world_size(ParallelMode.ZERO1)
total_param_num = sum(p.numel() for p in model.parameters())
@ -819,25 +819,35 @@ class ParamBcastSyncHandler:
if isinstance(_chunk, NaiveAMPModel):
_chunk = _chunk.model
for _, children in _chunk.named_children():
# In order to ensure orderliness, non-PackedFlashBaseLayer1D modules are required to
# communicate synchronously before the first fwd.
# BUG: The order of traversal is not necessarily the order of actual fwd/bwd execution!!!
last_module = None
for module_idx, children in _chunk.named_children():
if gpc.get_global_rank() == 0:
print(f"children: {children.__class__.__name__}", flush=True)
# should be the transformer block definaton in modeling_xxx.py
if isinstance(children, nn.ModuleList):
assert module_idx != 0
# record the block that a parameter belongs to
for _, block in enumerate(children):
for layer_idx, block in enumerate(children):
# self._block_to_param[f"{name}.{idx}"] = list(block.parameters())
self._block_to_param[block] = list(block.parameters())
key_list = list(self._block_to_param.keys())
if len(key_list) > 1:
up_layer = key_list[-2]
self._block_next_block[up_layer] = key_list[-1]
self._block_next_block[last_module] = block
if gpc.get_global_rank() == 0:
print(f"{block.__class__.__name__}_layer_{layer_idx}", flush=True)
last_module = block
else:
# record the block that a parameter belongs to
# self._block_to_param[name] = list(children.parameters())
self._block_to_param[children] = list(children.parameters())
key_list = list(self._block_to_param.keys())
if len(key_list) > 1:
up_layer = key_list[-2]
self._block_next_block[up_layer] = key_list[-1]
if module_idx == 0:
assert "embedding" in f"{children.__class__.__name__}"
assert last_module is None
self._block_next_block[children] = children
else:
self._block_next_block[last_module] = children
last_module = children
alloc_num = 0
rank_to_go = 0
@ -882,20 +892,20 @@ class ParamBcastSyncHandler:
def _register_sync_parameters_hook(self) -> None:
def _pre_forward_hook(model: nn.Module, inputs: Any): # pylint: disable=W0613
current_layer = model
next_layer = self._block_next_block[current_layer] if current_layer in self._block_next_block else None
current_module = model
next_layer = self._block_next_block[current_module] if current_module in self._block_next_block else None
# if this is the first layer
# launch broadcast for current layer
if current_layer == list(self._block_to_param.keys())[0]:
self._launch_handle(current_layer)
if current_module == list(self._block_to_param.keys())[0]:
self._launch_handle(current_module)
# if this is not the last layer
# launch broadcast for next layer
if next_layer:
self._launch_handle(next_layer)
for handle in self._block_to_handles[current_layer]:
for handle in self._block_to_handles[current_module]:
handle.wait()
# register_forward_pre_hook for transformer/embeding/norm/xxx block