diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index b57c8f0..3c58bd8 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -32,6 +32,7 @@ MODEL_TYPE = "INTERNLM" logger = get_logger(__file__) RMSNorm = try_import_RMSNorm() + class PackedFlashBaseLayer1D(nn.Module): """ 1D Packed Flash Base Layer. @@ -104,10 +105,10 @@ class PackedFlashBaseLayer1D(nn.Module): self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) - ## TODO: replace num_experts and epsize with function parameter + # TODO: replace num_experts and epsize with function parameter self.num_experts = num_experts ep_size = gpc.get_world_size(ParallelMode.EXPERT) - if num_experts <= 1: # dense, not MoE + if num_experts <= 1: # dense, not MoE if use_swiglu: self.mlp = FeedForward( hidden_size, @@ -134,27 +135,29 @@ class PackedFlashBaseLayer1D(nn.Module): dtype=dtype, ) else: - expert = torch.nn.ModuleList([FeedForward( - hidden_size, - int(hidden_size * gpc.config.model.mlp_ratio), - out_features=hidden_size, - process_group=gpc.get_group(ParallelMode.TENSOR), - bias=False, - device=torch.device("cuda"), - dtype=torch.float, - ) for i in range(num_experts // ep_size)]) - # TODO: test moe for now, need more parameter such as: capacity_factor, eval_capacity_factor, min_capacity, drop_tokens - self.mlp = MoE(hidden_size=hidden_size, - expert=expert, - ep_size=ep_size, - num_experts=num_experts, - k=1) + expert = torch.nn.ModuleList( + [ + FeedForward( + hidden_size, + int(hidden_size * gpc.config.model.mlp_ratio), + out_features=hidden_size, + process_group=gpc.get_group(ParallelMode.TENSOR), + bias=False, + device=torch.device("cuda"), + dtype=torch.float, + ) + for i in range(num_experts // ep_size) + ] + ) + # TODO: test moe for now, need more parameter such as: capacity_factor, + # eval_capacity_factor, min_capacity, drop_tokens + self.mlp = MoE(hidden_size=hidden_size, expert=expert, ep_size=ep_size, num_experts=num_experts, k=1) self.dropout2 = nn.Dropout(drop_rate) self.use_swiglu = use_swiglu self.use_scaled_init = use_scaled_init self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm self.return_residual = False - self.reset_parameters() ## TODO: check this should be changed when moe is added + self.reset_parameters() # TODO: check this should be changed when moe is added def reset_parameters(self): with torch.no_grad(): @@ -186,7 +189,7 @@ class PackedFlashBaseLayer1D(nn.Module): if self.checkpoint and self.training: return activation_checkpoint( self._forward, False, hidden_states, cu_seqlens, indexes, inference_params, max_seqlen - ) ##TODO: check whether this will be affected by moe + ) # TODO: check whether this will be affected by moe else: return self._forward(hidden_states, cu_seqlens, indexes, inference_params, max_seqlen) @@ -550,6 +553,7 @@ def build_model_with_cfg( use_scaled_init=use_scaled_init, use_swiglu=use_swiglu, use_flash_attn=use_flash_attn, + sequence_parallel=sequence_parallel, num_experts=num_experts, )