mirror of https://github.com/InternLM/InternLM
refactor code
parent
c3854f924a
commit
41f8283a3e
|
@ -5,7 +5,7 @@ import torch
|
|||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.model.linear import FeedForward
|
||||
from internlm.moe.sharded_moe import GShardMOELayer
|
||||
from internlm.moe import GShardMOELayer
|
||||
from internlm.utils.logger import get_logger
|
||||
|
||||
# global llm logger
|
||||
|
@ -46,7 +46,7 @@ class MoE(torch.nn.Module):
|
|||
noisy_gate_policy: typing.Optional[str] = None,
|
||||
drop_tokens: bool = True,
|
||||
use_rts: bool = True,
|
||||
using_default_moe: bool = True,
|
||||
moe_type: str = None,
|
||||
use_residual=False,
|
||||
device=None,
|
||||
dtype=None,
|
||||
|
@ -65,7 +65,7 @@ class MoE(torch.nn.Module):
|
|||
"Unsupported noisy_gate_policy: " + noisy_gate_policy
|
||||
)
|
||||
|
||||
if using_default_moe:
|
||||
if moe_type is None or moe_type == "GShard":
|
||||
self.moe_layer = GShardMOELayer(
|
||||
hidden_size,
|
||||
gpc.get_group(ParallelMode.EXPERT),
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
from internlm.moe.sharded_moe import GShardMOELayer
|
||||
|
||||
__all__ = ["GShardMOELayer"]
|
|
@ -1,5 +1,6 @@
|
|||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
|
@ -30,3 +31,5 @@ class BaseMoELayer(Base):
|
|||
self.ep_group = ep_group
|
||||
self.ep_size = ep_size
|
||||
self.num_local_experts = num_local_experts
|
||||
self.l_aux = torch.zeros(1, device=torch.cuda.current_device())
|
||||
self.exp_counts = None
|
||||
|
|
|
@ -18,17 +18,14 @@ class Experts(torch.nn.Module):
|
|||
def __init__(self, experts: Union[Module, ModuleList], num_local_experts=1, expert_group_name=None):
|
||||
super().__init__()
|
||||
|
||||
# TODO: We can not deepcopy FeedForward since it contains a process_group in submodules
|
||||
# self.experts = torch.nn.ModuleList([copy.deepcopy(expert) for i in range(num_local_experts)])
|
||||
|
||||
if isinstance(experts, ModuleList):
|
||||
self.experts = cast(ModuleList, experts)
|
||||
self.wrapped_experts = cast(ModuleList, experts)
|
||||
else:
|
||||
self.experts = ModuleList([experts])
|
||||
self.wrapped_experts = ModuleList([experts])
|
||||
self.num_local_experts = num_local_experts
|
||||
|
||||
# TODO: revisit allreduce for moe.gate...
|
||||
for expert in self.experts:
|
||||
for expert in self.wrapped_experts:
|
||||
# TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group)
|
||||
for _, param in expert.named_parameters():
|
||||
param.is_expert = True
|
||||
|
@ -37,7 +34,7 @@ class Experts(torch.nn.Module):
|
|||
def forward(self, inputs):
|
||||
chunks = inputs.chunk(self.num_local_experts, dim=1)
|
||||
expert_outputs = []
|
||||
for chunk, expert in zip(chunks, self.experts):
|
||||
for chunk, expert in zip(chunks, self.wrapped_experts):
|
||||
out = expert(chunk)
|
||||
if isinstance(out, tuple):
|
||||
out = out[0] # Ignore the bias term for now
|
||||
|
|
|
@ -614,7 +614,7 @@ def try_save_moe_checkpoint(folder, model, tp_rank, pp_rank):
|
|||
for n, p in module.state_dict().items():
|
||||
if "expert" in n and "moe_layer.gate" not in n:
|
||||
moe_state_dict[n_module + "." + n] = p
|
||||
moe_str_prefix = ".moe_layer.experts.experts."
|
||||
moe_str_prefix = ".moe_layer.experts.wrapped_experts."
|
||||
# Reorder the moe name rank, so that each checkpoint only has one expert
|
||||
experts_state_dict = defaultdict(dict)
|
||||
for key in list(moe_state_dict.keys()):
|
||||
|
@ -696,7 +696,7 @@ def try_load_moe_checkpoint(folder, model, state_dict, tp_rank, pp_rank):
|
|||
fp = os.path.join(folder, fn)
|
||||
expert_state_dict = llm_load(fp, map_location=get_current_device())
|
||||
# Updating global -> local expert ids
|
||||
moe_str_prefix = ".moe_layer.experts.experts."
|
||||
moe_str_prefix = ".moe_layer.experts.wrapped_experts."
|
||||
for key in list(expert_state_dict.keys()):
|
||||
local_key = key.replace(f"{moe_str_prefix}{global_expert_id}", f"{moe_str_prefix}{local_expert_id}")
|
||||
expert_state_dict[local_key] = expert_state_dict.pop(key)
|
||||
|
|
Loading…
Reference in New Issue