mirror of https://github.com/InternLM/InternLM
support hybrid overlap
parent
d0f0c22cac
commit
82204eea59
|
@ -2,10 +2,10 @@ JOB_NAME = "7b_train"
|
||||||
DO_ALERT = False
|
DO_ALERT = False
|
||||||
|
|
||||||
SEQ_LEN = 4096
|
SEQ_LEN = 4096
|
||||||
HIDDEN_SIZE = 4096
|
HIDDEN_SIZE = 8192
|
||||||
NUM_ATTENTION_HEAD = 32
|
NUM_ATTENTION_HEAD = 32
|
||||||
MLP_RATIO = 8 / 3
|
MLP_RATIO = 8 / 3
|
||||||
NUM_LAYER = 32
|
NUM_LAYER = 8
|
||||||
VOCAB_SIZE = 103168
|
VOCAB_SIZE = 103168
|
||||||
|
|
||||||
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
|
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
|
||||||
|
|
|
@ -360,10 +360,12 @@ class FSTPAllGatherSyncHandler:
|
||||||
self.module_handler[next_module] = weights_handler
|
self.module_handler[next_module] = weights_handler
|
||||||
|
|
||||||
def _post_forward_hook(module: nn.Module, input, output):
|
def _post_forward_hook(module: nn.Module, input, output):
|
||||||
|
if module in self.FSTP_global_weights:
|
||||||
del self.FSTP_global_weights[module]
|
del self.FSTP_global_weights[module]
|
||||||
|
if module in self.module_handler:
|
||||||
del self.module_handler[module]
|
del self.module_handler[module]
|
||||||
|
|
||||||
def _pre_backward_hook(module: nn.Module, grad_input, grad_output):
|
def _pre_backward_hook(module: nn.Module, grad_output):
|
||||||
block_index = self.module_block[module]
|
block_index = self.module_block[module]
|
||||||
name_index = self.module_name_index[module]
|
name_index = self.module_name_index[module]
|
||||||
if name_index == 4:
|
if name_index == 4:
|
||||||
|
@ -396,7 +398,8 @@ class FSTPAllGatherSyncHandler:
|
||||||
module.register_forward_hook(_post_forward_hook)
|
module.register_forward_hook(_post_forward_hook)
|
||||||
# module.register_backward_pre_hook(_pre_backward_hook)
|
# module.register_backward_pre_hook(_pre_backward_hook)
|
||||||
# module.register_backward_hook(_post_backward_hook)
|
# module.register_backward_hook(_post_backward_hook)
|
||||||
module.register_module_full_backward_pre_hook(_pre_backward_hook)
|
module.register_full_backward_pre_hook(_pre_backward_hook)
|
||||||
|
module.register_full_backward_hook(_post_backward_hook)
|
||||||
|
|
||||||
|
|
||||||
class CoarseGrainedFSTPAllGatherSyncHandler:
|
class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
|
@ -410,6 +413,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
self.FSTP_blocks = []
|
self.FSTP_blocks = []
|
||||||
self.FSTP_outs = []
|
self.FSTP_outs = []
|
||||||
self.FSTP_wqkvs = []
|
self.FSTP_wqkvs = []
|
||||||
|
self.FSTP_modules = []
|
||||||
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
|
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
|
||||||
self.FSTP_global_handle = dict() # key: FSTP module; value: module global all-gather op handle
|
self.FSTP_global_handle = dict() # key: FSTP module; value: module global all-gather op handle
|
||||||
self.FSTP_global_weights = dict() # key: FSTP module; value: module global weight for forward
|
self.FSTP_global_weights = dict() # key: FSTP module; value: module global weight for forward
|
||||||
|
@ -418,6 +422,8 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
self.block_to_index = dict() # key: transformer block; value: transformer block index
|
self.block_to_index = dict() # key: transformer block; value: transformer block index
|
||||||
self.index_to_block = dict() # key: transformer block index; value: transformer block
|
self.index_to_block = dict() # key: transformer block index; value: transformer block
|
||||||
self.index_to_fsdp_modules = dict() # key: transformer block index; value: fsdp modules
|
self.index_to_fsdp_modules = dict() # key: transformer block index; value: fsdp modules
|
||||||
|
self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name
|
||||||
|
self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module}
|
||||||
|
|
||||||
# just want to share same for loop for ModuleList and Module
|
# just want to share same for loop for ModuleList and Module
|
||||||
if not isinstance(model, nn.ModuleList):
|
if not isinstance(model, nn.ModuleList):
|
||||||
|
@ -430,6 +436,8 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
for _, children in _chunk.named_children():
|
for _, children in _chunk.named_children():
|
||||||
if isinstance(children, nn.ModuleList):
|
if isinstance(children, nn.ModuleList):
|
||||||
for idx, block in enumerate(children):
|
for idx, block in enumerate(children):
|
||||||
|
index = 0
|
||||||
|
self.block_module[idx] = {}
|
||||||
self.FSTP_blocks.append(block)
|
self.FSTP_blocks.append(block)
|
||||||
self.block_to_index[block] = idx
|
self.block_to_index[block] = idx
|
||||||
self.index_to_block[idx] = block
|
self.index_to_block[idx] = block
|
||||||
|
@ -441,12 +449,17 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
# print(f"name: {name}", flush=True)
|
# print(f"name: {name}", flush=True)
|
||||||
if name == "out_proj":
|
if name == "out_proj":
|
||||||
self.FSTP_outs.append(child)
|
self.FSTP_outs.append(child)
|
||||||
self.module_to_index[child] = idx
|
# self.module_to_index[child] = idx
|
||||||
if name == "Wqkv":
|
if name == "Wqkv":
|
||||||
self.FSTP_wqkvs.append(child)
|
self.FSTP_wqkvs.append(child)
|
||||||
self.module_to_index[child] = idx
|
# self.module_to_index[child] = idx
|
||||||
if isinstance(child, FSTPLinear):
|
if isinstance(child, FSTPLinear):
|
||||||
|
self.module_to_index[child] = idx
|
||||||
|
self.block_module[idx][index] = child
|
||||||
|
self.FSTP_modules.append(child)
|
||||||
self.index_to_fsdp_modules[idx].append(child)
|
self.index_to_fsdp_modules[idx].append(child)
|
||||||
|
self.module_name_index[child] = index
|
||||||
|
index = index + 1
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -457,6 +470,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
for module in fsdp_modules:
|
for module in fsdp_modules:
|
||||||
total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True)
|
total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True)
|
||||||
self.FSTP_global_weights[module] = total_weight
|
self.FSTP_global_weights[module] = total_weight
|
||||||
|
self.FSTP_global_handle[module] = weight_handle
|
||||||
self.block_handles[block].append(weight_handle)
|
self.block_handles[block].append(weight_handle)
|
||||||
|
|
||||||
def _register_sync_parameters_hook(self) -> None:
|
def _register_sync_parameters_hook(self) -> None:
|
||||||
|
@ -499,6 +513,19 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
for module in fsdp_modules:
|
for module in fsdp_modules:
|
||||||
del self.FSTP_global_weights[module]
|
del self.FSTP_global_weights[module]
|
||||||
|
|
||||||
|
|
||||||
|
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any):
|
||||||
|
block_index = self.module_to_index[module]
|
||||||
|
if block_index != 0:
|
||||||
|
handler = self.FSTP_global_handle[module]
|
||||||
|
handler.wait()
|
||||||
|
|
||||||
|
def _post_forward_hook_for_module(module: nn.Module, input, output):
|
||||||
|
if module in self.FSTP_global_weights:
|
||||||
|
del self.FSTP_global_weights[module]
|
||||||
|
if module in self.FSTP_global_handle:
|
||||||
|
del self.FSTP_global_handle[module]
|
||||||
|
|
||||||
def _pre_backward_hook_for_wqkv(module: nn.Module, grad_output):
|
def _pre_backward_hook_for_wqkv(module: nn.Module, grad_output):
|
||||||
block_index = self.module_to_index[module]
|
block_index = self.module_to_index[module]
|
||||||
# start the all-gather for next block
|
# start the all-gather for next block
|
||||||
|
@ -532,14 +559,47 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
for module in fsdp_modules:
|
for module in fsdp_modules:
|
||||||
del self.FSTP_global_weights[module]
|
del self.FSTP_global_weights[module]
|
||||||
|
|
||||||
for block in self.FSTP_blocks:
|
def _pre_backward_hook_for_module(module: nn.Module, grad_output):
|
||||||
block.register_forward_pre_hook(_pre_forward_hook_for_block)
|
block_index = self.module_to_index[module]
|
||||||
block.register_forward_hook(_post_forward_hook_for_block)
|
name_index = self.module_name_index[module]
|
||||||
block.register_full_backward_pre_hook(_pre_backward_hook_for_block)
|
if name_index == 4:
|
||||||
block.register_full_backward_hook(_post_backward_hook_for_block)
|
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
|
||||||
|
weight_handler.wait()
|
||||||
|
self.FSTP_global_weights[module] = total_weight
|
||||||
|
|
||||||
|
# start the all-gather for next module
|
||||||
|
next_module = self.block_module[block_index][name_index - 1]
|
||||||
|
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
||||||
|
next_module.weight, self.process_group, async_op=True
|
||||||
|
)
|
||||||
|
self.FSTP_global_handle[next_module] = weights_handler
|
||||||
|
else:
|
||||||
|
handler = self.FSTP_global_handle[module]
|
||||||
|
handler.wait()
|
||||||
|
if name_index != 0:
|
||||||
|
next_module = self.block_module[block_index][name_index - 1]
|
||||||
|
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
||||||
|
next_module.weight, self.process_group, async_op=True
|
||||||
|
)
|
||||||
|
self.FSTP_global_handle[next_module] = weights_handler
|
||||||
|
|
||||||
|
def _post_backward_hook_for_module(module, grad_input, grad_output):
|
||||||
|
del self.FSTP_global_weights[module]
|
||||||
|
|
||||||
|
# for block in self.FSTP_blocks:
|
||||||
|
# block.register_forward_pre_hook(_pre_forward_hook_for_block)
|
||||||
|
# block.register_forward_hook(_post_forward_hook_for_block)
|
||||||
|
# block.register_full_backward_pre_hook(_pre_backward_hook_for_block)
|
||||||
|
# block.register_full_backward_hook(_post_backward_hook_for_block)
|
||||||
|
|
||||||
for out_proj in self.FSTP_outs:
|
for out_proj in self.FSTP_outs:
|
||||||
out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj)
|
out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj)
|
||||||
|
|
||||||
# for wqkv in self.FSTP_wqkvs:
|
# for wqkv in self.FSTP_wqkvs:
|
||||||
# wqkv.register_full_backward_pre_hook(_pre_backward_hook_for_wqkv)
|
# wqkv.register_full_backward_pre_hook(_pre_backward_hook_for_wqkv)
|
||||||
|
|
||||||
|
for module in self.FSTP_modules:
|
||||||
|
module.register_forward_pre_hook(_pre_forward_hook_for_module)
|
||||||
|
module.register_forward_hook(_post_forward_hook_for_module)
|
||||||
|
module.register_full_backward_pre_hook(_pre_backward_hook_for_module)
|
||||||
|
module.register_full_backward_hook(_post_backward_hook_for_module)
|
||||||
|
|
|
@ -110,7 +110,8 @@ def initialize_model():
|
||||||
model = wrap_FSDP_model(model)
|
model = wrap_FSDP_model(model)
|
||||||
|
|
||||||
if gpc.config.parallel["tensor"]["mode"] == "fstp":
|
if gpc.config.parallel["tensor"]["mode"] == "fstp":
|
||||||
handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
# handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
||||||
|
handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
||||||
handler._register_sync_parameters_hook()
|
handler._register_sync_parameters_hook()
|
||||||
gpc.config.fstp_handler = handler
|
gpc.config.fstp_handler = handler
|
||||||
return model
|
return model
|
||||||
|
|
Loading…
Reference in New Issue