remove full weight for block 0

pull/407/head
yingtongxiong 2023-10-17 16:37:06 +08:00
parent 5c38cb6409
commit 5abe519c4c
2 changed files with 84 additions and 116 deletions

View File

@ -12,6 +12,7 @@ from torch import nn
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.naive_amp import NaiveAMPModel
from internlm.model.embedding import Embedding1D
from internlm.model.utils import (
Silu,
all_gather_raw,
@ -255,56 +256,33 @@ class FSTPFeedForward(nn.Module):
hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)
if block_idx == 0 and gpc.config.parallel.block_0_full_weight:
self.w1 = nn.Linear(
in_features,
hidden_features,
bias,
device=device,
dtype=dtype,
)
self.w2 = nn.Linear(
in_features,
hidden_features,
bias,
device=device,
dtype=dtype,
)
self.w3 = nn.Linear(
hidden_features,
out_features,
bias=bias,
device=device,
dtype=dtype,
)
else:
self.w1 = FSTPLinear(
in_features,
hidden_features,
process_group,
bias,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
self.w2 = FSTPLinear(
in_features,
hidden_features,
process_group,
bias,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
self.w3 = FSTPLinear(
hidden_features,
out_features,
process_group,
bias=bias,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
self.w1 = FSTPLinear(
in_features,
hidden_features,
process_group,
bias,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
self.w2 = FSTPLinear(
in_features,
hidden_features,
process_group,
bias,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
self.w3 = FSTPLinear(
hidden_features,
out_features,
process_group,
bias=bias,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
def forward(self, x):
w1_o = self.w1(x)
@ -458,6 +436,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name
self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module}
self.head = []
self.embedding = []
self.reduce_scatter_handlers = {}
@ -505,6 +484,8 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
continue
elif isinstance(children, ScaleColumnParallelLinear):
self.head.append(children)
elif isinstance(children, Embedding1D):
self.embedding.append(children)
def _all_gather_block_weight(self, block_index: int):
block = self.index_to_block[block_index]
@ -532,7 +513,6 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
# start the all-gather for next block
if block_index + 1 < gpc.config.NUM_LAYER:
self._all_gather_block_weight(block_index + 1)
# print(f"_all_gather_block_weight for block {block_index+1}", flush=True)
def _pre_forward_hook_for_block(block: nn.Module, inputs: Any):
block_index = self.block_to_index[block]
@ -548,6 +528,10 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
handles = self.block_handles[block]
for handle in handles:
handle.wait()
def _pre_forward_hook_for_embedding(module: nn.Module, inputs: Any, output):
self._all_gather_block_weight(0)
def _post_forward_hook_for_block(block: nn.Module, input, output):
block_index = self.block_to_index[block]
@ -557,11 +541,10 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
for module in fsdp_modules:
del self.FSTP_global_weights[module]
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any):
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any,):
block_index = self.module_to_index[module]
if block_index != 0:
handler = self.FSTP_global_handle[module]
handler.wait()
handler = self.FSTP_global_handle[module]
handler.wait()
def _post_forward_hook_for_module(module: nn.Module, input, output):
if module in self.FSTP_global_weights:
@ -593,7 +576,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
# if block_index == gpc.config.NUM_LAYER - 1:
# self._all_gather_block_weight(block_index)
# start the all-gather for next block
if block_index - 1 > 0:
if block_index - 1 >= 0:
self._all_gather_block_weight(block_index - 1)
def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output):
@ -613,38 +596,38 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
def _pre_backward_hook_for_module(module: nn.Module, grad_output):
block_index = self.module_to_index[module]
name_index = self.module_name_index[module]
if block_index != 0:
if name_index == 4 and block_index == gpc.config.NUM_LAYER - 1:
# total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
weight_handler = self.FSTP_global_handle[module]
weight_handler.wait()
# self.FSTP_global_weights[module] = total_weight
if name_index == 4 and block_index == gpc.config.NUM_LAYER - 1:
# total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
weight_handler = self.FSTP_global_handle[module]
weight_handler.wait()
# self.FSTP_global_weights[module] = total_weight
# start the all-gather for next module
# start the all-gather for next module
next_module = self.block_module[block_index][name_index - 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.FSTP_global_handle[next_module] = weights_handler
elif name_index == 0:
handler = self.FSTP_global_handle[module]
handler.wait()
if block_index - 1 >= 0:
next_module = self.block_module[block_index - 1][4]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.FSTP_global_handle[next_module] = weights_handler
else:
handler = self.FSTP_global_handle[module]
handler.wait()
if name_index != 0:
next_module = self.block_module[block_index][name_index - 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.FSTP_global_handle[next_module] = weights_handler
elif name_index == 0:
handler = self.FSTP_global_handle[module]
handler.wait()
if block_index - 1 > 0:
next_module = self.block_module[block_index - 1][4]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.FSTP_global_handle[next_module] = weights_handler
else:
handler = self.FSTP_global_handle[module]
handler.wait()
if name_index != 0:
next_module = self.block_module[block_index][name_index - 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.FSTP_global_handle[next_module] = weights_handler
# if module in self.FSTP_global_handle:
# handler = self.FSTP_global_handle[module]
# handler.wait()
@ -655,6 +638,9 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
if module in self.FSTP_global_handle:
del self.FSTP_global_handle[module]
for embedding in self.embedding:
embedding.register_forward_hook(_pre_forward_hook_for_embedding)
for head in self.head:
head.register_full_backward_hook(_post_backward_hook_for_head)

View File

@ -205,23 +205,14 @@ class MHA(nn.Module):
# notice here should change bias=True
Wqkv_cls = ColumnParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
if block_idx == 0 and tp_mode != "origin_tp" and gpc.config.parallel.block_0_full_weight:
Wqkv_cls = nn.Linear
self.Wqkv = Wqkv_cls(
embed_dim,
3 * embed_dim,
bias=False,
**factory_kwargs,
)
else:
self.Wqkv = Wqkv_cls(
embed_dim,
3 * embed_dim,
process_group,
bias=False,
sequence_parallel=gpc.config.parallel.sequence_parallel,
**factory_kwargs,
) # according to https://spaces.ac.cn/archives/9577
self.Wqkv = Wqkv_cls(
embed_dim,
3 * embed_dim,
process_group,
bias=False,
sequence_parallel=gpc.config.parallel.sequence_parallel,
**factory_kwargs,
) # according to https://spaces.ac.cn/archives/9577
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
@ -235,23 +226,14 @@ class MHA(nn.Module):
# output projection always have the bias (for now)
out_proj_cls = RowParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
if block_idx == 0 and tp_mode != "origin_tp" and gpc.config.parallel.block_0_full_weight:
out_proj_cls = nn.Linear
self.out_proj = out_proj_cls(
embed_dim,
embed_dim,
bias=False,
**factory_kwargs,
)
else:
self.out_proj = out_proj_cls(
embed_dim,
embed_dim,
process_group,
bias=False,
sequence_parallel=gpc.config.parallel.sequence_parallel,
**factory_kwargs,
)
self.out_proj = out_proj_cls(
embed_dim,
embed_dim,
process_group,
bias=False,
sequence_parallel=gpc.config.parallel.sequence_parallel,
**factory_kwargs,
)
# need to assign tp attribute so that internlm know it is tensor parallel module
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
for name in ["out_proj", "Wqkv"]: