mirror of https://github.com/InternLM/InternLM
refactor code
parent
5539f9db50
commit
196514d87f
|
@ -6,7 +6,7 @@ from internlm.core.context import ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.model.linear import FeedForward
|
from internlm.model.linear import FeedForward
|
||||||
from internlm.moe.experts import Experts
|
from internlm.moe.experts import Experts
|
||||||
from internlm.moe.sharded_moe import MOELayer, TopKGate
|
from internlm.moe.sharded_moe import GShardMOELayer, TopKGate
|
||||||
from internlm.utils.logger import get_logger
|
from internlm.utils.logger import get_logger
|
||||||
|
|
||||||
# global llm logger
|
# global llm logger
|
||||||
|
@ -87,7 +87,7 @@ class MoE(torch.nn.Module):
|
||||||
experts = Experts(experts, self.num_local_experts, expert_group_name)
|
experts = Experts(experts, self.num_local_experts, expert_group_name)
|
||||||
|
|
||||||
if using_default_moe:
|
if using_default_moe:
|
||||||
self.moe_layer = MOELayer(
|
self.moe_layer = GShardMOELayer(
|
||||||
TopKGate(
|
TopKGate(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
num_experts,
|
num_experts,
|
||||||
|
|
|
@ -0,0 +1,23 @@
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.nn import Module
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
Base = Module[Tensor]
|
||||||
|
else:
|
||||||
|
Base = Module
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMoELayer(Base):
|
||||||
|
"""
|
||||||
|
Base MoE Layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, gate: Module, experts: Module, ep_group, ep_size, num_local_experts: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.gate = gate
|
||||||
|
self.experts = experts
|
||||||
|
self.ep_group = ep_group
|
||||||
|
self.ep_size = ep_size
|
||||||
|
self.num_local_experts = num_local_experts
|
|
@ -4,7 +4,7 @@ https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py
|
||||||
Git commit hash: f3943cf9109226ed3ecf2d5dbb639a11cd925555
|
Git commit hash: f3943cf9109226ed3ecf2d5dbb639a11cd925555
|
||||||
We retain the following license from the original files:
|
We retain the following license from the original files:
|
||||||
"""
|
"""
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple
|
from typing import Any, Callable, Dict, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -15,14 +15,11 @@ from torch.nn import Module
|
||||||
from internlm.utils.logger import get_logger
|
from internlm.utils.logger import get_logger
|
||||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||||
|
|
||||||
|
from .base_moe import BaseMoELayer
|
||||||
|
|
||||||
# global llm logger
|
# global llm logger
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__file__)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
Base = Module[Tensor]
|
|
||||||
else:
|
|
||||||
Base = Module
|
|
||||||
|
|
||||||
uniform_map: Dict[torch.device, Callable] = {}
|
uniform_map: Dict[torch.device, Callable] = {}
|
||||||
gumbel_map: Dict[torch.device, Callable] = {}
|
gumbel_map: Dict[torch.device, Callable] = {}
|
||||||
exp_selection_uniform_map: Dict[torch.device, Callable] = {}
|
exp_selection_uniform_map: Dict[torch.device, Callable] = {}
|
||||||
|
@ -387,7 +384,7 @@ class TopKGate(Module):
|
||||||
return gate_output
|
return gate_output
|
||||||
|
|
||||||
|
|
||||||
class MOELayer(Base):
|
class GShardMOELayer(BaseMoELayer):
|
||||||
"""MOELayer module which implements MixtureOfExperts as described in Gshard_.
|
"""MOELayer module which implements MixtureOfExperts as described in Gshard_.
|
||||||
::
|
::
|
||||||
|
|
||||||
|
@ -406,12 +403,8 @@ class MOELayer(Base):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, gate: Module, experts: Module, ep_group, ep_size, num_local_experts: int) -> None:
|
def __init__(self, gate: Module, experts: Module, ep_group, ep_size, num_local_experts: int) -> None:
|
||||||
super().__init__()
|
super().__init__(gate, experts, ep_group, ep_size, num_local_experts)
|
||||||
self.gate = gate
|
|
||||||
self.experts = experts
|
|
||||||
self.ep_group = ep_group
|
|
||||||
self.ep_size = ep_size
|
|
||||||
self.num_local_experts = num_local_experts
|
|
||||||
self.time_falltoall = 0.0
|
self.time_falltoall = 0.0
|
||||||
self.time_salltoall = 0.0
|
self.time_salltoall = 0.0
|
||||||
self.time_moe = 0.0
|
self.time_moe = 0.0
|
||||||
|
|
Loading…
Reference in New Issue