remove full weight for block 0

pull/407/head
yingtongxiong 2023-10-17 16:37:06 +08:00
parent 5c38cb6409
commit 5abe519c4c
2 changed files with 84 additions and 116 deletions

View File

@ -12,6 +12,7 @@ from torch import nn
from internlm.core.context import ParallelMode from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.core.naive_amp import NaiveAMPModel from internlm.core.naive_amp import NaiveAMPModel
from internlm.model.embedding import Embedding1D
from internlm.model.utils import ( from internlm.model.utils import (
Silu, Silu,
all_gather_raw, all_gather_raw,
@ -255,29 +256,6 @@ class FSTPFeedForward(nn.Module):
hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of) hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)
if block_idx == 0 and gpc.config.parallel.block_0_full_weight:
self.w1 = nn.Linear(
in_features,
hidden_features,
bias,
device=device,
dtype=dtype,
)
self.w2 = nn.Linear(
in_features,
hidden_features,
bias,
device=device,
dtype=dtype,
)
self.w3 = nn.Linear(
hidden_features,
out_features,
bias=bias,
device=device,
dtype=dtype,
)
else:
self.w1 = FSTPLinear( self.w1 = FSTPLinear(
in_features, in_features,
hidden_features, hidden_features,
@ -458,6 +436,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name
self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module} self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module}
self.head = [] self.head = []
self.embedding = []
self.reduce_scatter_handlers = {} self.reduce_scatter_handlers = {}
@ -505,6 +484,8 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
continue continue
elif isinstance(children, ScaleColumnParallelLinear): elif isinstance(children, ScaleColumnParallelLinear):
self.head.append(children) self.head.append(children)
elif isinstance(children, Embedding1D):
self.embedding.append(children)
def _all_gather_block_weight(self, block_index: int): def _all_gather_block_weight(self, block_index: int):
block = self.index_to_block[block_index] block = self.index_to_block[block_index]
@ -532,7 +513,6 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
# start the all-gather for next block # start the all-gather for next block
if block_index + 1 < gpc.config.NUM_LAYER: if block_index + 1 < gpc.config.NUM_LAYER:
self._all_gather_block_weight(block_index + 1) self._all_gather_block_weight(block_index + 1)
# print(f"_all_gather_block_weight for block {block_index+1}", flush=True)
def _pre_forward_hook_for_block(block: nn.Module, inputs: Any): def _pre_forward_hook_for_block(block: nn.Module, inputs: Any):
block_index = self.block_to_index[block] block_index = self.block_to_index[block]
@ -549,6 +529,10 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
for handle in handles: for handle in handles:
handle.wait() handle.wait()
def _pre_forward_hook_for_embedding(module: nn.Module, inputs: Any, output):
self._all_gather_block_weight(0)
def _post_forward_hook_for_block(block: nn.Module, input, output): def _post_forward_hook_for_block(block: nn.Module, input, output):
block_index = self.block_to_index[block] block_index = self.block_to_index[block]
fsdp_modules = self.index_to_fsdp_modules[block_index] fsdp_modules = self.index_to_fsdp_modules[block_index]
@ -557,9 +541,8 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
for module in fsdp_modules: for module in fsdp_modules:
del self.FSTP_global_weights[module] del self.FSTP_global_weights[module]
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any): def _pre_forward_hook_for_module(module: nn.Module, inputs: Any,):
block_index = self.module_to_index[module] block_index = self.module_to_index[module]
if block_index != 0:
handler = self.FSTP_global_handle[module] handler = self.FSTP_global_handle[module]
handler.wait() handler.wait()
@ -593,7 +576,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
# if block_index == gpc.config.NUM_LAYER - 1: # if block_index == gpc.config.NUM_LAYER - 1:
# self._all_gather_block_weight(block_index) # self._all_gather_block_weight(block_index)
# start the all-gather for next block # start the all-gather for next block
if block_index - 1 > 0: if block_index - 1 >= 0:
self._all_gather_block_weight(block_index - 1) self._all_gather_block_weight(block_index - 1)
def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output): def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output):
@ -613,7 +596,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
def _pre_backward_hook_for_module(module: nn.Module, grad_output): def _pre_backward_hook_for_module(module: nn.Module, grad_output):
block_index = self.module_to_index[module] block_index = self.module_to_index[module]
name_index = self.module_name_index[module] name_index = self.module_name_index[module]
if block_index != 0:
if name_index == 4 and block_index == gpc.config.NUM_LAYER - 1: if name_index == 4 and block_index == gpc.config.NUM_LAYER - 1:
# total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True) # total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
weight_handler = self.FSTP_global_handle[module] weight_handler = self.FSTP_global_handle[module]
@ -630,7 +613,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
handler = self.FSTP_global_handle[module] handler = self.FSTP_global_handle[module]
handler.wait() handler.wait()
if block_index - 1 > 0: if block_index - 1 >= 0:
next_module = self.block_module[block_index - 1][4] next_module = self.block_module[block_index - 1][4]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True next_module.weight, self.process_group, async_op=True
@ -655,6 +638,9 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
if module in self.FSTP_global_handle: if module in self.FSTP_global_handle:
del self.FSTP_global_handle[module] del self.FSTP_global_handle[module]
for embedding in self.embedding:
embedding.register_forward_hook(_pre_forward_hook_for_embedding)
for head in self.head: for head in self.head:
head.register_full_backward_hook(_post_backward_hook_for_head) head.register_full_backward_hook(_post_backward_hook_for_head)

View File

@ -205,15 +205,6 @@ class MHA(nn.Module):
# notice here should change bias=True # notice here should change bias=True
Wqkv_cls = ColumnParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear Wqkv_cls = ColumnParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
if block_idx == 0 and tp_mode != "origin_tp" and gpc.config.parallel.block_0_full_weight:
Wqkv_cls = nn.Linear
self.Wqkv = Wqkv_cls(
embed_dim,
3 * embed_dim,
bias=False,
**factory_kwargs,
)
else:
self.Wqkv = Wqkv_cls( self.Wqkv = Wqkv_cls(
embed_dim, embed_dim,
3 * embed_dim, 3 * embed_dim,
@ -235,15 +226,6 @@ class MHA(nn.Module):
# output projection always have the bias (for now) # output projection always have the bias (for now)
out_proj_cls = RowParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear out_proj_cls = RowParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
if block_idx == 0 and tp_mode != "origin_tp" and gpc.config.parallel.block_0_full_weight:
out_proj_cls = nn.Linear
self.out_proj = out_proj_cls(
embed_dim,
embed_dim,
bias=False,
**factory_kwargs,
)
else:
self.out_proj = out_proj_cls( self.out_proj = out_proj_cls(
embed_dim, embed_dim,
embed_dim, embed_dim,