mirror of https://github.com/InternLM/InternLM
modified: .pre-commit-config.yaml
modified: internlm/model/moe.py modified: internlm/model/modeling_internlm.pypull/375/head
parent
5b6cf7cab0
commit
8b198b2665
|
@ -49,5 +49,5 @@ repos:
|
|||
args:
|
||||
[
|
||||
'--rcfile=.pylintrc',
|
||||
'--disable=C0330, C0114,C0415,W0212,W0235,W0238,W0621,C0103,R1735,C2801,E0402,C0412,W0719,R1728,W1514,W0718,W0105,W0707,C0209,W0703,W1203'
|
||||
'--disable=C0330, C0114,C0415,W0212,W0235,W0238,W0621,C0103,R1735,C2801,E0402,C0412,W0719,R1728,W1514,W0718,W0105,W0707,C0209,W0703,W1203,W1202'
|
||||
]
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
import typing
|
||||
|
||||
import torch
|
||||
|
||||
from internlm.moe.sharded_moe import MOELayer, TopKGate
|
||||
from internlm.moe.experts import Experts
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
|
||||
from internlm.moe.experts import Experts
|
||||
from internlm.moe.sharded_moe import MOELayer, TopKGate
|
||||
from internlm.utils.logger import get_logger
|
||||
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
|
@ -11,8 +13,6 @@ from internlm.utils.logger import get_logger
|
|||
|
||||
# DeepSpeed Team
|
||||
|
||||
import torch
|
||||
import typing
|
||||
|
||||
# global llm logger
|
||||
logger = get_logger(__file__)
|
||||
|
@ -35,6 +35,7 @@ def is_moe_param(param: torch.Tensor) -> bool:
|
|||
return True
|
||||
return False
|
||||
|
||||
|
||||
class MoE(torch.nn.Module):
|
||||
"""Initialize an MoE layer.
|
||||
|
||||
|
@ -47,51 +48,70 @@ class MoE(torch.nn.Module):
|
|||
capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.
|
||||
eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
|
||||
min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
|
||||
noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample' or 'None'.
|
||||
drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to infinite capacity).
|
||||
noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'
|
||||
or 'None'.
|
||||
drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to
|
||||
infinite capacity).
|
||||
use_rts (bool, optional): default=True, whether to use Random Token Selection.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size,
|
||||
expert,
|
||||
num_experts=1,
|
||||
ep_size=1,
|
||||
k=1,
|
||||
capacity_factor=1.,
|
||||
eval_capacity_factor=1.,
|
||||
min_capacity=4,
|
||||
noisy_gate_policy: typing.Optional[str] = None,
|
||||
drop_tokens: bool = True,
|
||||
use_rts: bool = True,
|
||||
using_default_moe: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
expert,
|
||||
num_experts=1,
|
||||
ep_size=1,
|
||||
k=1,
|
||||
capacity_factor=1.0,
|
||||
eval_capacity_factor=1.0,
|
||||
min_capacity=4,
|
||||
noisy_gate_policy: typing.Optional[str] = None,
|
||||
drop_tokens: bool = True,
|
||||
use_rts: bool = True,
|
||||
using_default_moe: bool = True,
|
||||
):
|
||||
|
||||
super(MoE, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
assert num_experts % ep_size == 0, f"Number of experts ({num_experts}) should be divisible by expert parallel size ({ep_size})"
|
||||
assert (
|
||||
num_experts % ep_size == 0
|
||||
), f"Number of experts ({num_experts}) should be divisible by expert parallel size ({ep_size})"
|
||||
self.ep_size = ep_size
|
||||
self.num_experts = num_experts
|
||||
self.num_local_experts = num_experts // self.ep_size
|
||||
|
||||
logger.info(
|
||||
f'Creating MoE layer with num_experts: {num_experts} | num_local_experts: {self.num_local_experts} | expert_parallel_size: {self.ep_size}')
|
||||
f"""Creating MoE layer with num_experts: {num_experts} | num_local_experts:
|
||||
{self.num_local_experts} | expert_parallel_size: {self.ep_size}"""
|
||||
)
|
||||
|
||||
assert noisy_gate_policy is None or noisy_gate_policy in ['None', 'Jitter', 'RSample'], \
|
||||
'Unsupported noisy_gate_policy: ' + noisy_gate_policy
|
||||
assert noisy_gate_policy is None or noisy_gate_policy in ["None", "Jitter", "RSample"], (
|
||||
"Unsupported noisy_gate_policy: " + noisy_gate_policy
|
||||
)
|
||||
|
||||
experts = Experts(expert, self.num_local_experts)
|
||||
|
||||
if using_default_moe:
|
||||
self.moe_layer = MOELayer(TopKGate(hidden_size, num_experts, k, capacity_factor, eval_capacity_factor,
|
||||
min_capacity, noisy_gate_policy, drop_tokens, use_rts),
|
||||
experts,
|
||||
gpc.get_group(ParallelMode.EXPERT),
|
||||
self.ep_size,
|
||||
self.num_local_experts)
|
||||
|
||||
self.moe_layer = MOELayer(
|
||||
TopKGate(
|
||||
hidden_size,
|
||||
num_experts,
|
||||
k,
|
||||
capacity_factor,
|
||||
eval_capacity_factor,
|
||||
min_capacity,
|
||||
noisy_gate_policy,
|
||||
drop_tokens,
|
||||
use_rts,
|
||||
),
|
||||
experts,
|
||||
gpc.get_group(ParallelMode.EXPERT),
|
||||
self.ep_size,
|
||||
self.num_local_experts,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, used_token=None):
|
||||
""" MoE forward
|
||||
"""MoE forward
|
||||
|
||||
Arguments:
|
||||
hidden_states (Tensor): input to the layer
|
||||
|
|
Loading…
Reference in New Issue