feat(model/linear.py): block-grained backward

pull/407/head
huangting4201 2023-10-17 10:13:56 +08:00
parent 0d1fa037dd
commit d1af0d6aee
2 changed files with 45 additions and 39 deletions

View File

@ -5,7 +5,7 @@ SEQ_LEN = 4096
HIDDEN_SIZE = 8192
NUM_ATTENTION_HEAD = 32
MLP_RATIO = 8 / 3
NUM_LAYER = 8
NUM_LAYER = 4
VOCAB_SIZE = 103168
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
@ -57,7 +57,7 @@ data = dict(
# defaults to 0, means disable evaluate
valid_every=50,
pack_sample_into_one=False,
total_steps=50000,
total_steps=20,
skip_batches="",
rampup_batch_size="",
# Datasets with less than 50 rows will be discarded
@ -161,10 +161,11 @@ pipeline parallel (dict):
sequence parallel (bool): enable/disable sequence parallel, defaults to False.
"""
parallel = dict(
zero1=dict(size=1, fsdp=False),
tensor=dict(size=8, mode='fstp'), # the mode should be 'origin_tp' or 'fstp'. if the mode is 'fstp', the sequence_parallel should be True
zero1=dict(size=-1, fsdp=False),
tensor=dict(size=8, mode="fstp"),
pipeline=dict(size=1, interleaved_overlap=True),
sequence_parallel=True,
block_0_full_weight=True,
)
cudnn_deterministic = False

View File

@ -559,21 +559,21 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
def _pre_backward_hook_for_block(block: nn.Module, grad_output):
block_index = self.block_to_index[block]
if block_index == gpc.config.NUM_LAYER - 1:
# all gather weight for the last block
fsdp_modules = self.index_to_fsdp_modules[block_index]
for module in fsdp_modules:
total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True)
weight_handle.wait()
self.FSTP_global_weights[module] = total_weight
else:
# wait handle for current block
handles = self.block_handles[block]
for handle in handles:
handle.wait()
# if block_index == gpc.config.NUM_LAYER - 1:
# # all gather weight for the last block
# fsdp_modules = self.index_to_fsdp_modules[block_index]
# for module in fsdp_modules:
# total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True)
# weight_handle.wait()
# self.FSTP_global_weights[module] = total_weight
# else:
# # wait handle for current block
# handles = self.block_handles[block]
# for handle in handles:
# handle.wait()
# start the all-gather for next block
if block_index - 1 >= 0:
if block_index - 1 > 0:
self._all_gather_block_weight(block_index - 1)
def _post_backward_hook_for_block(block: nn.Module, grad_input, grad_output):
@ -588,36 +588,41 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
block_index = self.module_to_index[module]
name_index = self.module_name_index[module]
if block_index != 0:
if name_index == 4:
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
weight_handler.wait()
self.FSTP_global_weights[module] = total_weight
# if name_index == 4:
# total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
# weight_handler.wait()
# self.FSTP_global_weights[module] = total_weight
# start the all-gather for next module
next_module = self.block_module[block_index][name_index - 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.FSTP_global_handle[next_module] = weights_handler
else:
# # start the all-gather for next module
# next_module = self.block_module[block_index][name_index - 1]
# self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
# next_module.weight, self.process_group, async_op=True
# )
# self.FSTP_global_handle[next_module] = weights_handler
# else:
# handler = self.FSTP_global_handle[module]
# handler.wait()
# if name_index != 0:
# next_module = self.block_module[block_index][name_index - 1]
# self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
# next_module.weight, self.process_group, async_op=True
# )
# self.FSTP_global_handle[next_module] = weights_handler
if module in self.FSTP_global_handle:
handler = self.FSTP_global_handle[module]
handler.wait()
if name_index != 0:
next_module = self.block_module[block_index][name_index - 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.FSTP_global_handle[next_module] = weights_handler
def _post_backward_hook_for_module(module, grad_input, grad_output):
if module in self.FSTP_global_weights:
del self.FSTP_global_weights[module]
if module in self.FSTP_global_handle:
del self.FSTP_global_handle[module]
# for block in self.FSTP_blocks:
# block.register_forward_pre_hook(_pre_forward_hook_for_block)
# block.register_forward_hook(_post_forward_hook_for_block)
# block.register_full_backward_pre_hook(_pre_backward_hook_for_block)
# block.register_full_backward_hook(_post_backward_hook_for_block)
for block in self.FSTP_blocks:
# block.register_forward_pre_hook(_pre_forward_hook_for_block)
# block.register_forward_hook(_post_forward_hook_for_block)
block.register_full_backward_pre_hook(_pre_backward_hook_for_block)
# block.register_full_backward_hook(_post_backward_hook_for_block)
for out_proj in self.FSTP_outs:
out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj)