mirror of https://github.com/InternLM/InternLM
refactor code
parent
c3854f924a
commit
41f8283a3e
|
@ -5,7 +5,7 @@ import torch
|
||||||
from internlm.core.context import ParallelMode
|
from internlm.core.context import ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.model.linear import FeedForward
|
from internlm.model.linear import FeedForward
|
||||||
from internlm.moe.sharded_moe import GShardMOELayer
|
from internlm.moe import GShardMOELayer
|
||||||
from internlm.utils.logger import get_logger
|
from internlm.utils.logger import get_logger
|
||||||
|
|
||||||
# global llm logger
|
# global llm logger
|
||||||
|
@ -46,7 +46,7 @@ class MoE(torch.nn.Module):
|
||||||
noisy_gate_policy: typing.Optional[str] = None,
|
noisy_gate_policy: typing.Optional[str] = None,
|
||||||
drop_tokens: bool = True,
|
drop_tokens: bool = True,
|
||||||
use_rts: bool = True,
|
use_rts: bool = True,
|
||||||
using_default_moe: bool = True,
|
moe_type: str = None,
|
||||||
use_residual=False,
|
use_residual=False,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
|
@ -65,7 +65,7 @@ class MoE(torch.nn.Module):
|
||||||
"Unsupported noisy_gate_policy: " + noisy_gate_policy
|
"Unsupported noisy_gate_policy: " + noisy_gate_policy
|
||||||
)
|
)
|
||||||
|
|
||||||
if using_default_moe:
|
if moe_type is None or moe_type == "GShard":
|
||||||
self.moe_layer = GShardMOELayer(
|
self.moe_layer = GShardMOELayer(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
gpc.get_group(ParallelMode.EXPERT),
|
gpc.get_group(ParallelMode.EXPERT),
|
||||||
|
|
|
@ -0,0 +1,3 @@
|
||||||
|
from internlm.moe.sharded_moe import GShardMOELayer
|
||||||
|
|
||||||
|
__all__ = ["GShardMOELayer"]
|
|
@ -1,5 +1,6 @@
|
||||||
from typing import TYPE_CHECKING, Union
|
from typing import TYPE_CHECKING, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn import Module, ModuleList
|
from torch.nn import Module, ModuleList
|
||||||
|
|
||||||
|
@ -30,3 +31,5 @@ class BaseMoELayer(Base):
|
||||||
self.ep_group = ep_group
|
self.ep_group = ep_group
|
||||||
self.ep_size = ep_size
|
self.ep_size = ep_size
|
||||||
self.num_local_experts = num_local_experts
|
self.num_local_experts = num_local_experts
|
||||||
|
self.l_aux = torch.zeros(1, device=torch.cuda.current_device())
|
||||||
|
self.exp_counts = None
|
||||||
|
|
|
@ -18,17 +18,14 @@ class Experts(torch.nn.Module):
|
||||||
def __init__(self, experts: Union[Module, ModuleList], num_local_experts=1, expert_group_name=None):
|
def __init__(self, experts: Union[Module, ModuleList], num_local_experts=1, expert_group_name=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# TODO: We can not deepcopy FeedForward since it contains a process_group in submodules
|
|
||||||
# self.experts = torch.nn.ModuleList([copy.deepcopy(expert) for i in range(num_local_experts)])
|
|
||||||
|
|
||||||
if isinstance(experts, ModuleList):
|
if isinstance(experts, ModuleList):
|
||||||
self.experts = cast(ModuleList, experts)
|
self.wrapped_experts = cast(ModuleList, experts)
|
||||||
else:
|
else:
|
||||||
self.experts = ModuleList([experts])
|
self.wrapped_experts = ModuleList([experts])
|
||||||
self.num_local_experts = num_local_experts
|
self.num_local_experts = num_local_experts
|
||||||
|
|
||||||
# TODO: revisit allreduce for moe.gate...
|
# TODO: revisit allreduce for moe.gate...
|
||||||
for expert in self.experts:
|
for expert in self.wrapped_experts:
|
||||||
# TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group)
|
# TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group)
|
||||||
for _, param in expert.named_parameters():
|
for _, param in expert.named_parameters():
|
||||||
param.is_expert = True
|
param.is_expert = True
|
||||||
|
@ -37,7 +34,7 @@ class Experts(torch.nn.Module):
|
||||||
def forward(self, inputs):
|
def forward(self, inputs):
|
||||||
chunks = inputs.chunk(self.num_local_experts, dim=1)
|
chunks = inputs.chunk(self.num_local_experts, dim=1)
|
||||||
expert_outputs = []
|
expert_outputs = []
|
||||||
for chunk, expert in zip(chunks, self.experts):
|
for chunk, expert in zip(chunks, self.wrapped_experts):
|
||||||
out = expert(chunk)
|
out = expert(chunk)
|
||||||
if isinstance(out, tuple):
|
if isinstance(out, tuple):
|
||||||
out = out[0] # Ignore the bias term for now
|
out = out[0] # Ignore the bias term for now
|
||||||
|
|
|
@ -614,7 +614,7 @@ def try_save_moe_checkpoint(folder, model, tp_rank, pp_rank):
|
||||||
for n, p in module.state_dict().items():
|
for n, p in module.state_dict().items():
|
||||||
if "expert" in n and "moe_layer.gate" not in n:
|
if "expert" in n and "moe_layer.gate" not in n:
|
||||||
moe_state_dict[n_module + "." + n] = p
|
moe_state_dict[n_module + "." + n] = p
|
||||||
moe_str_prefix = ".moe_layer.experts.experts."
|
moe_str_prefix = ".moe_layer.experts.wrapped_experts."
|
||||||
# Reorder the moe name rank, so that each checkpoint only has one expert
|
# Reorder the moe name rank, so that each checkpoint only has one expert
|
||||||
experts_state_dict = defaultdict(dict)
|
experts_state_dict = defaultdict(dict)
|
||||||
for key in list(moe_state_dict.keys()):
|
for key in list(moe_state_dict.keys()):
|
||||||
|
@ -696,7 +696,7 @@ def try_load_moe_checkpoint(folder, model, state_dict, tp_rank, pp_rank):
|
||||||
fp = os.path.join(folder, fn)
|
fp = os.path.join(folder, fn)
|
||||||
expert_state_dict = llm_load(fp, map_location=get_current_device())
|
expert_state_dict = llm_load(fp, map_location=get_current_device())
|
||||||
# Updating global -> local expert ids
|
# Updating global -> local expert ids
|
||||||
moe_str_prefix = ".moe_layer.experts.experts."
|
moe_str_prefix = ".moe_layer.experts.wrapped_experts."
|
||||||
for key in list(expert_state_dict.keys()):
|
for key in list(expert_state_dict.keys()):
|
||||||
local_key = key.replace(f"{moe_str_prefix}{global_expert_id}", f"{moe_str_prefix}{local_expert_id}")
|
local_key = key.replace(f"{moe_str_prefix}{global_expert_id}", f"{moe_str_prefix}{local_expert_id}")
|
||||||
expert_state_dict[local_key] = expert_state_dict.pop(key)
|
expert_state_dict[local_key] = expert_state_dict.pop(key)
|
||||||
|
|
Loading…
Reference in New Issue