From 0d1fa037ddd3c899e3c42fbb9c013b17c4dd03dc Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Mon, 16 Oct 2023 20:13:59 +0800 Subject: [PATCH] feat(model/linear.py): set block 0 full weight --- internlm/model/linear.py | 133 +++++++++++++++---------- internlm/model/modeling_internlm.py | 6 +- internlm/model/multi_head_attention.py | 53 ++++++---- internlm/train/training_internlm.py | 13 ++- 4 files changed, 131 insertions(+), 74 deletions(-) diff --git a/internlm/model/linear.py b/internlm/model/linear.py index 890f1cb..8a17c71 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -175,6 +175,7 @@ class FeedForward(nn.Module): device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, multiple_of: int = 256, + block_idx: int = 0, ): super().__init__() @@ -248,38 +249,62 @@ class FSTPFeedForward(nn.Module): device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, multiple_of: int = 256, + block_idx: int = 0, ): super().__init__() hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of) - self.w1 = FSTPLinear( - in_features, - hidden_features, - process_group, - bias, - sequence_parallel=gpc.config.parallel.sequence_parallel, - device=device, - dtype=dtype, - ) - self.w2 = FSTPLinear( - in_features, - hidden_features, - process_group, - bias, - sequence_parallel=gpc.config.parallel.sequence_parallel, - device=device, - dtype=dtype, - ) - self.w3 = FSTPLinear( - hidden_features, - out_features, - process_group, - bias=bias, - sequence_parallel=gpc.config.parallel.sequence_parallel, - device=device, - dtype=dtype, - ) + if block_idx == 0 and gpc.config.parallel.block_0_full_weight: + self.w1 = nn.Linear( + in_features, + hidden_features, + bias, + device=device, + dtype=dtype, + ) + self.w2 = nn.Linear( + in_features, + hidden_features, + bias, + device=device, + dtype=dtype, + ) + self.w3 = nn.Linear( + hidden_features, + out_features, + bias=bias, + device=device, + dtype=dtype, + ) + else: + self.w1 = FSTPLinear( + in_features, + hidden_features, + process_group, + bias, + sequence_parallel=gpc.config.parallel.sequence_parallel, + device=device, + dtype=dtype, + ) + self.w2 = FSTPLinear( + in_features, + hidden_features, + process_group, + bias, + sequence_parallel=gpc.config.parallel.sequence_parallel, + device=device, + dtype=dtype, + ) + self.w3 = FSTPLinear( + hidden_features, + out_features, + process_group, + bias=bias, + sequence_parallel=gpc.config.parallel.sequence_parallel, + device=device, + dtype=dtype, + ) def forward(self, x): w1_o = self.w1(x) @@ -449,10 +474,10 @@ class CoarseGrainedFSTPAllGatherSyncHandler: # print(f"name: {name}", flush=True) if name == "out_proj": self.FSTP_outs.append(child) - # self.module_to_index[child] = idx + self.module_to_index[child] = idx if name == "Wqkv": self.FSTP_wqkvs.append(child) - # self.module_to_index[child] = idx + self.module_to_index[child] = idx if isinstance(child, FSTPLinear): self.module_to_index[child] = idx self.block_module[idx][index] = child @@ -489,6 +514,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler: # start the all-gather for next block if block_index + 1 < gpc.config.NUM_LAYER: self._all_gather_block_weight(block_index + 1) + # print(f"_all_gather_block_weight for block {block_index+1}", flush=True) def _pre_forward_hook_for_block(block: nn.Module, inputs: Any): block_index = self.block_to_index[block] @@ -512,14 +538,13 @@ class CoarseGrainedFSTPAllGatherSyncHandler: del self.block_handles[block] for module in fsdp_modules: del self.FSTP_global_weights[module] - - + def _pre_forward_hook_for_module(module: nn.Module, inputs: Any): block_index = self.module_to_index[module] if block_index != 0: handler = self.FSTP_global_handle[module] handler.wait() - + def _post_forward_hook_for_module(module: nn.Module, input, output): if module in self.FSTP_global_weights: del self.FSTP_global_weights[module] @@ -558,46 +583,48 @@ class CoarseGrainedFSTPAllGatherSyncHandler: del self.block_handles[block] for module in fsdp_modules: del self.FSTP_global_weights[module] - + def _pre_backward_hook_for_module(module: nn.Module, grad_output): block_index = self.module_to_index[module] name_index = self.module_name_index[module] - if name_index == 4: - total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True) - weight_handler.wait() - self.FSTP_global_weights[module] = total_weight + if block_index != 0: + if name_index == 4: + total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True) + weight_handler.wait() + self.FSTP_global_weights[module] = total_weight - # start the all-gather for next module - next_module = self.block_module[block_index][name_index - 1] - self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( - next_module.weight, self.process_group, async_op=True - ) - self.FSTP_global_handle[next_module] = weights_handler - else: - handler = self.FSTP_global_handle[module] - handler.wait() - if name_index != 0: + # start the all-gather for next module next_module = self.block_module[block_index][name_index - 1] self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( next_module.weight, self.process_group, async_op=True ) self.FSTP_global_handle[next_module] = weights_handler + else: + handler = self.FSTP_global_handle[module] + handler.wait() + if name_index != 0: + next_module = self.block_module[block_index][name_index - 1] + self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( + next_module.weight, self.process_group, async_op=True + ) + self.FSTP_global_handle[next_module] = weights_handler def _post_backward_hook_for_module(module, grad_input, grad_output): - del self.FSTP_global_weights[module] + if module in self.FSTP_global_weights: + del self.FSTP_global_weights[module] # for block in self.FSTP_blocks: - # block.register_forward_pre_hook(_pre_forward_hook_for_block) - # block.register_forward_hook(_post_forward_hook_for_block) - # block.register_full_backward_pre_hook(_pre_backward_hook_for_block) - # block.register_full_backward_hook(_post_backward_hook_for_block) + # block.register_forward_pre_hook(_pre_forward_hook_for_block) + # block.register_forward_hook(_post_forward_hook_for_block) + # block.register_full_backward_pre_hook(_pre_backward_hook_for_block) + # block.register_full_backward_hook(_post_backward_hook_for_block) for out_proj in self.FSTP_outs: out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj) # for wqkv in self.FSTP_wqkvs: # wqkv.register_full_backward_pre_hook(_pre_backward_hook_for_wqkv) - + for module in self.FSTP_modules: module.register_forward_pre_hook(_pre_forward_hook_for_module) module.register_forward_hook(_post_forward_hook_for_module) diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index 228dbd3..cb93396 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -78,6 +78,7 @@ class PackedFlashBaseLayer1D(nn.Module): use_swiglu: bool = True, use_flash_attn: bool = True, tp_mode: str = "origin_tp", + block_idx: int = 0, ): super().__init__() self.checkpoint = checkpoint @@ -103,6 +104,7 @@ class PackedFlashBaseLayer1D(nn.Module): device=device, dtype=dtype, tp_mode=tp_mode, + block_idx=block_idx, ) self.dropout1 = nn.Dropout(drop_rate) @@ -123,6 +125,7 @@ class PackedFlashBaseLayer1D(nn.Module): bias=False, device=device, dtype=dtype, + block_idx=block_idx, ) else: self.mlp = ParallelFusedMLP( @@ -344,6 +347,7 @@ class PackedFlashInternLm1D(nn.Module): use_swiglu=use_swiglu, use_flash_attn=use_flash_attn, tp_mode=self.tp_mode, + block_idx=lid, ) for lid in range(num_layers) ] @@ -410,7 +414,7 @@ class PackedFlashInternLm1D(nn.Module): # Evaluation if hidden_states.ndim == 3: hidden_states = self.head(hidden_states, gather_dim=1) - else: # Training + else: # Training hidden_states = self.head(hidden_states, gather_dim=0) if not self.parallel_output: diff --git a/internlm/model/multi_head_attention.py b/internlm/model/multi_head_attention.py index 1db98d7..6c1e7d8 100644 --- a/internlm/model/multi_head_attention.py +++ b/internlm/model/multi_head_attention.py @@ -51,7 +51,6 @@ class _SeqAllToAll(torch.autograd.Function): @staticmethod def forward(ctx: Any, group: dist.ProcessGroup, input_: Tensor, scatter_idx: int, gather_idx: int) -> Tensor: - ctx.group = group ctx.scatter_idx = scatter_idx ctx.gather_idx = gather_idx @@ -91,7 +90,6 @@ class DistributedAttention(torch.nn.Module): second_scatter_idx: int = 0, second_gather_idx: int = 1, ) -> None: - super().__init__() self.local_attn = local_attention self.spg = sequence_process_group @@ -178,6 +176,7 @@ class MHA(nn.Module): device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, tp_mode: str = "origin_tp", + block_idx: int = 0, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() @@ -206,14 +205,23 @@ class MHA(nn.Module): # notice here should change bias=True Wqkv_cls = ColumnParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear - self.Wqkv = Wqkv_cls( - embed_dim, - 3 * embed_dim, - process_group, - bias=False, - sequence_parallel=gpc.config.parallel.sequence_parallel, - **factory_kwargs, - ) # according to https://spaces.ac.cn/archives/9577 + if block_idx == 0 and tp_mode != "origin_tp" and gpc.config.parallel.block_0_full_weight: + Wqkv_cls = nn.Linear + self.Wqkv = Wqkv_cls( + embed_dim, + 3 * embed_dim, + bias=False, + **factory_kwargs, + ) + else: + self.Wqkv = Wqkv_cls( + embed_dim, + 3 * embed_dim, + process_group, + bias=False, + sequence_parallel=gpc.config.parallel.sequence_parallel, + **factory_kwargs, + ) # according to https://spaces.ac.cn/archives/9577 inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention @@ -227,14 +235,23 @@ class MHA(nn.Module): # output projection always have the bias (for now) out_proj_cls = RowParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear - self.out_proj = out_proj_cls( - embed_dim, - embed_dim, - process_group, - bias=False, - sequence_parallel=gpc.config.parallel.sequence_parallel, - **factory_kwargs, - ) + if block_idx == 0 and tp_mode != "origin_tp" and gpc.config.parallel.block_0_full_weight: + out_proj_cls = nn.Linear + self.out_proj = out_proj_cls( + embed_dim, + embed_dim, + bias=False, + **factory_kwargs, + ) + else: + self.out_proj = out_proj_cls( + embed_dim, + embed_dim, + process_group, + bias=False, + sequence_parallel=gpc.config.parallel.sequence_parallel, + **factory_kwargs, + ) # need to assign tp attribute so that internlm know it is tensor parallel module if gpc.get_world_size(ParallelMode.TENSOR) > 1: for name in ["out_proj", "Wqkv"]: diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 572adba..24040a0 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -110,8 +110,8 @@ def initialize_model(): model = wrap_FSDP_model(model) if gpc.config.parallel["tensor"]["mode"] == "fstp": - # handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR)) - handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR)) + handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR)) + # handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR)) handler._register_sync_parameters_hook() gpc.config.fstp_handler = handler return model @@ -396,6 +396,9 @@ def initialize_llm_profile(profiling: bool = False, start_time: str = None): ) +tgs_list = [] + + @llm_timeout(func_name="record_current_batch_training_metrics") def record_current_batch_training_metrics( get_tflops_func, @@ -568,3 +571,9 @@ def record_current_batch_training_metrics( step_count=batch_count, cur_step_loss=loss.item(), ) + + if batch_count >= 5: + tgs_list.append(tgs_origin) + if batch_count == gpc.config.data.total_steps - 1: + print(tgs_list, flush=True) + print(f"avg_tgs: {sum(tgs_list)/len(tgs_list)}", flush=True)