mirror of https://github.com/InternLM/InternLM
remove full weight for block 0
parent
5c38cb6409
commit
5abe519c4c
|
@ -12,6 +12,7 @@ from torch import nn
|
|||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.naive_amp import NaiveAMPModel
|
||||
from internlm.model.embedding import Embedding1D
|
||||
from internlm.model.utils import (
|
||||
Silu,
|
||||
all_gather_raw,
|
||||
|
@ -255,56 +256,33 @@ class FSTPFeedForward(nn.Module):
|
|||
|
||||
hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)
|
||||
|
||||
if block_idx == 0 and gpc.config.parallel.block_0_full_weight:
|
||||
self.w1 = nn.Linear(
|
||||
in_features,
|
||||
hidden_features,
|
||||
bias,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.w2 = nn.Linear(
|
||||
in_features,
|
||||
hidden_features,
|
||||
bias,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.w3 = nn.Linear(
|
||||
hidden_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
else:
|
||||
self.w1 = FSTPLinear(
|
||||
in_features,
|
||||
hidden_features,
|
||||
process_group,
|
||||
bias,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.w2 = FSTPLinear(
|
||||
in_features,
|
||||
hidden_features,
|
||||
process_group,
|
||||
bias,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.w3 = FSTPLinear(
|
||||
hidden_features,
|
||||
out_features,
|
||||
process_group,
|
||||
bias=bias,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.w1 = FSTPLinear(
|
||||
in_features,
|
||||
hidden_features,
|
||||
process_group,
|
||||
bias,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.w2 = FSTPLinear(
|
||||
in_features,
|
||||
hidden_features,
|
||||
process_group,
|
||||
bias,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.w3 = FSTPLinear(
|
||||
hidden_features,
|
||||
out_features,
|
||||
process_group,
|
||||
bias=bias,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
w1_o = self.w1(x)
|
||||
|
@ -458,6 +436,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
|||
self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name
|
||||
self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module}
|
||||
self.head = []
|
||||
self.embedding = []
|
||||
|
||||
self.reduce_scatter_handlers = {}
|
||||
|
||||
|
@ -505,6 +484,8 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
|||
continue
|
||||
elif isinstance(children, ScaleColumnParallelLinear):
|
||||
self.head.append(children)
|
||||
elif isinstance(children, Embedding1D):
|
||||
self.embedding.append(children)
|
||||
|
||||
def _all_gather_block_weight(self, block_index: int):
|
||||
block = self.index_to_block[block_index]
|
||||
|
@ -532,7 +513,6 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
|||
# start the all-gather for next block
|
||||
if block_index + 1 < gpc.config.NUM_LAYER:
|
||||
self._all_gather_block_weight(block_index + 1)
|
||||
# print(f"_all_gather_block_weight for block {block_index+1}", flush=True)
|
||||
|
||||
def _pre_forward_hook_for_block(block: nn.Module, inputs: Any):
|
||||
block_index = self.block_to_index[block]
|
||||
|
@ -548,6 +528,10 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
|||
handles = self.block_handles[block]
|
||||
for handle in handles:
|
||||
handle.wait()
|
||||
|
||||
def _pre_forward_hook_for_embedding(module: nn.Module, inputs: Any, output):
|
||||
self._all_gather_block_weight(0)
|
||||
|
||||
|
||||
def _post_forward_hook_for_block(block: nn.Module, input, output):
|
||||
block_index = self.block_to_index[block]
|
||||
|
@ -557,11 +541,10 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
|||
for module in fsdp_modules:
|
||||
del self.FSTP_global_weights[module]
|
||||
|
||||
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any):
|
||||
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any,):
|
||||
block_index = self.module_to_index[module]
|
||||
if block_index != 0:
|
||||
handler = self.FSTP_global_handle[module]
|
||||
handler.wait()
|
||||
handler = self.FSTP_global_handle[module]
|
||||
handler.wait()
|
||||
|
||||
def _post_forward_hook_for_module(module: nn.Module, input, output):
|
||||
if module in self.FSTP_global_weights:
|
||||
|
@ -593,7 +576,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
|||
# if block_index == gpc.config.NUM_LAYER - 1:
|
||||
# self._all_gather_block_weight(block_index)
|
||||
# start the all-gather for next block
|
||||
if block_index - 1 > 0:
|
||||
if block_index - 1 >= 0:
|
||||
self._all_gather_block_weight(block_index - 1)
|
||||
|
||||
def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output):
|
||||
|
@ -613,38 +596,38 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
|||
def _pre_backward_hook_for_module(module: nn.Module, grad_output):
|
||||
block_index = self.module_to_index[module]
|
||||
name_index = self.module_name_index[module]
|
||||
if block_index != 0:
|
||||
if name_index == 4 and block_index == gpc.config.NUM_LAYER - 1:
|
||||
# total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
|
||||
weight_handler = self.FSTP_global_handle[module]
|
||||
weight_handler.wait()
|
||||
# self.FSTP_global_weights[module] = total_weight
|
||||
|
||||
if name_index == 4 and block_index == gpc.config.NUM_LAYER - 1:
|
||||
# total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
|
||||
weight_handler = self.FSTP_global_handle[module]
|
||||
weight_handler.wait()
|
||||
# self.FSTP_global_weights[module] = total_weight
|
||||
|
||||
# start the all-gather for next module
|
||||
# start the all-gather for next module
|
||||
next_module = self.block_module[block_index][name_index - 1]
|
||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
||||
next_module.weight, self.process_group, async_op=True
|
||||
)
|
||||
self.FSTP_global_handle[next_module] = weights_handler
|
||||
elif name_index == 0:
|
||||
handler = self.FSTP_global_handle[module]
|
||||
handler.wait()
|
||||
|
||||
if block_index - 1 >= 0:
|
||||
next_module = self.block_module[block_index - 1][4]
|
||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
||||
next_module.weight, self.process_group, async_op=True
|
||||
)
|
||||
self.FSTP_global_handle[next_module] = weights_handler
|
||||
else:
|
||||
handler = self.FSTP_global_handle[module]
|
||||
handler.wait()
|
||||
if name_index != 0:
|
||||
next_module = self.block_module[block_index][name_index - 1]
|
||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
||||
next_module.weight, self.process_group, async_op=True
|
||||
)
|
||||
self.FSTP_global_handle[next_module] = weights_handler
|
||||
elif name_index == 0:
|
||||
handler = self.FSTP_global_handle[module]
|
||||
handler.wait()
|
||||
|
||||
if block_index - 1 > 0:
|
||||
next_module = self.block_module[block_index - 1][4]
|
||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
||||
next_module.weight, self.process_group, async_op=True
|
||||
)
|
||||
self.FSTP_global_handle[next_module] = weights_handler
|
||||
else:
|
||||
handler = self.FSTP_global_handle[module]
|
||||
handler.wait()
|
||||
if name_index != 0:
|
||||
next_module = self.block_module[block_index][name_index - 1]
|
||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
||||
next_module.weight, self.process_group, async_op=True
|
||||
)
|
||||
self.FSTP_global_handle[next_module] = weights_handler
|
||||
# if module in self.FSTP_global_handle:
|
||||
# handler = self.FSTP_global_handle[module]
|
||||
# handler.wait()
|
||||
|
@ -655,6 +638,9 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
|||
if module in self.FSTP_global_handle:
|
||||
del self.FSTP_global_handle[module]
|
||||
|
||||
for embedding in self.embedding:
|
||||
embedding.register_forward_hook(_pre_forward_hook_for_embedding)
|
||||
|
||||
for head in self.head:
|
||||
head.register_full_backward_hook(_post_backward_hook_for_head)
|
||||
|
||||
|
|
|
@ -205,23 +205,14 @@ class MHA(nn.Module):
|
|||
|
||||
# notice here should change bias=True
|
||||
Wqkv_cls = ColumnParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
|
||||
if block_idx == 0 and tp_mode != "origin_tp" and gpc.config.parallel.block_0_full_weight:
|
||||
Wqkv_cls = nn.Linear
|
||||
self.Wqkv = Wqkv_cls(
|
||||
embed_dim,
|
||||
3 * embed_dim,
|
||||
bias=False,
|
||||
**factory_kwargs,
|
||||
)
|
||||
else:
|
||||
self.Wqkv = Wqkv_cls(
|
||||
embed_dim,
|
||||
3 * embed_dim,
|
||||
process_group,
|
||||
bias=False,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
**factory_kwargs,
|
||||
) # according to https://spaces.ac.cn/archives/9577
|
||||
self.Wqkv = Wqkv_cls(
|
||||
embed_dim,
|
||||
3 * embed_dim,
|
||||
process_group,
|
||||
bias=False,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
**factory_kwargs,
|
||||
) # according to https://spaces.ac.cn/archives/9577
|
||||
|
||||
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
|
||||
inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
|
||||
|
@ -235,23 +226,14 @@ class MHA(nn.Module):
|
|||
|
||||
# output projection always have the bias (for now)
|
||||
out_proj_cls = RowParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
|
||||
if block_idx == 0 and tp_mode != "origin_tp" and gpc.config.parallel.block_0_full_weight:
|
||||
out_proj_cls = nn.Linear
|
||||
self.out_proj = out_proj_cls(
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
bias=False,
|
||||
**factory_kwargs,
|
||||
)
|
||||
else:
|
||||
self.out_proj = out_proj_cls(
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
process_group,
|
||||
bias=False,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
**factory_kwargs,
|
||||
)
|
||||
self.out_proj = out_proj_cls(
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
process_group,
|
||||
bias=False,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
**factory_kwargs,
|
||||
)
|
||||
# need to assign tp attribute so that internlm know it is tensor parallel module
|
||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||
for name in ["out_proj", "Wqkv"]:
|
||||
|
|
Loading…
Reference in New Issue