refactor code

pull/567/head
Wenwen Qu 2024-01-08 16:03:55 +08:00
parent c3854f924a
commit 41f8283a3e
5 changed files with 15 additions and 12 deletions

View File

@ -5,7 +5,7 @@ import torch
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.linear import FeedForward
from internlm.moe.sharded_moe import GShardMOELayer
from internlm.moe import GShardMOELayer
from internlm.utils.logger import get_logger
# global llm logger
@ -46,7 +46,7 @@ class MoE(torch.nn.Module):
noisy_gate_policy: typing.Optional[str] = None,
drop_tokens: bool = True,
use_rts: bool = True,
using_default_moe: bool = True,
moe_type: str = None,
use_residual=False,
device=None,
dtype=None,
@ -65,7 +65,7 @@ class MoE(torch.nn.Module):
"Unsupported noisy_gate_policy: " + noisy_gate_policy
)
if using_default_moe:
if moe_type is None or moe_type == "GShard":
self.moe_layer = GShardMOELayer(
hidden_size,
gpc.get_group(ParallelMode.EXPERT),

View File

@ -0,0 +1,3 @@
from internlm.moe.sharded_moe import GShardMOELayer
__all__ = ["GShardMOELayer"]

View File

@ -1,5 +1,6 @@
from typing import TYPE_CHECKING, Union
import torch
from torch import Tensor
from torch.nn import Module, ModuleList
@ -30,3 +31,5 @@ class BaseMoELayer(Base):
self.ep_group = ep_group
self.ep_size = ep_size
self.num_local_experts = num_local_experts
self.l_aux = torch.zeros(1, device=torch.cuda.current_device())
self.exp_counts = None

View File

@ -18,17 +18,14 @@ class Experts(torch.nn.Module):
def __init__(self, experts: Union[Module, ModuleList], num_local_experts=1, expert_group_name=None):
super().__init__()
# TODO: We can not deepcopy FeedForward since it contains a process_group in submodules
# self.experts = torch.nn.ModuleList([copy.deepcopy(expert) for i in range(num_local_experts)])
if isinstance(experts, ModuleList):
self.experts = cast(ModuleList, experts)
self.wrapped_experts = cast(ModuleList, experts)
else:
self.experts = ModuleList([experts])
self.wrapped_experts = ModuleList([experts])
self.num_local_experts = num_local_experts
# TODO: revisit allreduce for moe.gate...
for expert in self.experts:
for expert in self.wrapped_experts:
# TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group)
for _, param in expert.named_parameters():
param.is_expert = True
@ -37,7 +34,7 @@ class Experts(torch.nn.Module):
def forward(self, inputs):
chunks = inputs.chunk(self.num_local_experts, dim=1)
expert_outputs = []
for chunk, expert in zip(chunks, self.experts):
for chunk, expert in zip(chunks, self.wrapped_experts):
out = expert(chunk)
if isinstance(out, tuple):
out = out[0] # Ignore the bias term for now

View File

@ -614,7 +614,7 @@ def try_save_moe_checkpoint(folder, model, tp_rank, pp_rank):
for n, p in module.state_dict().items():
if "expert" in n and "moe_layer.gate" not in n:
moe_state_dict[n_module + "." + n] = p
moe_str_prefix = ".moe_layer.experts.experts."
moe_str_prefix = ".moe_layer.experts.wrapped_experts."
# Reorder the moe name rank, so that each checkpoint only has one expert
experts_state_dict = defaultdict(dict)
for key in list(moe_state_dict.keys()):
@ -696,7 +696,7 @@ def try_load_moe_checkpoint(folder, model, state_dict, tp_rank, pp_rank):
fp = os.path.join(folder, fn)
expert_state_dict = llm_load(fp, map_location=get_current_device())
# Updating global -> local expert ids
moe_str_prefix = ".moe_layer.experts.experts."
moe_str_prefix = ".moe_layer.experts.wrapped_experts."
for key in list(expert_state_dict.keys()):
local_key = key.replace(f"{moe_str_prefix}{global_expert_id}", f"{moe_str_prefix}{local_expert_id}")
expert_state_dict[local_key] = expert_state_dict.pop(key)