modified: internlm/model/modeling_internlm.py

pull/375/head
Wenwen Qu 2023-08-08 15:47:46 +08:00
parent 8b198b2665
commit 2a52452ed2
1 changed files with 23 additions and 19 deletions

View File

@ -32,6 +32,7 @@ MODEL_TYPE = "INTERNLM"
logger = get_logger(__file__) logger = get_logger(__file__)
RMSNorm = try_import_RMSNorm() RMSNorm = try_import_RMSNorm()
class PackedFlashBaseLayer1D(nn.Module): class PackedFlashBaseLayer1D(nn.Module):
""" """
1D Packed Flash Base Layer. 1D Packed Flash Base Layer.
@ -104,10 +105,10 @@ class PackedFlashBaseLayer1D(nn.Module):
self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
## TODO: replace num_experts and epsize with function parameter # TODO: replace num_experts and epsize with function parameter
self.num_experts = num_experts self.num_experts = num_experts
ep_size = gpc.get_world_size(ParallelMode.EXPERT) ep_size = gpc.get_world_size(ParallelMode.EXPERT)
if num_experts <= 1: # dense, not MoE if num_experts <= 1: # dense, not MoE
if use_swiglu: if use_swiglu:
self.mlp = FeedForward( self.mlp = FeedForward(
hidden_size, hidden_size,
@ -134,27 +135,29 @@ class PackedFlashBaseLayer1D(nn.Module):
dtype=dtype, dtype=dtype,
) )
else: else:
expert = torch.nn.ModuleList([FeedForward( expert = torch.nn.ModuleList(
hidden_size, [
int(hidden_size * gpc.config.model.mlp_ratio), FeedForward(
out_features=hidden_size, hidden_size,
process_group=gpc.get_group(ParallelMode.TENSOR), int(hidden_size * gpc.config.model.mlp_ratio),
bias=False, out_features=hidden_size,
device=torch.device("cuda"), process_group=gpc.get_group(ParallelMode.TENSOR),
dtype=torch.float, bias=False,
) for i in range(num_experts // ep_size)]) device=torch.device("cuda"),
# TODO: test moe for now, need more parameter such as: capacity_factor, eval_capacity_factor, min_capacity, drop_tokens dtype=torch.float,
self.mlp = MoE(hidden_size=hidden_size, )
expert=expert, for i in range(num_experts // ep_size)
ep_size=ep_size, ]
num_experts=num_experts, )
k=1) # TODO: test moe for now, need more parameter such as: capacity_factor,
# eval_capacity_factor, min_capacity, drop_tokens
self.mlp = MoE(hidden_size=hidden_size, expert=expert, ep_size=ep_size, num_experts=num_experts, k=1)
self.dropout2 = nn.Dropout(drop_rate) self.dropout2 = nn.Dropout(drop_rate)
self.use_swiglu = use_swiglu self.use_swiglu = use_swiglu
self.use_scaled_init = use_scaled_init self.use_scaled_init = use_scaled_init
self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm
self.return_residual = False self.return_residual = False
self.reset_parameters() ## TODO: check this should be changed when moe is added self.reset_parameters() # TODO: check this should be changed when moe is added
def reset_parameters(self): def reset_parameters(self):
with torch.no_grad(): with torch.no_grad():
@ -186,7 +189,7 @@ class PackedFlashBaseLayer1D(nn.Module):
if self.checkpoint and self.training: if self.checkpoint and self.training:
return activation_checkpoint( return activation_checkpoint(
self._forward, False, hidden_states, cu_seqlens, indexes, inference_params, max_seqlen self._forward, False, hidden_states, cu_seqlens, indexes, inference_params, max_seqlen
) ##TODO: check whether this will be affected by moe ) # TODO: check whether this will be affected by moe
else: else:
return self._forward(hidden_states, cu_seqlens, indexes, inference_params, max_seqlen) return self._forward(hidden_states, cu_seqlens, indexes, inference_params, max_seqlen)
@ -550,6 +553,7 @@ def build_model_with_cfg(
use_scaled_init=use_scaled_init, use_scaled_init=use_scaled_init,
use_swiglu=use_swiglu, use_swiglu=use_swiglu,
use_flash_attn=use_flash_attn, use_flash_attn=use_flash_attn,
sequence_parallel=sequence_parallel,
num_experts=num_experts, num_experts=num_experts,
) )