refactor code

pull/567/head
Wenwen Qu 2024-01-03 17:39:37 +08:00
parent 5539f9db50
commit 196514d87f
4 changed files with 31 additions and 15 deletions

View File

@ -6,7 +6,7 @@ from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.linear import FeedForward
from internlm.moe.experts import Experts
from internlm.moe.sharded_moe import MOELayer, TopKGate
from internlm.moe.sharded_moe import GShardMOELayer, TopKGate
from internlm.utils.logger import get_logger
# global llm logger
@ -87,7 +87,7 @@ class MoE(torch.nn.Module):
experts = Experts(experts, self.num_local_experts, expert_group_name)
if using_default_moe:
self.moe_layer = MOELayer(
self.moe_layer = GShardMOELayer(
TopKGate(
hidden_size,
num_experts,

0
internlm/moe/__init__.py Normal file
View File

23
internlm/moe/base_moe.py Normal file
View File

@ -0,0 +1,23 @@
from typing import TYPE_CHECKING
from torch import Tensor
from torch.nn import Module
if TYPE_CHECKING:
Base = Module[Tensor]
else:
Base = Module
class BaseMoELayer(Base):
"""
Base MoE Layer.
"""
def __init__(self, gate: Module, experts: Module, ep_group, ep_size, num_local_experts: int) -> None:
super().__init__()
self.gate = gate
self.experts = experts
self.ep_group = ep_group
self.ep_size = ep_size
self.num_local_experts = num_local_experts

View File

@ -4,7 +4,7 @@ https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py
Git commit hash: f3943cf9109226ed3ecf2d5dbb639a11cd925555
We retain the following license from the original files:
"""
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple
from typing import Any, Callable, Dict, Optional, Tuple
import torch
import torch.distributed as dist
@ -15,14 +15,11 @@ from torch.nn import Module
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from .base_moe import BaseMoELayer
# global llm logger
logger = get_logger(__file__)
if TYPE_CHECKING:
Base = Module[Tensor]
else:
Base = Module
uniform_map: Dict[torch.device, Callable] = {}
gumbel_map: Dict[torch.device, Callable] = {}
exp_selection_uniform_map: Dict[torch.device, Callable] = {}
@ -387,7 +384,7 @@ class TopKGate(Module):
return gate_output
class MOELayer(Base):
class GShardMOELayer(BaseMoELayer):
"""MOELayer module which implements MixtureOfExperts as described in Gshard_.
::
@ -406,12 +403,8 @@ class MOELayer(Base):
"""
def __init__(self, gate: Module, experts: Module, ep_group, ep_size, num_local_experts: int) -> None:
super().__init__()
self.gate = gate
self.experts = experts
self.ep_group = ep_group
self.ep_size = ep_size
self.num_local_experts = num_local_experts
super().__init__(gate, experts, ep_group, ep_size, num_local_experts)
self.time_falltoall = 0.0
self.time_salltoall = 0.0
self.time_moe = 0.0