diff --git a/internlm/model/linear.py b/internlm/model/linear.py index 71bdf05..cc9524a 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -12,6 +12,7 @@ from torch import nn from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.naive_amp import NaiveAMPModel +from internlm.model.embedding import Embedding1D from internlm.model.utils import ( Silu, all_gather_raw, @@ -255,56 +256,33 @@ class FSTPFeedForward(nn.Module): hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of) - if block_idx == 0 and gpc.config.parallel.block_0_full_weight: - self.w1 = nn.Linear( - in_features, - hidden_features, - bias, - device=device, - dtype=dtype, - ) - self.w2 = nn.Linear( - in_features, - hidden_features, - bias, - device=device, - dtype=dtype, - ) - self.w3 = nn.Linear( - hidden_features, - out_features, - bias=bias, - device=device, - dtype=dtype, - ) - else: - self.w1 = FSTPLinear( - in_features, - hidden_features, - process_group, - bias, - sequence_parallel=gpc.config.parallel.sequence_parallel, - device=device, - dtype=dtype, - ) - self.w2 = FSTPLinear( - in_features, - hidden_features, - process_group, - bias, - sequence_parallel=gpc.config.parallel.sequence_parallel, - device=device, - dtype=dtype, - ) - self.w3 = FSTPLinear( - hidden_features, - out_features, - process_group, - bias=bias, - sequence_parallel=gpc.config.parallel.sequence_parallel, - device=device, - dtype=dtype, - ) + self.w1 = FSTPLinear( + in_features, + hidden_features, + process_group, + bias, + sequence_parallel=gpc.config.parallel.sequence_parallel, + device=device, + dtype=dtype, + ) + self.w2 = FSTPLinear( + in_features, + hidden_features, + process_group, + bias, + sequence_parallel=gpc.config.parallel.sequence_parallel, + device=device, + dtype=dtype, + ) + self.w3 = FSTPLinear( + hidden_features, + out_features, + process_group, + bias=bias, + sequence_parallel=gpc.config.parallel.sequence_parallel, + device=device, + dtype=dtype, + ) def forward(self, x): w1_o = self.w1(x) @@ -458,6 +436,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler: self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module} self.head = [] + self.embedding = [] self.reduce_scatter_handlers = {} @@ -505,6 +484,8 @@ class CoarseGrainedFSTPAllGatherSyncHandler: continue elif isinstance(children, ScaleColumnParallelLinear): self.head.append(children) + elif isinstance(children, Embedding1D): + self.embedding.append(children) def _all_gather_block_weight(self, block_index: int): block = self.index_to_block[block_index] @@ -532,7 +513,6 @@ class CoarseGrainedFSTPAllGatherSyncHandler: # start the all-gather for next block if block_index + 1 < gpc.config.NUM_LAYER: self._all_gather_block_weight(block_index + 1) - # print(f"_all_gather_block_weight for block {block_index+1}", flush=True) def _pre_forward_hook_for_block(block: nn.Module, inputs: Any): block_index = self.block_to_index[block] @@ -548,6 +528,10 @@ class CoarseGrainedFSTPAllGatherSyncHandler: handles = self.block_handles[block] for handle in handles: handle.wait() + + def _pre_forward_hook_for_embedding(module: nn.Module, inputs: Any, output): + self._all_gather_block_weight(0) + def _post_forward_hook_for_block(block: nn.Module, input, output): block_index = self.block_to_index[block] @@ -557,11 +541,10 @@ class CoarseGrainedFSTPAllGatherSyncHandler: for module in fsdp_modules: del self.FSTP_global_weights[module] - def _pre_forward_hook_for_module(module: nn.Module, inputs: Any): + def _pre_forward_hook_for_module(module: nn.Module, inputs: Any,): block_index = self.module_to_index[module] - if block_index != 0: - handler = self.FSTP_global_handle[module] - handler.wait() + handler = self.FSTP_global_handle[module] + handler.wait() def _post_forward_hook_for_module(module: nn.Module, input, output): if module in self.FSTP_global_weights: @@ -593,7 +576,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler: # if block_index == gpc.config.NUM_LAYER - 1: # self._all_gather_block_weight(block_index) # start the all-gather for next block - if block_index - 1 > 0: + if block_index - 1 >= 0: self._all_gather_block_weight(block_index - 1) def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output): @@ -613,38 +596,38 @@ class CoarseGrainedFSTPAllGatherSyncHandler: def _pre_backward_hook_for_module(module: nn.Module, grad_output): block_index = self.module_to_index[module] name_index = self.module_name_index[module] - if block_index != 0: - if name_index == 4 and block_index == gpc.config.NUM_LAYER - 1: - # total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True) - weight_handler = self.FSTP_global_handle[module] - weight_handler.wait() - # self.FSTP_global_weights[module] = total_weight + + if name_index == 4 and block_index == gpc.config.NUM_LAYER - 1: + # total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True) + weight_handler = self.FSTP_global_handle[module] + weight_handler.wait() + # self.FSTP_global_weights[module] = total_weight - # start the all-gather for next module + # start the all-gather for next module + next_module = self.block_module[block_index][name_index - 1] + self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( + next_module.weight, self.process_group, async_op=True + ) + self.FSTP_global_handle[next_module] = weights_handler + elif name_index == 0: + handler = self.FSTP_global_handle[module] + handler.wait() + + if block_index - 1 >= 0: + next_module = self.block_module[block_index - 1][4] + self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( + next_module.weight, self.process_group, async_op=True + ) + self.FSTP_global_handle[next_module] = weights_handler + else: + handler = self.FSTP_global_handle[module] + handler.wait() + if name_index != 0: next_module = self.block_module[block_index][name_index - 1] self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( next_module.weight, self.process_group, async_op=True ) self.FSTP_global_handle[next_module] = weights_handler - elif name_index == 0: - handler = self.FSTP_global_handle[module] - handler.wait() - - if block_index - 1 > 0: - next_module = self.block_module[block_index - 1][4] - self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( - next_module.weight, self.process_group, async_op=True - ) - self.FSTP_global_handle[next_module] = weights_handler - else: - handler = self.FSTP_global_handle[module] - handler.wait() - if name_index != 0: - next_module = self.block_module[block_index][name_index - 1] - self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( - next_module.weight, self.process_group, async_op=True - ) - self.FSTP_global_handle[next_module] = weights_handler # if module in self.FSTP_global_handle: # handler = self.FSTP_global_handle[module] # handler.wait() @@ -655,6 +638,9 @@ class CoarseGrainedFSTPAllGatherSyncHandler: if module in self.FSTP_global_handle: del self.FSTP_global_handle[module] + for embedding in self.embedding: + embedding.register_forward_hook(_pre_forward_hook_for_embedding) + for head in self.head: head.register_full_backward_hook(_post_backward_hook_for_head) diff --git a/internlm/model/multi_head_attention.py b/internlm/model/multi_head_attention.py index 6c1e7d8..7a0f4ed 100644 --- a/internlm/model/multi_head_attention.py +++ b/internlm/model/multi_head_attention.py @@ -205,23 +205,14 @@ class MHA(nn.Module): # notice here should change bias=True Wqkv_cls = ColumnParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear - if block_idx == 0 and tp_mode != "origin_tp" and gpc.config.parallel.block_0_full_weight: - Wqkv_cls = nn.Linear - self.Wqkv = Wqkv_cls( - embed_dim, - 3 * embed_dim, - bias=False, - **factory_kwargs, - ) - else: - self.Wqkv = Wqkv_cls( - embed_dim, - 3 * embed_dim, - process_group, - bias=False, - sequence_parallel=gpc.config.parallel.sequence_parallel, - **factory_kwargs, - ) # according to https://spaces.ac.cn/archives/9577 + self.Wqkv = Wqkv_cls( + embed_dim, + 3 * embed_dim, + process_group, + bias=False, + sequence_parallel=gpc.config.parallel.sequence_parallel, + **factory_kwargs, + ) # according to https://spaces.ac.cn/archives/9577 inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention @@ -235,23 +226,14 @@ class MHA(nn.Module): # output projection always have the bias (for now) out_proj_cls = RowParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear - if block_idx == 0 and tp_mode != "origin_tp" and gpc.config.parallel.block_0_full_weight: - out_proj_cls = nn.Linear - self.out_proj = out_proj_cls( - embed_dim, - embed_dim, - bias=False, - **factory_kwargs, - ) - else: - self.out_proj = out_proj_cls( - embed_dim, - embed_dim, - process_group, - bias=False, - sequence_parallel=gpc.config.parallel.sequence_parallel, - **factory_kwargs, - ) + self.out_proj = out_proj_cls( + embed_dim, + embed_dim, + process_group, + bias=False, + sequence_parallel=gpc.config.parallel.sequence_parallel, + **factory_kwargs, + ) # need to assign tp attribute so that internlm know it is tensor parallel module if gpc.get_world_size(ParallelMode.TENSOR) > 1: for name in ["out_proj", "Wqkv"]: