mirror of https://github.com/InternLM/InternLM
refactor code
parent
5539f9db50
commit
196514d87f
|
@ -6,7 +6,7 @@ from internlm.core.context import ParallelMode
|
|||
from internlm.core.context import global_context as gpc
|
||||
from internlm.model.linear import FeedForward
|
||||
from internlm.moe.experts import Experts
|
||||
from internlm.moe.sharded_moe import MOELayer, TopKGate
|
||||
from internlm.moe.sharded_moe import GShardMOELayer, TopKGate
|
||||
from internlm.utils.logger import get_logger
|
||||
|
||||
# global llm logger
|
||||
|
@ -87,7 +87,7 @@ class MoE(torch.nn.Module):
|
|||
experts = Experts(experts, self.num_local_experts, expert_group_name)
|
||||
|
||||
if using_default_moe:
|
||||
self.moe_layer = MOELayer(
|
||||
self.moe_layer = GShardMOELayer(
|
||||
TopKGate(
|
||||
hidden_size,
|
||||
num_experts,
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
|
||||
if TYPE_CHECKING:
|
||||
Base = Module[Tensor]
|
||||
else:
|
||||
Base = Module
|
||||
|
||||
|
||||
class BaseMoELayer(Base):
|
||||
"""
|
||||
Base MoE Layer.
|
||||
"""
|
||||
|
||||
def __init__(self, gate: Module, experts: Module, ep_group, ep_size, num_local_experts: int) -> None:
|
||||
super().__init__()
|
||||
self.gate = gate
|
||||
self.experts = experts
|
||||
self.ep_group = ep_group
|
||||
self.ep_size = ep_size
|
||||
self.num_local_experts = num_local_experts
|
|
@ -4,7 +4,7 @@ https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py
|
|||
Git commit hash: f3943cf9109226ed3ecf2d5dbb639a11cd925555
|
||||
We retain the following license from the original files:
|
||||
"""
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -15,14 +15,11 @@ from torch.nn import Module
|
|||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
|
||||
from .base_moe import BaseMoELayer
|
||||
|
||||
# global llm logger
|
||||
logger = get_logger(__file__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
Base = Module[Tensor]
|
||||
else:
|
||||
Base = Module
|
||||
|
||||
uniform_map: Dict[torch.device, Callable] = {}
|
||||
gumbel_map: Dict[torch.device, Callable] = {}
|
||||
exp_selection_uniform_map: Dict[torch.device, Callable] = {}
|
||||
|
@ -387,7 +384,7 @@ class TopKGate(Module):
|
|||
return gate_output
|
||||
|
||||
|
||||
class MOELayer(Base):
|
||||
class GShardMOELayer(BaseMoELayer):
|
||||
"""MOELayer module which implements MixtureOfExperts as described in Gshard_.
|
||||
::
|
||||
|
||||
|
@ -406,12 +403,8 @@ class MOELayer(Base):
|
|||
"""
|
||||
|
||||
def __init__(self, gate: Module, experts: Module, ep_group, ep_size, num_local_experts: int) -> None:
|
||||
super().__init__()
|
||||
self.gate = gate
|
||||
self.experts = experts
|
||||
self.ep_group = ep_group
|
||||
self.ep_size = ep_size
|
||||
self.num_local_experts = num_local_experts
|
||||
super().__init__(gate, experts, ep_group, ep_size, num_local_experts)
|
||||
|
||||
self.time_falltoall = 0.0
|
||||
self.time_salltoall = 0.0
|
||||
self.time_moe = 0.0
|
||||
|
|
Loading…
Reference in New Issue