mirror of https://github.com/InternLM/InternLM
modified: internlm/model/modeling_internlm.py
parent
8b198b2665
commit
2a52452ed2
|
@ -32,6 +32,7 @@ MODEL_TYPE = "INTERNLM"
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__file__)
|
||||||
RMSNorm = try_import_RMSNorm()
|
RMSNorm = try_import_RMSNorm()
|
||||||
|
|
||||||
|
|
||||||
class PackedFlashBaseLayer1D(nn.Module):
|
class PackedFlashBaseLayer1D(nn.Module):
|
||||||
"""
|
"""
|
||||||
1D Packed Flash Base Layer.
|
1D Packed Flash Base Layer.
|
||||||
|
@ -104,10 +105,10 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||||
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||||
|
|
||||||
## TODO: replace num_experts and epsize with function parameter
|
# TODO: replace num_experts and epsize with function parameter
|
||||||
self.num_experts = num_experts
|
self.num_experts = num_experts
|
||||||
ep_size = gpc.get_world_size(ParallelMode.EXPERT)
|
ep_size = gpc.get_world_size(ParallelMode.EXPERT)
|
||||||
if num_experts <= 1: # dense, not MoE
|
if num_experts <= 1: # dense, not MoE
|
||||||
if use_swiglu:
|
if use_swiglu:
|
||||||
self.mlp = FeedForward(
|
self.mlp = FeedForward(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
|
@ -134,27 +135,29 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
expert = torch.nn.ModuleList([FeedForward(
|
expert = torch.nn.ModuleList(
|
||||||
hidden_size,
|
[
|
||||||
int(hidden_size * gpc.config.model.mlp_ratio),
|
FeedForward(
|
||||||
out_features=hidden_size,
|
hidden_size,
|
||||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
int(hidden_size * gpc.config.model.mlp_ratio),
|
||||||
bias=False,
|
out_features=hidden_size,
|
||||||
device=torch.device("cuda"),
|
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||||
dtype=torch.float,
|
bias=False,
|
||||||
) for i in range(num_experts // ep_size)])
|
device=torch.device("cuda"),
|
||||||
# TODO: test moe for now, need more parameter such as: capacity_factor, eval_capacity_factor, min_capacity, drop_tokens
|
dtype=torch.float,
|
||||||
self.mlp = MoE(hidden_size=hidden_size,
|
)
|
||||||
expert=expert,
|
for i in range(num_experts // ep_size)
|
||||||
ep_size=ep_size,
|
]
|
||||||
num_experts=num_experts,
|
)
|
||||||
k=1)
|
# TODO: test moe for now, need more parameter such as: capacity_factor,
|
||||||
|
# eval_capacity_factor, min_capacity, drop_tokens
|
||||||
|
self.mlp = MoE(hidden_size=hidden_size, expert=expert, ep_size=ep_size, num_experts=num_experts, k=1)
|
||||||
self.dropout2 = nn.Dropout(drop_rate)
|
self.dropout2 = nn.Dropout(drop_rate)
|
||||||
self.use_swiglu = use_swiglu
|
self.use_swiglu = use_swiglu
|
||||||
self.use_scaled_init = use_scaled_init
|
self.use_scaled_init = use_scaled_init
|
||||||
self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm
|
self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm
|
||||||
self.return_residual = False
|
self.return_residual = False
|
||||||
self.reset_parameters() ## TODO: check this should be changed when moe is added
|
self.reset_parameters() # TODO: check this should be changed when moe is added
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -186,7 +189,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
if self.checkpoint and self.training:
|
if self.checkpoint and self.training:
|
||||||
return activation_checkpoint(
|
return activation_checkpoint(
|
||||||
self._forward, False, hidden_states, cu_seqlens, indexes, inference_params, max_seqlen
|
self._forward, False, hidden_states, cu_seqlens, indexes, inference_params, max_seqlen
|
||||||
) ##TODO: check whether this will be affected by moe
|
) # TODO: check whether this will be affected by moe
|
||||||
else:
|
else:
|
||||||
return self._forward(hidden_states, cu_seqlens, indexes, inference_params, max_seqlen)
|
return self._forward(hidden_states, cu_seqlens, indexes, inference_params, max_seqlen)
|
||||||
|
|
||||||
|
@ -550,6 +553,7 @@ def build_model_with_cfg(
|
||||||
use_scaled_init=use_scaled_init,
|
use_scaled_init=use_scaled_init,
|
||||||
use_swiglu=use_swiglu,
|
use_swiglu=use_swiglu,
|
||||||
use_flash_attn=use_flash_attn,
|
use_flash_attn=use_flash_attn,
|
||||||
|
sequence_parallel=sequence_parallel,
|
||||||
num_experts=num_experts,
|
num_experts=num_experts,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue