support hybrid overlap

pull/407/head
yingtongxiong 2023-10-16 16:35:14 +08:00
parent d0f0c22cac
commit 82204eea59
3 changed files with 75 additions and 14 deletions

View File

@ -2,10 +2,10 @@ JOB_NAME = "7b_train"
DO_ALERT = False DO_ALERT = False
SEQ_LEN = 4096 SEQ_LEN = 4096
HIDDEN_SIZE = 4096 HIDDEN_SIZE = 8192
NUM_ATTENTION_HEAD = 32 NUM_ATTENTION_HEAD = 32
MLP_RATIO = 8 / 3 MLP_RATIO = 8 / 3
NUM_LAYER = 32 NUM_LAYER = 8
VOCAB_SIZE = 103168 VOCAB_SIZE = 103168
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx" MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"

View File

@ -360,10 +360,12 @@ class FSTPAllGatherSyncHandler:
self.module_handler[next_module] = weights_handler self.module_handler[next_module] = weights_handler
def _post_forward_hook(module: nn.Module, input, output): def _post_forward_hook(module: nn.Module, input, output):
del self.FSTP_global_weights[module] if module in self.FSTP_global_weights:
del self.module_handler[module] del self.FSTP_global_weights[module]
if module in self.module_handler:
del self.module_handler[module]
def _pre_backward_hook(module: nn.Module, grad_input, grad_output): def _pre_backward_hook(module: nn.Module, grad_output):
block_index = self.module_block[module] block_index = self.module_block[module]
name_index = self.module_name_index[module] name_index = self.module_name_index[module]
if name_index == 4: if name_index == 4:
@ -396,7 +398,8 @@ class FSTPAllGatherSyncHandler:
module.register_forward_hook(_post_forward_hook) module.register_forward_hook(_post_forward_hook)
# module.register_backward_pre_hook(_pre_backward_hook) # module.register_backward_pre_hook(_pre_backward_hook)
# module.register_backward_hook(_post_backward_hook) # module.register_backward_hook(_post_backward_hook)
module.register_module_full_backward_pre_hook(_pre_backward_hook) module.register_full_backward_pre_hook(_pre_backward_hook)
module.register_full_backward_hook(_post_backward_hook)
class CoarseGrainedFSTPAllGatherSyncHandler: class CoarseGrainedFSTPAllGatherSyncHandler:
@ -410,6 +413,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
self.FSTP_blocks = [] self.FSTP_blocks = []
self.FSTP_outs = [] self.FSTP_outs = []
self.FSTP_wqkvs = [] self.FSTP_wqkvs = []
self.FSTP_modules = []
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"] self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
self.FSTP_global_handle = dict() # key: FSTP module; value: module global all-gather op handle self.FSTP_global_handle = dict() # key: FSTP module; value: module global all-gather op handle
self.FSTP_global_weights = dict() # key: FSTP module; value: module global weight for forward self.FSTP_global_weights = dict() # key: FSTP module; value: module global weight for forward
@ -418,6 +422,8 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
self.block_to_index = dict() # key: transformer block; value: transformer block index self.block_to_index = dict() # key: transformer block; value: transformer block index
self.index_to_block = dict() # key: transformer block index; value: transformer block self.index_to_block = dict() # key: transformer block index; value: transformer block
self.index_to_fsdp_modules = dict() # key: transformer block index; value: fsdp modules self.index_to_fsdp_modules = dict() # key: transformer block index; value: fsdp modules
self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name
self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module}
# just want to share same for loop for ModuleList and Module # just want to share same for loop for ModuleList and Module
if not isinstance(model, nn.ModuleList): if not isinstance(model, nn.ModuleList):
@ -430,6 +436,8 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
for _, children in _chunk.named_children(): for _, children in _chunk.named_children():
if isinstance(children, nn.ModuleList): if isinstance(children, nn.ModuleList):
for idx, block in enumerate(children): for idx, block in enumerate(children):
index = 0
self.block_module[idx] = {}
self.FSTP_blocks.append(block) self.FSTP_blocks.append(block)
self.block_to_index[block] = idx self.block_to_index[block] = idx
self.index_to_block[idx] = block self.index_to_block[idx] = block
@ -441,12 +449,17 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
# print(f"name: {name}", flush=True) # print(f"name: {name}", flush=True)
if name == "out_proj": if name == "out_proj":
self.FSTP_outs.append(child) self.FSTP_outs.append(child)
self.module_to_index[child] = idx # self.module_to_index[child] = idx
if name == "Wqkv": if name == "Wqkv":
self.FSTP_wqkvs.append(child) self.FSTP_wqkvs.append(child)
self.module_to_index[child] = idx # self.module_to_index[child] = idx
if isinstance(child, FSTPLinear): if isinstance(child, FSTPLinear):
self.module_to_index[child] = idx
self.block_module[idx][index] = child
self.FSTP_modules.append(child)
self.index_to_fsdp_modules[idx].append(child) self.index_to_fsdp_modules[idx].append(child)
self.module_name_index[child] = index
index = index + 1
else: else:
continue continue
@ -457,6 +470,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
for module in fsdp_modules: for module in fsdp_modules:
total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True) total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True)
self.FSTP_global_weights[module] = total_weight self.FSTP_global_weights[module] = total_weight
self.FSTP_global_handle[module] = weight_handle
self.block_handles[block].append(weight_handle) self.block_handles[block].append(weight_handle)
def _register_sync_parameters_hook(self) -> None: def _register_sync_parameters_hook(self) -> None:
@ -499,6 +513,19 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
for module in fsdp_modules: for module in fsdp_modules:
del self.FSTP_global_weights[module] del self.FSTP_global_weights[module]
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any):
block_index = self.module_to_index[module]
if block_index != 0:
handler = self.FSTP_global_handle[module]
handler.wait()
def _post_forward_hook_for_module(module: nn.Module, input, output):
if module in self.FSTP_global_weights:
del self.FSTP_global_weights[module]
if module in self.FSTP_global_handle:
del self.FSTP_global_handle[module]
def _pre_backward_hook_for_wqkv(module: nn.Module, grad_output): def _pre_backward_hook_for_wqkv(module: nn.Module, grad_output):
block_index = self.module_to_index[module] block_index = self.module_to_index[module]
# start the all-gather for next block # start the all-gather for next block
@ -532,14 +559,47 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
for module in fsdp_modules: for module in fsdp_modules:
del self.FSTP_global_weights[module] del self.FSTP_global_weights[module]
for block in self.FSTP_blocks: def _pre_backward_hook_for_module(module: nn.Module, grad_output):
block.register_forward_pre_hook(_pre_forward_hook_for_block) block_index = self.module_to_index[module]
block.register_forward_hook(_post_forward_hook_for_block) name_index = self.module_name_index[module]
block.register_full_backward_pre_hook(_pre_backward_hook_for_block) if name_index == 4:
block.register_full_backward_hook(_post_backward_hook_for_block) total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
weight_handler.wait()
self.FSTP_global_weights[module] = total_weight
# start the all-gather for next module
next_module = self.block_module[block_index][name_index - 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.FSTP_global_handle[next_module] = weights_handler
else:
handler = self.FSTP_global_handle[module]
handler.wait()
if name_index != 0:
next_module = self.block_module[block_index][name_index - 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.FSTP_global_handle[next_module] = weights_handler
def _post_backward_hook_for_module(module, grad_input, grad_output):
del self.FSTP_global_weights[module]
# for block in self.FSTP_blocks:
# block.register_forward_pre_hook(_pre_forward_hook_for_block)
# block.register_forward_hook(_post_forward_hook_for_block)
# block.register_full_backward_pre_hook(_pre_backward_hook_for_block)
# block.register_full_backward_hook(_post_backward_hook_for_block)
for out_proj in self.FSTP_outs: for out_proj in self.FSTP_outs:
out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj) out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj)
# for wqkv in self.FSTP_wqkvs: # for wqkv in self.FSTP_wqkvs:
# wqkv.register_full_backward_pre_hook(_pre_backward_hook_for_wqkv) # wqkv.register_full_backward_pre_hook(_pre_backward_hook_for_wqkv)
for module in self.FSTP_modules:
module.register_forward_pre_hook(_pre_forward_hook_for_module)
module.register_forward_hook(_post_forward_hook_for_module)
module.register_full_backward_pre_hook(_pre_backward_hook_for_module)
module.register_full_backward_hook(_post_backward_hook_for_module)

View File

@ -110,7 +110,8 @@ def initialize_model():
model = wrap_FSDP_model(model) model = wrap_FSDP_model(model)
if gpc.config.parallel["tensor"]["mode"] == "fstp": if gpc.config.parallel["tensor"]["mode"] == "fstp":
handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR)) # handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
handler._register_sync_parameters_hook() handler._register_sync_parameters_hook()
gpc.config.fstp_handler = handler gpc.config.fstp_handler = handler
return model return model