mirror of https://github.com/InternLM/InternLM
get moe setting from gpc
parent
f5226b5152
commit
fe0c342f9d
|
@ -86,14 +86,6 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
use_swiglu: bool = True,
|
use_swiglu: bool = True,
|
||||||
use_flash_attn: bool = True,
|
use_flash_attn: bool = True,
|
||||||
num_experts: int = 1,
|
num_experts: int = 1,
|
||||||
moe_gate_k: int = 1,
|
|
||||||
moe_capacity_factor: float = 1.0,
|
|
||||||
moe_eval_capacity_factor: float = 1.0,
|
|
||||||
moe_min_capacity: int = 4,
|
|
||||||
moe_noisy_gate_policy: str = None,
|
|
||||||
moe_drop_tokens: bool = True,
|
|
||||||
moe_use_rts: bool = True,
|
|
||||||
moe_use_residual: bool = False,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.checkpoint = checkpoint
|
self.checkpoint = checkpoint
|
||||||
|
@ -131,14 +123,6 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
set_fp32_attr_to_module(self.norm2)
|
set_fp32_attr_to_module(self.norm2)
|
||||||
|
|
||||||
self.num_experts = num_experts
|
self.num_experts = num_experts
|
||||||
self.moe_gate_k = moe_gate_k
|
|
||||||
self.moe_capacity_factor = moe_capacity_factor
|
|
||||||
self.moe_eval_capacity_factor = moe_eval_capacity_factor
|
|
||||||
self.moe_min_capacity = moe_min_capacity
|
|
||||||
self.moe_noisy_gate_policy = moe_noisy_gate_policy
|
|
||||||
self.moe_drop_tokens = moe_drop_tokens
|
|
||||||
self.moe_use_rts = moe_use_rts
|
|
||||||
self.moe_use_residual = moe_use_residual
|
|
||||||
ep_size = gpc.get_world_size(ParallelMode.EXPERT)
|
ep_size = gpc.get_world_size(ParallelMode.EXPERT)
|
||||||
if num_experts <= 1: # dense, not MoE
|
if num_experts <= 1: # dense, not MoE
|
||||||
if use_swiglu:
|
if use_swiglu:
|
||||||
|
@ -175,14 +159,6 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
num_experts=num_experts,
|
num_experts=num_experts,
|
||||||
ep_size=ep_size,
|
ep_size=ep_size,
|
||||||
topk=moe_gate_k,
|
|
||||||
capacity_factor=moe_capacity_factor,
|
|
||||||
eval_capacity_factor=moe_eval_capacity_factor,
|
|
||||||
min_capacity=moe_min_capacity,
|
|
||||||
noisy_gate_policy=moe_noisy_gate_policy,
|
|
||||||
drop_tokens=moe_drop_tokens,
|
|
||||||
use_rts=moe_use_rts,
|
|
||||||
use_residual=moe_use_residual,
|
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
@ -357,14 +333,6 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
use_swiglu: bool = True,
|
use_swiglu: bool = True,
|
||||||
use_flash_attn: bool = True,
|
use_flash_attn: bool = True,
|
||||||
num_experts: bool = 1,
|
num_experts: bool = 1,
|
||||||
moe_gate_k: int = 1,
|
|
||||||
moe_capacity_factor: float = 1.0,
|
|
||||||
moe_eval_capacity_factor: float = 1.0,
|
|
||||||
moe_min_capacity: int = 4,
|
|
||||||
moe_noisy_gate_policy: str = None,
|
|
||||||
moe_drop_tokens: bool = True,
|
|
||||||
moe_use_rts: bool = True,
|
|
||||||
moe_use_residual: bool = False,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -415,14 +383,6 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
use_swiglu=use_swiglu,
|
use_swiglu=use_swiglu,
|
||||||
use_flash_attn=use_flash_attn,
|
use_flash_attn=use_flash_attn,
|
||||||
num_experts=num_experts,
|
num_experts=num_experts,
|
||||||
moe_gate_k=moe_gate_k,
|
|
||||||
moe_capacity_factor=moe_capacity_factor,
|
|
||||||
moe_eval_capacity_factor=moe_eval_capacity_factor,
|
|
||||||
moe_min_capacity=moe_min_capacity,
|
|
||||||
moe_noisy_gate_policy=moe_noisy_gate_policy,
|
|
||||||
moe_drop_tokens=moe_drop_tokens,
|
|
||||||
moe_use_rts=moe_use_rts,
|
|
||||||
moe_use_residual=moe_use_residual,
|
|
||||||
)
|
)
|
||||||
for lid in range(num_layers)
|
for lid in range(num_layers)
|
||||||
]
|
]
|
||||||
|
@ -559,14 +519,14 @@ def build_model_with_moe_cfg(
|
||||||
use_swiglu: bool = True,
|
use_swiglu: bool = True,
|
||||||
use_flash_attn: bool = True,
|
use_flash_attn: bool = True,
|
||||||
num_experts: int = 1,
|
num_experts: int = 1,
|
||||||
moe_gate_k: int = 1,
|
moe_gate_k: int = 1, # pylint: disable=W0613
|
||||||
moe_capacity_factor: float = 1.0,
|
moe_capacity_factor: float = 1.0, # pylint: disable=W0613
|
||||||
moe_eval_capacity_factor: float = 1.0,
|
moe_eval_capacity_factor: float = 1.0, # pylint: disable=W0613
|
||||||
moe_min_capacity: int = 4,
|
moe_min_capacity: int = 4, # pylint: disable=W0613
|
||||||
moe_noisy_gate_policy: str = None,
|
moe_noisy_gate_policy: str = None, # pylint: disable=W0613
|
||||||
moe_drop_tokens: bool = True,
|
moe_drop_tokens: bool = True, # pylint: disable=W0613
|
||||||
moe_use_rts: bool = True,
|
moe_use_rts: bool = True, # pylint: disable=W0613
|
||||||
moe_use_residual: bool = False,
|
moe_use_residual: bool = False, # pylint: disable=W0613
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Build model with config.
|
Build model with config.
|
||||||
|
@ -633,14 +593,6 @@ def build_model_with_moe_cfg(
|
||||||
use_swiglu=use_swiglu,
|
use_swiglu=use_swiglu,
|
||||||
use_flash_attn=use_flash_attn,
|
use_flash_attn=use_flash_attn,
|
||||||
num_experts=num_experts,
|
num_experts=num_experts,
|
||||||
moe_gate_k=moe_gate_k,
|
|
||||||
moe_capacity_factor=moe_capacity_factor,
|
|
||||||
moe_eval_capacity_factor=moe_eval_capacity_factor,
|
|
||||||
moe_min_capacity=moe_min_capacity,
|
|
||||||
moe_noisy_gate_policy=moe_noisy_gate_policy,
|
|
||||||
moe_drop_tokens=moe_drop_tokens,
|
|
||||||
moe_use_rts=moe_use_rts,
|
|
||||||
moe_use_residual=moe_use_residual,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)
|
return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
import typing
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from internlm.core.context import ParallelMode
|
from internlm.core.context import ParallelMode
|
||||||
|
@ -39,17 +37,9 @@ class MoE(torch.nn.Module):
|
||||||
hidden_size,
|
hidden_size,
|
||||||
num_experts=1,
|
num_experts=1,
|
||||||
ep_size=1,
|
ep_size=1,
|
||||||
topk=1,
|
|
||||||
capacity_factor=1.0,
|
|
||||||
eval_capacity_factor=1.0,
|
|
||||||
min_capacity=4,
|
|
||||||
noisy_gate_policy: typing.Optional[str] = None,
|
|
||||||
drop_tokens: bool = True,
|
|
||||||
use_rts: bool = True,
|
|
||||||
moe_type: str = None,
|
|
||||||
use_residual=False,
|
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
|
moe_type: str = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -61,30 +51,19 @@ class MoE(torch.nn.Module):
|
||||||
self.num_experts = num_experts
|
self.num_experts = num_experts
|
||||||
self.num_local_experts = num_experts // self.ep_size
|
self.num_local_experts = num_experts // self.ep_size
|
||||||
|
|
||||||
assert noisy_gate_policy is None or noisy_gate_policy in ["None", "Jitter", "RSample"], (
|
|
||||||
"Unsupported noisy_gate_policy: " + noisy_gate_policy
|
|
||||||
)
|
|
||||||
|
|
||||||
if moe_type is None or moe_type == "GShard":
|
if moe_type is None or moe_type == "GShard":
|
||||||
self.moe_layer = GShardMOELayer(
|
self.moe_layer = GShardMOELayer(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
gpc.get_group(ParallelMode.EXPERT),
|
gpc.get_group(ParallelMode.EXPERT),
|
||||||
ep_size,
|
ep_size,
|
||||||
num_experts,
|
num_experts,
|
||||||
topk,
|
|
||||||
capacity_factor,
|
|
||||||
eval_capacity_factor,
|
|
||||||
min_capacity,
|
|
||||||
noisy_gate_policy,
|
|
||||||
drop_tokens,
|
|
||||||
use_rts,
|
|
||||||
device,
|
device,
|
||||||
dtype,
|
dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
# residual network, see https://arxiv.org/pdf/2201.05596.pdf, seems useful for convergence
|
# residual network, see https://arxiv.org/pdf/2201.05596.pdf, seems useful for convergence
|
||||||
self.use_residual = use_residual
|
self.use_residual = getattr(gpc.config.model, "moe_use_residual", False)
|
||||||
if use_residual:
|
if self.use_residual:
|
||||||
self.residual_mlp = FeedForward(
|
self.residual_mlp = FeedForward(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
int(hidden_size * gpc.config.model.mlp_ratio),
|
int(hidden_size * gpc.config.model.mlp_ratio),
|
||||||
|
|
|
@ -304,7 +304,7 @@ class TopKGate(Module):
|
||||||
self,
|
self,
|
||||||
model_dim: int,
|
model_dim: int,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
k: int = 1,
|
topk: int = 1,
|
||||||
capacity_factor: float = 1.0,
|
capacity_factor: float = 1.0,
|
||||||
eval_capacity_factor: float = 1.0,
|
eval_capacity_factor: float = 1.0,
|
||||||
min_capacity: int = 8,
|
min_capacity: int = 8,
|
||||||
|
@ -315,11 +315,11 @@ class TopKGate(Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Only top-1 and top-2 are supported at the moment.
|
# Only top-1 and top-2 are supported at the moment.
|
||||||
if k not in (1, 2):
|
if topk not in (1, 2):
|
||||||
raise ValueError("Only top-1 and top-2 gatings are supported.")
|
raise ValueError("Only top-1 and top-2 gatings are supported.")
|
||||||
# Deepspeed's mechisms, alway use fp32
|
# Deepspeed's mechisms, alway use fp32
|
||||||
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False)
|
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False)
|
||||||
self.k = k
|
self.k = topk
|
||||||
self.capacity_factor = capacity_factor
|
self.capacity_factor = capacity_factor
|
||||||
self.eval_capacity_factor = eval_capacity_factor
|
self.eval_capacity_factor = eval_capacity_factor
|
||||||
self.min_capacity = min_capacity
|
self.min_capacity = min_capacity
|
||||||
|
@ -388,27 +388,24 @@ class GShardMOELayer(BaseMoELayer):
|
||||||
ep_group,
|
ep_group,
|
||||||
ep_size: int,
|
ep_size: int,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
topk,
|
|
||||||
capacity_factor,
|
|
||||||
eval_capacity_factor,
|
|
||||||
min_capacity,
|
|
||||||
noisy_gate_policy,
|
|
||||||
drop_tokens,
|
|
||||||
use_rts,
|
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
noisy_gate_policy = getattr(gpc.config.model, "noisy_gate_policy", None)
|
||||||
|
assert noisy_gate_policy is None or noisy_gate_policy in ["None", "Jitter", "RSample"], (
|
||||||
|
"Unsupported noisy_gate_policy: " + noisy_gate_policy
|
||||||
|
)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
TopKGate(
|
TopKGate(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
num_experts,
|
num_experts,
|
||||||
topk,
|
topk=getattr(gpc.config.model, "moe_gate_k", 1),
|
||||||
capacity_factor,
|
capacity_factor=getattr(gpc.config.model, "moe_capacity_factor", 1.0),
|
||||||
eval_capacity_factor,
|
eval_capacity_factor=getattr(gpc.config.model, "moe_eval_capacity_factor", 1.0),
|
||||||
min_capacity,
|
min_capacity=getattr(gpc.config.model, "moe_min_capacity", 4),
|
||||||
noisy_gate_policy,
|
noisy_gate_policy=getattr(gpc.config.model, "moe_noisy_gate_policy", None),
|
||||||
drop_tokens,
|
drop_tokens=getattr(gpc.config.model, "moe_drop_tokens", True),
|
||||||
use_rts,
|
use_rts=getattr(gpc.config.model, "moe_use_rts", True),
|
||||||
),
|
),
|
||||||
torch.nn.ModuleList(
|
torch.nn.ModuleList(
|
||||||
[
|
[
|
||||||
|
|
Loading…
Reference in New Issue