mirror of https://github.com/InternLM/InternLM
feat(model/linear.py): block-grained backward
parent
0d1fa037dd
commit
d1af0d6aee
|
@ -5,7 +5,7 @@ SEQ_LEN = 4096
|
||||||
HIDDEN_SIZE = 8192
|
HIDDEN_SIZE = 8192
|
||||||
NUM_ATTENTION_HEAD = 32
|
NUM_ATTENTION_HEAD = 32
|
||||||
MLP_RATIO = 8 / 3
|
MLP_RATIO = 8 / 3
|
||||||
NUM_LAYER = 8
|
NUM_LAYER = 4
|
||||||
VOCAB_SIZE = 103168
|
VOCAB_SIZE = 103168
|
||||||
|
|
||||||
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
|
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
|
||||||
|
@ -57,7 +57,7 @@ data = dict(
|
||||||
# defaults to 0, means disable evaluate
|
# defaults to 0, means disable evaluate
|
||||||
valid_every=50,
|
valid_every=50,
|
||||||
pack_sample_into_one=False,
|
pack_sample_into_one=False,
|
||||||
total_steps=50000,
|
total_steps=20,
|
||||||
skip_batches="",
|
skip_batches="",
|
||||||
rampup_batch_size="",
|
rampup_batch_size="",
|
||||||
# Datasets with less than 50 rows will be discarded
|
# Datasets with less than 50 rows will be discarded
|
||||||
|
@ -161,10 +161,11 @@ pipeline parallel (dict):
|
||||||
sequence parallel (bool): enable/disable sequence parallel, defaults to False.
|
sequence parallel (bool): enable/disable sequence parallel, defaults to False.
|
||||||
"""
|
"""
|
||||||
parallel = dict(
|
parallel = dict(
|
||||||
zero1=dict(size=1, fsdp=False),
|
zero1=dict(size=-1, fsdp=False),
|
||||||
tensor=dict(size=8, mode='fstp'), # the mode should be 'origin_tp' or 'fstp'. if the mode is 'fstp', the sequence_parallel should be True
|
tensor=dict(size=8, mode="fstp"),
|
||||||
pipeline=dict(size=1, interleaved_overlap=True),
|
pipeline=dict(size=1, interleaved_overlap=True),
|
||||||
sequence_parallel=True,
|
sequence_parallel=True,
|
||||||
|
block_0_full_weight=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cudnn_deterministic = False
|
cudnn_deterministic = False
|
||||||
|
|
|
@ -559,21 +559,21 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
|
|
||||||
def _pre_backward_hook_for_block(block: nn.Module, grad_output):
|
def _pre_backward_hook_for_block(block: nn.Module, grad_output):
|
||||||
block_index = self.block_to_index[block]
|
block_index = self.block_to_index[block]
|
||||||
if block_index == gpc.config.NUM_LAYER - 1:
|
# if block_index == gpc.config.NUM_LAYER - 1:
|
||||||
# all gather weight for the last block
|
# # all gather weight for the last block
|
||||||
fsdp_modules = self.index_to_fsdp_modules[block_index]
|
# fsdp_modules = self.index_to_fsdp_modules[block_index]
|
||||||
for module in fsdp_modules:
|
# for module in fsdp_modules:
|
||||||
total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True)
|
# total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True)
|
||||||
weight_handle.wait()
|
# weight_handle.wait()
|
||||||
self.FSTP_global_weights[module] = total_weight
|
# self.FSTP_global_weights[module] = total_weight
|
||||||
else:
|
# else:
|
||||||
# wait handle for current block
|
# # wait handle for current block
|
||||||
handles = self.block_handles[block]
|
# handles = self.block_handles[block]
|
||||||
for handle in handles:
|
# for handle in handles:
|
||||||
handle.wait()
|
# handle.wait()
|
||||||
|
|
||||||
# start the all-gather for next block
|
# start the all-gather for next block
|
||||||
if block_index - 1 >= 0:
|
if block_index - 1 > 0:
|
||||||
self._all_gather_block_weight(block_index - 1)
|
self._all_gather_block_weight(block_index - 1)
|
||||||
|
|
||||||
def _post_backward_hook_for_block(block: nn.Module, grad_input, grad_output):
|
def _post_backward_hook_for_block(block: nn.Module, grad_input, grad_output):
|
||||||
|
@ -588,36 +588,41 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
block_index = self.module_to_index[module]
|
block_index = self.module_to_index[module]
|
||||||
name_index = self.module_name_index[module]
|
name_index = self.module_name_index[module]
|
||||||
if block_index != 0:
|
if block_index != 0:
|
||||||
if name_index == 4:
|
# if name_index == 4:
|
||||||
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
|
# total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
|
||||||
weight_handler.wait()
|
# weight_handler.wait()
|
||||||
self.FSTP_global_weights[module] = total_weight
|
# self.FSTP_global_weights[module] = total_weight
|
||||||
|
|
||||||
# start the all-gather for next module
|
# # start the all-gather for next module
|
||||||
next_module = self.block_module[block_index][name_index - 1]
|
# next_module = self.block_module[block_index][name_index - 1]
|
||||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
# self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
||||||
next_module.weight, self.process_group, async_op=True
|
# next_module.weight, self.process_group, async_op=True
|
||||||
)
|
# )
|
||||||
self.FSTP_global_handle[next_module] = weights_handler
|
# self.FSTP_global_handle[next_module] = weights_handler
|
||||||
else:
|
# else:
|
||||||
|
# handler = self.FSTP_global_handle[module]
|
||||||
|
# handler.wait()
|
||||||
|
# if name_index != 0:
|
||||||
|
# next_module = self.block_module[block_index][name_index - 1]
|
||||||
|
# self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
||||||
|
# next_module.weight, self.process_group, async_op=True
|
||||||
|
# )
|
||||||
|
# self.FSTP_global_handle[next_module] = weights_handler
|
||||||
|
if module in self.FSTP_global_handle:
|
||||||
handler = self.FSTP_global_handle[module]
|
handler = self.FSTP_global_handle[module]
|
||||||
handler.wait()
|
handler.wait()
|
||||||
if name_index != 0:
|
|
||||||
next_module = self.block_module[block_index][name_index - 1]
|
|
||||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
|
||||||
next_module.weight, self.process_group, async_op=True
|
|
||||||
)
|
|
||||||
self.FSTP_global_handle[next_module] = weights_handler
|
|
||||||
|
|
||||||
def _post_backward_hook_for_module(module, grad_input, grad_output):
|
def _post_backward_hook_for_module(module, grad_input, grad_output):
|
||||||
if module in self.FSTP_global_weights:
|
if module in self.FSTP_global_weights:
|
||||||
del self.FSTP_global_weights[module]
|
del self.FSTP_global_weights[module]
|
||||||
|
if module in self.FSTP_global_handle:
|
||||||
|
del self.FSTP_global_handle[module]
|
||||||
|
|
||||||
# for block in self.FSTP_blocks:
|
for block in self.FSTP_blocks:
|
||||||
# block.register_forward_pre_hook(_pre_forward_hook_for_block)
|
# block.register_forward_pre_hook(_pre_forward_hook_for_block)
|
||||||
# block.register_forward_hook(_post_forward_hook_for_block)
|
# block.register_forward_hook(_post_forward_hook_for_block)
|
||||||
# block.register_full_backward_pre_hook(_pre_backward_hook_for_block)
|
block.register_full_backward_pre_hook(_pre_backward_hook_for_block)
|
||||||
# block.register_full_backward_hook(_post_backward_hook_for_block)
|
# block.register_full_backward_hook(_post_backward_hook_for_block)
|
||||||
|
|
||||||
for out_proj in self.FSTP_outs:
|
for out_proj in self.FSTP_outs:
|
||||||
out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj)
|
out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj)
|
||||||
|
|
Loading…
Reference in New Issue