mirror of https://github.com/InternLM/InternLM
remove full weight for block 0
parent
5c38cb6409
commit
5abe519c4c
|
@ -12,6 +12,7 @@ from torch import nn
|
||||||
from internlm.core.context import ParallelMode
|
from internlm.core.context import ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.core.naive_amp import NaiveAMPModel
|
from internlm.core.naive_amp import NaiveAMPModel
|
||||||
|
from internlm.model.embedding import Embedding1D
|
||||||
from internlm.model.utils import (
|
from internlm.model.utils import (
|
||||||
Silu,
|
Silu,
|
||||||
all_gather_raw,
|
all_gather_raw,
|
||||||
|
@ -255,56 +256,33 @@ class FSTPFeedForward(nn.Module):
|
||||||
|
|
||||||
hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)
|
hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)
|
||||||
|
|
||||||
if block_idx == 0 and gpc.config.parallel.block_0_full_weight:
|
self.w1 = FSTPLinear(
|
||||||
self.w1 = nn.Linear(
|
in_features,
|
||||||
in_features,
|
hidden_features,
|
||||||
hidden_features,
|
process_group,
|
||||||
bias,
|
bias,
|
||||||
device=device,
|
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||||
dtype=dtype,
|
device=device,
|
||||||
)
|
dtype=dtype,
|
||||||
self.w2 = nn.Linear(
|
)
|
||||||
in_features,
|
self.w2 = FSTPLinear(
|
||||||
hidden_features,
|
in_features,
|
||||||
bias,
|
hidden_features,
|
||||||
device=device,
|
process_group,
|
||||||
dtype=dtype,
|
bias,
|
||||||
)
|
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||||
self.w3 = nn.Linear(
|
device=device,
|
||||||
hidden_features,
|
dtype=dtype,
|
||||||
out_features,
|
)
|
||||||
bias=bias,
|
self.w3 = FSTPLinear(
|
||||||
device=device,
|
hidden_features,
|
||||||
dtype=dtype,
|
out_features,
|
||||||
)
|
process_group,
|
||||||
else:
|
bias=bias,
|
||||||
self.w1 = FSTPLinear(
|
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||||
in_features,
|
device=device,
|
||||||
hidden_features,
|
dtype=dtype,
|
||||||
process_group,
|
)
|
||||||
bias,
|
|
||||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
self.w2 = FSTPLinear(
|
|
||||||
in_features,
|
|
||||||
hidden_features,
|
|
||||||
process_group,
|
|
||||||
bias,
|
|
||||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
self.w3 = FSTPLinear(
|
|
||||||
hidden_features,
|
|
||||||
out_features,
|
|
||||||
process_group,
|
|
||||||
bias=bias,
|
|
||||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
w1_o = self.w1(x)
|
w1_o = self.w1(x)
|
||||||
|
@ -458,6 +436,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name
|
self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name
|
||||||
self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module}
|
self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module}
|
||||||
self.head = []
|
self.head = []
|
||||||
|
self.embedding = []
|
||||||
|
|
||||||
self.reduce_scatter_handlers = {}
|
self.reduce_scatter_handlers = {}
|
||||||
|
|
||||||
|
@ -505,6 +484,8 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
continue
|
continue
|
||||||
elif isinstance(children, ScaleColumnParallelLinear):
|
elif isinstance(children, ScaleColumnParallelLinear):
|
||||||
self.head.append(children)
|
self.head.append(children)
|
||||||
|
elif isinstance(children, Embedding1D):
|
||||||
|
self.embedding.append(children)
|
||||||
|
|
||||||
def _all_gather_block_weight(self, block_index: int):
|
def _all_gather_block_weight(self, block_index: int):
|
||||||
block = self.index_to_block[block_index]
|
block = self.index_to_block[block_index]
|
||||||
|
@ -532,7 +513,6 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
# start the all-gather for next block
|
# start the all-gather for next block
|
||||||
if block_index + 1 < gpc.config.NUM_LAYER:
|
if block_index + 1 < gpc.config.NUM_LAYER:
|
||||||
self._all_gather_block_weight(block_index + 1)
|
self._all_gather_block_weight(block_index + 1)
|
||||||
# print(f"_all_gather_block_weight for block {block_index+1}", flush=True)
|
|
||||||
|
|
||||||
def _pre_forward_hook_for_block(block: nn.Module, inputs: Any):
|
def _pre_forward_hook_for_block(block: nn.Module, inputs: Any):
|
||||||
block_index = self.block_to_index[block]
|
block_index = self.block_to_index[block]
|
||||||
|
@ -548,6 +528,10 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
handles = self.block_handles[block]
|
handles = self.block_handles[block]
|
||||||
for handle in handles:
|
for handle in handles:
|
||||||
handle.wait()
|
handle.wait()
|
||||||
|
|
||||||
|
def _pre_forward_hook_for_embedding(module: nn.Module, inputs: Any, output):
|
||||||
|
self._all_gather_block_weight(0)
|
||||||
|
|
||||||
|
|
||||||
def _post_forward_hook_for_block(block: nn.Module, input, output):
|
def _post_forward_hook_for_block(block: nn.Module, input, output):
|
||||||
block_index = self.block_to_index[block]
|
block_index = self.block_to_index[block]
|
||||||
|
@ -557,11 +541,10 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
for module in fsdp_modules:
|
for module in fsdp_modules:
|
||||||
del self.FSTP_global_weights[module]
|
del self.FSTP_global_weights[module]
|
||||||
|
|
||||||
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any):
|
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any,):
|
||||||
block_index = self.module_to_index[module]
|
block_index = self.module_to_index[module]
|
||||||
if block_index != 0:
|
handler = self.FSTP_global_handle[module]
|
||||||
handler = self.FSTP_global_handle[module]
|
handler.wait()
|
||||||
handler.wait()
|
|
||||||
|
|
||||||
def _post_forward_hook_for_module(module: nn.Module, input, output):
|
def _post_forward_hook_for_module(module: nn.Module, input, output):
|
||||||
if module in self.FSTP_global_weights:
|
if module in self.FSTP_global_weights:
|
||||||
|
@ -593,7 +576,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
# if block_index == gpc.config.NUM_LAYER - 1:
|
# if block_index == gpc.config.NUM_LAYER - 1:
|
||||||
# self._all_gather_block_weight(block_index)
|
# self._all_gather_block_weight(block_index)
|
||||||
# start the all-gather for next block
|
# start the all-gather for next block
|
||||||
if block_index - 1 > 0:
|
if block_index - 1 >= 0:
|
||||||
self._all_gather_block_weight(block_index - 1)
|
self._all_gather_block_weight(block_index - 1)
|
||||||
|
|
||||||
def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output):
|
def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output):
|
||||||
|
@ -613,38 +596,38 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
def _pre_backward_hook_for_module(module: nn.Module, grad_output):
|
def _pre_backward_hook_for_module(module: nn.Module, grad_output):
|
||||||
block_index = self.module_to_index[module]
|
block_index = self.module_to_index[module]
|
||||||
name_index = self.module_name_index[module]
|
name_index = self.module_name_index[module]
|
||||||
if block_index != 0:
|
|
||||||
if name_index == 4 and block_index == gpc.config.NUM_LAYER - 1:
|
if name_index == 4 and block_index == gpc.config.NUM_LAYER - 1:
|
||||||
# total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
|
# total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
|
||||||
weight_handler = self.FSTP_global_handle[module]
|
weight_handler = self.FSTP_global_handle[module]
|
||||||
weight_handler.wait()
|
weight_handler.wait()
|
||||||
# self.FSTP_global_weights[module] = total_weight
|
# self.FSTP_global_weights[module] = total_weight
|
||||||
|
|
||||||
# start the all-gather for next module
|
# start the all-gather for next module
|
||||||
|
next_module = self.block_module[block_index][name_index - 1]
|
||||||
|
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
||||||
|
next_module.weight, self.process_group, async_op=True
|
||||||
|
)
|
||||||
|
self.FSTP_global_handle[next_module] = weights_handler
|
||||||
|
elif name_index == 0:
|
||||||
|
handler = self.FSTP_global_handle[module]
|
||||||
|
handler.wait()
|
||||||
|
|
||||||
|
if block_index - 1 >= 0:
|
||||||
|
next_module = self.block_module[block_index - 1][4]
|
||||||
|
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
||||||
|
next_module.weight, self.process_group, async_op=True
|
||||||
|
)
|
||||||
|
self.FSTP_global_handle[next_module] = weights_handler
|
||||||
|
else:
|
||||||
|
handler = self.FSTP_global_handle[module]
|
||||||
|
handler.wait()
|
||||||
|
if name_index != 0:
|
||||||
next_module = self.block_module[block_index][name_index - 1]
|
next_module = self.block_module[block_index][name_index - 1]
|
||||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
||||||
next_module.weight, self.process_group, async_op=True
|
next_module.weight, self.process_group, async_op=True
|
||||||
)
|
)
|
||||||
self.FSTP_global_handle[next_module] = weights_handler
|
self.FSTP_global_handle[next_module] = weights_handler
|
||||||
elif name_index == 0:
|
|
||||||
handler = self.FSTP_global_handle[module]
|
|
||||||
handler.wait()
|
|
||||||
|
|
||||||
if block_index - 1 > 0:
|
|
||||||
next_module = self.block_module[block_index - 1][4]
|
|
||||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
|
||||||
next_module.weight, self.process_group, async_op=True
|
|
||||||
)
|
|
||||||
self.FSTP_global_handle[next_module] = weights_handler
|
|
||||||
else:
|
|
||||||
handler = self.FSTP_global_handle[module]
|
|
||||||
handler.wait()
|
|
||||||
if name_index != 0:
|
|
||||||
next_module = self.block_module[block_index][name_index - 1]
|
|
||||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
|
||||||
next_module.weight, self.process_group, async_op=True
|
|
||||||
)
|
|
||||||
self.FSTP_global_handle[next_module] = weights_handler
|
|
||||||
# if module in self.FSTP_global_handle:
|
# if module in self.FSTP_global_handle:
|
||||||
# handler = self.FSTP_global_handle[module]
|
# handler = self.FSTP_global_handle[module]
|
||||||
# handler.wait()
|
# handler.wait()
|
||||||
|
@ -655,6 +638,9 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
if module in self.FSTP_global_handle:
|
if module in self.FSTP_global_handle:
|
||||||
del self.FSTP_global_handle[module]
|
del self.FSTP_global_handle[module]
|
||||||
|
|
||||||
|
for embedding in self.embedding:
|
||||||
|
embedding.register_forward_hook(_pre_forward_hook_for_embedding)
|
||||||
|
|
||||||
for head in self.head:
|
for head in self.head:
|
||||||
head.register_full_backward_hook(_post_backward_hook_for_head)
|
head.register_full_backward_hook(_post_backward_hook_for_head)
|
||||||
|
|
||||||
|
|
|
@ -205,23 +205,14 @@ class MHA(nn.Module):
|
||||||
|
|
||||||
# notice here should change bias=True
|
# notice here should change bias=True
|
||||||
Wqkv_cls = ColumnParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
|
Wqkv_cls = ColumnParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
|
||||||
if block_idx == 0 and tp_mode != "origin_tp" and gpc.config.parallel.block_0_full_weight:
|
self.Wqkv = Wqkv_cls(
|
||||||
Wqkv_cls = nn.Linear
|
embed_dim,
|
||||||
self.Wqkv = Wqkv_cls(
|
3 * embed_dim,
|
||||||
embed_dim,
|
process_group,
|
||||||
3 * embed_dim,
|
bias=False,
|
||||||
bias=False,
|
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||||
**factory_kwargs,
|
**factory_kwargs,
|
||||||
)
|
) # according to https://spaces.ac.cn/archives/9577
|
||||||
else:
|
|
||||||
self.Wqkv = Wqkv_cls(
|
|
||||||
embed_dim,
|
|
||||||
3 * embed_dim,
|
|
||||||
process_group,
|
|
||||||
bias=False,
|
|
||||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
|
||||||
**factory_kwargs,
|
|
||||||
) # according to https://spaces.ac.cn/archives/9577
|
|
||||||
|
|
||||||
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
|
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
|
||||||
inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
|
inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
|
||||||
|
@ -235,23 +226,14 @@ class MHA(nn.Module):
|
||||||
|
|
||||||
# output projection always have the bias (for now)
|
# output projection always have the bias (for now)
|
||||||
out_proj_cls = RowParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
|
out_proj_cls = RowParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
|
||||||
if block_idx == 0 and tp_mode != "origin_tp" and gpc.config.parallel.block_0_full_weight:
|
self.out_proj = out_proj_cls(
|
||||||
out_proj_cls = nn.Linear
|
embed_dim,
|
||||||
self.out_proj = out_proj_cls(
|
embed_dim,
|
||||||
embed_dim,
|
process_group,
|
||||||
embed_dim,
|
bias=False,
|
||||||
bias=False,
|
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||||
**factory_kwargs,
|
**factory_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
self.out_proj = out_proj_cls(
|
|
||||||
embed_dim,
|
|
||||||
embed_dim,
|
|
||||||
process_group,
|
|
||||||
bias=False,
|
|
||||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
|
||||||
**factory_kwargs,
|
|
||||||
)
|
|
||||||
# need to assign tp attribute so that internlm know it is tensor parallel module
|
# need to assign tp attribute so that internlm know it is tensor parallel module
|
||||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||||
for name in ["out_proj", "Wqkv"]:
|
for name in ["out_proj", "Wqkv"]:
|
||||||
|
|
Loading…
Reference in New Issue