support hybrid overlap

pull/407/head
yingtongxiong 2023-10-16 16:35:14 +08:00
parent d0f0c22cac
commit 82204eea59
3 changed files with 75 additions and 14 deletions

View File

@ -2,10 +2,10 @@ JOB_NAME = "7b_train"
DO_ALERT = False
SEQ_LEN = 4096
HIDDEN_SIZE = 4096
HIDDEN_SIZE = 8192
NUM_ATTENTION_HEAD = 32
MLP_RATIO = 8 / 3
NUM_LAYER = 32
NUM_LAYER = 8
VOCAB_SIZE = 103168
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"

View File

@ -360,10 +360,12 @@ class FSTPAllGatherSyncHandler:
self.module_handler[next_module] = weights_handler
def _post_forward_hook(module: nn.Module, input, output):
del self.FSTP_global_weights[module]
del self.module_handler[module]
if module in self.FSTP_global_weights:
del self.FSTP_global_weights[module]
if module in self.module_handler:
del self.module_handler[module]
def _pre_backward_hook(module: nn.Module, grad_input, grad_output):
def _pre_backward_hook(module: nn.Module, grad_output):
block_index = self.module_block[module]
name_index = self.module_name_index[module]
if name_index == 4:
@ -396,7 +398,8 @@ class FSTPAllGatherSyncHandler:
module.register_forward_hook(_post_forward_hook)
# module.register_backward_pre_hook(_pre_backward_hook)
# module.register_backward_hook(_post_backward_hook)
module.register_module_full_backward_pre_hook(_pre_backward_hook)
module.register_full_backward_pre_hook(_pre_backward_hook)
module.register_full_backward_hook(_post_backward_hook)
class CoarseGrainedFSTPAllGatherSyncHandler:
@ -410,6 +413,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
self.FSTP_blocks = []
self.FSTP_outs = []
self.FSTP_wqkvs = []
self.FSTP_modules = []
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
self.FSTP_global_handle = dict() # key: FSTP module; value: module global all-gather op handle
self.FSTP_global_weights = dict() # key: FSTP module; value: module global weight for forward
@ -418,6 +422,8 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
self.block_to_index = dict() # key: transformer block; value: transformer block index
self.index_to_block = dict() # key: transformer block index; value: transformer block
self.index_to_fsdp_modules = dict() # key: transformer block index; value: fsdp modules
self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name
self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module}
# just want to share same for loop for ModuleList and Module
if not isinstance(model, nn.ModuleList):
@ -430,6 +436,8 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
for _, children in _chunk.named_children():
if isinstance(children, nn.ModuleList):
for idx, block in enumerate(children):
index = 0
self.block_module[idx] = {}
self.FSTP_blocks.append(block)
self.block_to_index[block] = idx
self.index_to_block[idx] = block
@ -441,12 +449,17 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
# print(f"name: {name}", flush=True)
if name == "out_proj":
self.FSTP_outs.append(child)
self.module_to_index[child] = idx
# self.module_to_index[child] = idx
if name == "Wqkv":
self.FSTP_wqkvs.append(child)
self.module_to_index[child] = idx
# self.module_to_index[child] = idx
if isinstance(child, FSTPLinear):
self.module_to_index[child] = idx
self.block_module[idx][index] = child
self.FSTP_modules.append(child)
self.index_to_fsdp_modules[idx].append(child)
self.module_name_index[child] = index
index = index + 1
else:
continue
@ -457,6 +470,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
for module in fsdp_modules:
total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True)
self.FSTP_global_weights[module] = total_weight
self.FSTP_global_handle[module] = weight_handle
self.block_handles[block].append(weight_handle)
def _register_sync_parameters_hook(self) -> None:
@ -499,6 +513,19 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
for module in fsdp_modules:
del self.FSTP_global_weights[module]
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any):
block_index = self.module_to_index[module]
if block_index != 0:
handler = self.FSTP_global_handle[module]
handler.wait()
def _post_forward_hook_for_module(module: nn.Module, input, output):
if module in self.FSTP_global_weights:
del self.FSTP_global_weights[module]
if module in self.FSTP_global_handle:
del self.FSTP_global_handle[module]
def _pre_backward_hook_for_wqkv(module: nn.Module, grad_output):
block_index = self.module_to_index[module]
# start the all-gather for next block
@ -532,14 +559,47 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
for module in fsdp_modules:
del self.FSTP_global_weights[module]
for block in self.FSTP_blocks:
block.register_forward_pre_hook(_pre_forward_hook_for_block)
block.register_forward_hook(_post_forward_hook_for_block)
block.register_full_backward_pre_hook(_pre_backward_hook_for_block)
block.register_full_backward_hook(_post_backward_hook_for_block)
def _pre_backward_hook_for_module(module: nn.Module, grad_output):
block_index = self.module_to_index[module]
name_index = self.module_name_index[module]
if name_index == 4:
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
weight_handler.wait()
self.FSTP_global_weights[module] = total_weight
# start the all-gather for next module
next_module = self.block_module[block_index][name_index - 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.FSTP_global_handle[next_module] = weights_handler
else:
handler = self.FSTP_global_handle[module]
handler.wait()
if name_index != 0:
next_module = self.block_module[block_index][name_index - 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.FSTP_global_handle[next_module] = weights_handler
def _post_backward_hook_for_module(module, grad_input, grad_output):
del self.FSTP_global_weights[module]
# for block in self.FSTP_blocks:
# block.register_forward_pre_hook(_pre_forward_hook_for_block)
# block.register_forward_hook(_post_forward_hook_for_block)
# block.register_full_backward_pre_hook(_pre_backward_hook_for_block)
# block.register_full_backward_hook(_post_backward_hook_for_block)
for out_proj in self.FSTP_outs:
out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj)
# for wqkv in self.FSTP_wqkvs:
# wqkv.register_full_backward_pre_hook(_pre_backward_hook_for_wqkv)
for module in self.FSTP_modules:
module.register_forward_pre_hook(_pre_forward_hook_for_module)
module.register_forward_hook(_post_forward_hook_for_module)
module.register_full_backward_pre_hook(_pre_backward_hook_for_module)
module.register_full_backward_hook(_post_backward_hook_for_module)

View File

@ -110,7 +110,8 @@ def initialize_model():
model = wrap_FSDP_model(model)
if gpc.config.parallel["tensor"]["mode"] == "fstp":
handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
# handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
handler._register_sync_parameters_hook()
gpc.config.fstp_handler = handler
return model