mirror of https://github.com/InternLM/InternLM
reformat code
parent
629e6a5ad1
commit
f3da80a7ca
|
@ -160,37 +160,9 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
experts = torch.nn.ModuleList(
|
|
||||||
[
|
|
||||||
FeedForward(
|
|
||||||
hidden_size,
|
|
||||||
int(hidden_size * gpc.config.model.mlp_ratio),
|
|
||||||
out_features=hidden_size,
|
|
||||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
|
||||||
bias=False,
|
|
||||||
device=torch.device("cuda"),
|
|
||||||
dtype=torch.float,
|
|
||||||
)
|
|
||||||
for i in range(num_experts // ep_size)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# residual network, see https://arxiv.org/pdf/2201.05596.pdf, seems useful for convergence
|
|
||||||
if moe_use_residual:
|
|
||||||
residual_mlp = FeedForward(
|
|
||||||
hidden_size,
|
|
||||||
int(hidden_size * gpc.config.model.mlp_ratio),
|
|
||||||
out_features=hidden_size,
|
|
||||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
|
||||||
bias=False,
|
|
||||||
device=torch.device("cuda"),
|
|
||||||
dtype=torch.float,
|
|
||||||
)
|
|
||||||
|
|
||||||
# replace mlp by MoE module. The expert in MoE is a FeedForward module.
|
# replace mlp by MoE module. The expert in MoE is a FeedForward module.
|
||||||
self.mlp = MoE(
|
self.mlp = MoE(
|
||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
experts=experts,
|
|
||||||
num_experts=num_experts,
|
num_experts=num_experts,
|
||||||
ep_size=ep_size,
|
ep_size=ep_size,
|
||||||
k=moe_gate_k,
|
k=moe_gate_k,
|
||||||
|
@ -201,7 +173,6 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
drop_tokens=moe_drop_tokens,
|
drop_tokens=moe_drop_tokens,
|
||||||
use_rts=moe_use_rts,
|
use_rts=moe_use_rts,
|
||||||
use_residual=moe_use_residual,
|
use_residual=moe_use_residual,
|
||||||
residual_mlp=residual_mlp if moe_use_residual else None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.dropout2 = nn.Dropout(drop_rate)
|
self.dropout2 = nn.Dropout(drop_rate)
|
||||||
|
|
|
@ -5,6 +5,7 @@ import torch
|
||||||
|
|
||||||
from internlm.core.context import ParallelMode
|
from internlm.core.context import ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
|
from internlm.model.linear import FeedForward
|
||||||
from internlm.moe.experts import Experts
|
from internlm.moe.experts import Experts
|
||||||
from internlm.moe.sharded_moe import MOELayer, TopKGate
|
from internlm.moe.sharded_moe import MOELayer, TopKGate
|
||||||
from internlm.utils.logger import get_logger
|
from internlm.utils.logger import get_logger
|
||||||
|
@ -63,7 +64,6 @@ class MoE(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
experts,
|
|
||||||
num_experts=1,
|
num_experts=1,
|
||||||
ep_size=1,
|
ep_size=1,
|
||||||
k=1,
|
k=1,
|
||||||
|
@ -75,7 +75,6 @@ class MoE(torch.nn.Module):
|
||||||
use_rts: bool = True,
|
use_rts: bool = True,
|
||||||
using_default_moe: bool = True,
|
using_default_moe: bool = True,
|
||||||
use_residual=False,
|
use_residual=False,
|
||||||
residual_mlp=None,
|
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -91,12 +90,26 @@ class MoE(torch.nn.Module):
|
||||||
f"Creating MoE layer with num_experts: {num_experts} | num_local_experts:"
|
f"Creating MoE layer with num_experts: {num_experts} | num_local_experts:"
|
||||||
f"{self.num_local_experts} | expert_parallel_size: {self.ep_size}"
|
f"{self.num_local_experts} | expert_parallel_size: {self.ep_size}"
|
||||||
)
|
)
|
||||||
|
|
||||||
assert noisy_gate_policy is None or noisy_gate_policy in ["None", "Jitter", "RSample"], (
|
assert noisy_gate_policy is None or noisy_gate_policy in ["None", "Jitter", "RSample"], (
|
||||||
"Unsupported noisy_gate_policy: " + noisy_gate_policy
|
"Unsupported noisy_gate_policy: " + noisy_gate_policy
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# for elastic expert paralle, experts may have multiple groups
|
||||||
expert_group_name = f"ep_size_{self.ep_size}"
|
expert_group_name = f"ep_size_{self.ep_size}"
|
||||||
|
experts = torch.nn.ModuleList(
|
||||||
|
[
|
||||||
|
FeedForward(
|
||||||
|
hidden_size,
|
||||||
|
int(hidden_size * gpc.config.model.mlp_ratio),
|
||||||
|
out_features=hidden_size,
|
||||||
|
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||||
|
bias=False,
|
||||||
|
device=torch.device("cuda"),
|
||||||
|
dtype=torch.float,
|
||||||
|
)
|
||||||
|
for _ in range(self.num_local_experts)
|
||||||
|
]
|
||||||
|
)
|
||||||
experts = Experts(experts, self.num_local_experts, expert_group_name)
|
experts = Experts(experts, self.num_local_experts, expert_group_name)
|
||||||
|
|
||||||
if using_default_moe:
|
if using_default_moe:
|
||||||
|
@ -118,10 +131,19 @@ class MoE(torch.nn.Module):
|
||||||
self.num_local_experts,
|
self.num_local_experts,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# residual network, see https://arxiv.org/pdf/2201.05596.pdf, seems useful for convergence
|
||||||
self.use_residual = use_residual
|
self.use_residual = use_residual
|
||||||
if use_residual:
|
if use_residual:
|
||||||
self.residual_mlp = residual_mlp
|
self.residual_mlp = FeedForward(
|
||||||
# coefficient is used for weighted sum of the output of expert and mlp
|
hidden_size,
|
||||||
|
int(hidden_size * gpc.config.model.mlp_ratio),
|
||||||
|
out_features=hidden_size,
|
||||||
|
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||||
|
bias=False,
|
||||||
|
device=torch.device("cuda"),
|
||||||
|
dtype=torch.float,
|
||||||
|
)
|
||||||
|
# coefficient is used for weighted sum of the output of expert and residual mlp
|
||||||
self.coefficient = torch.nn.Linear(hidden_size, 2)
|
self.coefficient = torch.nn.Linear(hidden_size, 2)
|
||||||
|
|
||||||
def forward(self, hidden_states, used_token=None):
|
def forward(self, hidden_states, used_token=None):
|
||||||
|
|
|
@ -356,7 +356,6 @@ class TopKGate(Module):
|
||||||
# Only top-1 and top-2 are supported at the moment.
|
# Only top-1 and top-2 are supported at the moment.
|
||||||
if k not in (1, 2):
|
if k not in (1, 2):
|
||||||
raise ValueError("Only top-1 and top-2 gatings are supported.")
|
raise ValueError("Only top-1 and top-2 gatings are supported.")
|
||||||
# TODO: can we use tensor parallel here?
|
|
||||||
# Deepspeed's mechisms, alway use fp32
|
# Deepspeed's mechisms, alway use fp32
|
||||||
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float()
|
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float()
|
||||||
self.k = k
|
self.k = k
|
||||||
|
@ -437,9 +436,6 @@ class MOELayer(Base):
|
||||||
self.time_moe = 0.0
|
self.time_moe = 0.0
|
||||||
self.wall_clock_breakdown = False
|
self.wall_clock_breakdown = False
|
||||||
|
|
||||||
def _set_ep_group(self, ep_group):
|
|
||||||
self.ep_group = ep_group
|
|
||||||
|
|
||||||
def forward(self, *inputs: Tensor) -> Tensor:
|
def forward(self, *inputs: Tensor) -> Tensor:
|
||||||
|
|
||||||
if self.wall_clock_breakdown:
|
if self.wall_clock_breakdown:
|
||||||
|
|
2
train.py
2
train.py
|
@ -262,7 +262,7 @@ def main(args):
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
loss=loss,
|
loss=loss,
|
||||||
moe_loss=moe_loss,
|
moe_loss=moe_loss,
|
||||||
grad_norm=np.array(grad_norm_groups),
|
grad_norm=np.linalg.norm(grad_norm_groups),
|
||||||
metric=metric,
|
metric=metric,
|
||||||
update_panel=uniscale_logger is not None,
|
update_panel=uniscale_logger is not None,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue