mirror of https://github.com/InternLM/InternLM
reformat code
parent
c357288a8b
commit
5b6cf7cab0
|
@ -49,5 +49,5 @@ repos:
|
||||||
args:
|
args:
|
||||||
[
|
[
|
||||||
'--rcfile=.pylintrc',
|
'--rcfile=.pylintrc',
|
||||||
'--disable=C0114,C0415,W0212,W0235,W0238,W0621,C0103,R1735,C2801,E0402,C0412,W0719,R1728,W1514,W0718,W0105,W0707,C0209,W0703,W1203'
|
'--disable=C0330, C0114,C0415,W0212,W0235,W0238,W0621,C0103,R1735,C2801,E0402,C0412,W0719,R1728,W1514,W0718,W0105,W0707,C0209,W0703,W1203'
|
||||||
]
|
]
|
|
@ -10,21 +10,24 @@ https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py
|
||||||
|
|
||||||
# DeepSpeed Team
|
# DeepSpeed Team
|
||||||
|
|
||||||
|
from typing import Union, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import copy
|
|
||||||
from torch.nn import Module, ModuleList
|
from torch.nn import Module, ModuleList
|
||||||
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast
|
|
||||||
|
|
||||||
class Experts(torch.nn.Module):
|
class Experts(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Local Experts.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, experts: Union[Module, ModuleList], num_local_experts=1):
|
def __init__(self, experts: Union[Module, ModuleList], num_local_experts=1):
|
||||||
super(Experts, self).__init__()
|
super().__init__()
|
||||||
|
|
||||||
# TODO: We can not deepcopy FeedForward since it contains a process_group in submodules
|
# TODO: We can not deepcopy FeedForward since it contains a process_group in submodules
|
||||||
# self.experts = torch.nn.ModuleList([copy.deepcopy(expert) for i in range(num_local_experts)])
|
# self.experts = torch.nn.ModuleList([copy.deepcopy(expert) for i in range(num_local_experts)])
|
||||||
|
|
||||||
|
if isinstance(experts, ModuleList):
|
||||||
if type(experts) == ModuleList:
|
|
||||||
self.experts = cast(ModuleList, experts)
|
self.experts = cast(ModuleList, experts)
|
||||||
else:
|
else:
|
||||||
self.experts = ModuleList([experts])
|
self.experts = ModuleList([experts])
|
||||||
|
@ -33,7 +36,7 @@ class Experts(torch.nn.Module):
|
||||||
# TODO: revisit allreduce for moe.gate...
|
# TODO: revisit allreduce for moe.gate...
|
||||||
for expert in self.experts:
|
for expert in self.experts:
|
||||||
# TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group)
|
# TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group)
|
||||||
for name, param in expert.named_parameters():
|
for _, param in expert.named_parameters():
|
||||||
param.all_reduce = False
|
param.all_reduce = False
|
||||||
|
|
||||||
def forward(self, inputs):
|
def forward(self, inputs):
|
||||||
|
@ -41,7 +44,7 @@ class Experts(torch.nn.Module):
|
||||||
expert_outputs = []
|
expert_outputs = []
|
||||||
for chunk, expert in zip(chunks, self.experts):
|
for chunk, expert in zip(chunks, self.experts):
|
||||||
out = expert(chunk)
|
out = expert(chunk)
|
||||||
if type(out) is tuple:
|
if isinstance(out, tuple):
|
||||||
out = out[0] # Ignore the bias term for now
|
out = out[0] # Ignore the bias term for now
|
||||||
expert_outputs += [out]
|
expert_outputs += [out]
|
||||||
|
|
||||||
|
|
|
@ -1,14 +1,3 @@
|
||||||
import torch.distributed as dist
|
|
||||||
|
|
||||||
from internlm.utils.logger import get_logger
|
|
||||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
|
||||||
from internlm.core.context import global_context as gpc
|
|
||||||
from internlm.core.context import ParallelMode
|
|
||||||
|
|
||||||
|
|
||||||
# global llm logger
|
|
||||||
logger = get_logger(__file__)
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
The file has been adapted from the following files:
|
The file has been adapted from the following files:
|
||||||
https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py
|
https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py
|
||||||
|
@ -22,13 +11,19 @@ https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py
|
||||||
# DeepSpeed Team
|
# DeepSpeed Team
|
||||||
|
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple
|
||||||
from typing import Callable, Dict, TYPE_CHECKING, Any, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn.functional as F
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
import torch.nn.functional as F
|
|
||||||
|
from internlm.utils.logger import get_logger
|
||||||
|
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||||
|
|
||||||
|
# global llm logger
|
||||||
|
logger = get_logger(__file__)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
Base = Module[Tensor]
|
Base = Module[Tensor]
|
||||||
|
@ -57,9 +52,9 @@ def multiplicative_jitter(x, device: torch.device, epsilon=1e-2):
|
||||||
return x
|
return x
|
||||||
uniform = uniform_map.get(device)
|
uniform = uniform_map.get(device)
|
||||||
if uniform is None:
|
if uniform is None:
|
||||||
uniform = torch.distributions.uniform.Uniform(low=torch.tensor(1.0 - epsilon, device=device),
|
uniform = torch.distributions.uniform.Uniform(
|
||||||
high=torch.tensor(1.0 + epsilon,
|
low=torch.tensor(1.0 - epsilon, device=device), high=torch.tensor(1.0 + epsilon, device=device)
|
||||||
device=device)).rsample # type: ignore
|
).rsample # type: ignore
|
||||||
uniform_map[device] = uniform
|
uniform_map[device] = uniform
|
||||||
return x * uniform(x.shape)
|
return x * uniform(x.shape)
|
||||||
|
|
||||||
|
@ -73,23 +68,28 @@ def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
|
||||||
gumbel_map[device] = gumbel
|
gumbel_map[device] = gumbel
|
||||||
return gumbel(shape)
|
return gumbel(shape)
|
||||||
|
|
||||||
|
|
||||||
# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
|
# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
|
||||||
# See https://arxiv.org/pdf/2006.16668.pdf for details.
|
# See https://arxiv.org/pdf/2006.16668.pdf for details.
|
||||||
|
|
||||||
|
|
||||||
# Based on https://github.com/pytorch/pytorch/pull/40762
|
# Based on https://github.com/pytorch/pytorch/pull/40762
|
||||||
class _AllToAll(torch.autograd.Function):
|
class _AllToAll(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
All to all communication
|
||||||
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(
|
def forward(
|
||||||
ctx: Any,
|
ctx: Any,
|
||||||
# TODO: replace with DS process group
|
# TODO: replace with DS process group
|
||||||
group: torch.distributed.ProcessGroup,
|
group: torch.distributed.ProcessGroup,
|
||||||
input: Tensor) -> Tensor: # type: ignore
|
inputs: Tensor,
|
||||||
|
) -> Tensor: # type: ignore
|
||||||
ctx.group = group
|
ctx.group = group
|
||||||
input = input.contiguous()
|
inputs = inputs.contiguous()
|
||||||
output = torch.empty_like(input)
|
output = torch.empty_like(inputs)
|
||||||
dist.all_to_all_single(output, input, group=group)
|
dist.all_to_all_single(output, inputs, group=group)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -107,26 +107,26 @@ USE_EINSUM = True
|
||||||
def einsum(rule, a, b):
|
def einsum(rule, a, b):
|
||||||
if USE_EINSUM:
|
if USE_EINSUM:
|
||||||
return torch.einsum(rule, a, b)
|
return torch.einsum(rule, a, b)
|
||||||
elif rule == 's,se->se':
|
elif rule == "s,se->se":
|
||||||
## [1, s] * [s, e]
|
# [1, s] * [s, e]
|
||||||
return a.reshape(a.shape[0], -1) * b
|
return a.reshape(a.shape[0], -1) * b
|
||||||
elif rule == 'se,sc->sec':
|
elif rule == "se,sc->sec":
|
||||||
## [s,e,1] * [s,1,c]
|
# [s,e,1] * [s,1,c]
|
||||||
return a.unsqueeze(2) * b.unsqueeze(1)
|
return a.unsqueeze(2) * b.unsqueeze(1)
|
||||||
elif rule == 'se,se->s':
|
elif rule == "se,se->s":
|
||||||
## [s,1,e] * [s,e,1]
|
# [s,1,e] * [s,e,1]
|
||||||
return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1)
|
return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1)
|
||||||
elif rule == 'sec,sm->ecm':
|
elif rule == "sec,sm->ecm":
|
||||||
## [e*c, s] * [s, m]
|
# [e*c, s] * [s, m]
|
||||||
s = a.shape[0]
|
s = a.shape[0]
|
||||||
e = a.shape[1]
|
e = a.shape[1]
|
||||||
c = a.shape[2]
|
c = a.shape[2]
|
||||||
m = b.shape[1]
|
m = b.shape[1]
|
||||||
return torch.matmul(a.reshape(s, -1).t(), b).reshape(e, c, m)
|
return torch.matmul(a.reshape(s, -1).t(), b).reshape(e, c, m)
|
||||||
elif rule == 'sec,ecm->sm':
|
elif rule == "sec,ecm->sm":
|
||||||
## [s, e*c] * [e*c, m]
|
# [s, e*c] * [e*c, m]
|
||||||
return torch.matmul(a.reshape(a.shape[0], -1), b.reshape(-1, b.shape[-1]))
|
return torch.matmul(a.reshape(a.shape[0], -1), b.reshape(-1, b.shape[-1]))
|
||||||
elif rule == 'ks,ksm->sm':
|
elif rule == "ks,ksm->sm":
|
||||||
k = b.shape[0]
|
k = b.shape[0]
|
||||||
s = b.shape[1]
|
s = b.shape[1]
|
||||||
m = b.shape[2]
|
m = b.shape[2]
|
||||||
|
@ -172,16 +172,17 @@ def _one_hot_to_float(x, num_classes):
|
||||||
return F.one_hot(x, num_classes=num_classes).float()
|
return F.one_hot(x, num_classes=num_classes).float()
|
||||||
|
|
||||||
|
|
||||||
def top1gating(logits: Tensor,
|
def top1gating(
|
||||||
|
logits: Tensor,
|
||||||
capacity_factor: float,
|
capacity_factor: float,
|
||||||
min_capacity: int,
|
min_capacity: int,
|
||||||
used_token: Tensor = None,
|
used_token: Tensor = None,
|
||||||
noisy_gate_policy: Optional[str] = None,
|
noisy_gate_policy: Optional[str] = None,
|
||||||
drop_tokens: bool = True,
|
drop_tokens: bool = True,
|
||||||
use_rts: bool = True,
|
use_rts: bool = True,
|
||||||
use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||||
"""Implements Top1Gating on logits."""
|
"""Implements Top1Gating on logits."""
|
||||||
if noisy_gate_policy == 'RSample':
|
if noisy_gate_policy == "RSample":
|
||||||
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
|
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
|
||||||
# everything is in fp32 in this function
|
# everything is in fp32 in this function
|
||||||
gates = F.softmax(logits, dim=1)
|
gates = F.softmax(logits, dim=1)
|
||||||
|
@ -190,7 +191,7 @@ def top1gating(logits: Tensor,
|
||||||
|
|
||||||
# Create a mask for 1st's expert per token
|
# Create a mask for 1st's expert per token
|
||||||
# noisy gating
|
# noisy gating
|
||||||
indices1_s = torch.argmax(logits_w_noise if noisy_gate_policy == 'RSample' else gates, dim=1)
|
indices1_s = torch.argmax(logits_w_noise if noisy_gate_policy == "RSample" else gates, dim=1)
|
||||||
num_experts = int(gates.shape[1])
|
num_experts = int(gates.shape[1])
|
||||||
mask1 = F.one_hot(indices1_s, num_classes=num_experts)
|
mask1 = F.one_hot(indices1_s, num_classes=num_experts)
|
||||||
|
|
||||||
|
@ -199,7 +200,7 @@ def top1gating(logits: Tensor,
|
||||||
mask1 = einsum("s,se->se", used_token, mask1)
|
mask1 = einsum("s,se->se", used_token, mask1)
|
||||||
|
|
||||||
# gating decisions
|
# gating decisions
|
||||||
exp_counts = torch.sum(mask1, dim=0).detach().to('cpu')
|
exp_counts = torch.sum(mask1, dim=0).detach().to("cpu")
|
||||||
|
|
||||||
# if we don't want to drop any tokens
|
# if we don't want to drop any tokens
|
||||||
if not drop_tokens:
|
if not drop_tokens:
|
||||||
|
@ -216,43 +217,29 @@ def top1gating(logits: Tensor,
|
||||||
if use_rts:
|
if use_rts:
|
||||||
uniform = exp_selection_uniform_map.get(logits.device)
|
uniform = exp_selection_uniform_map.get(logits.device)
|
||||||
if uniform is None:
|
if uniform is None:
|
||||||
uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=logits.device),
|
uniform = torch.distributions.uniform.Uniform(
|
||||||
high=torch.tensor(1.0, device=logits.device)).rsample
|
low=torch.tensor(0.0, device=logits.device), high=torch.tensor(1.0, device=logits.device)
|
||||||
|
).rsample
|
||||||
exp_selection_uniform_map[logits.device] = uniform
|
exp_selection_uniform_map[logits.device] = uniform
|
||||||
|
|
||||||
mask1_rand = mask1 * uniform(mask1.shape)
|
mask1_rand = mask1 * uniform(mask1.shape)
|
||||||
else:
|
else:
|
||||||
mask1_rand = mask1
|
mask1_rand = mask1
|
||||||
|
|
||||||
assert logits.shape[
|
assert (
|
||||||
0] >= min_capacity, "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size."
|
logits.shape[0] >= min_capacity
|
||||||
|
), """No. of tokens (batch-size) should be greater than min_capacity.
|
||||||
|
Either set min_capacity to 0 or increase your batch size."""
|
||||||
|
|
||||||
top_idx = _top_idx(mask1_rand, capacity) # @wenwen: token index
|
top_idx = _top_idx(mask1_rand, capacity) # @wenwen: token index
|
||||||
|
|
||||||
new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1)
|
new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1)
|
||||||
mask1 = new_mask1
|
mask1 = new_mask1
|
||||||
|
|
||||||
if use_tutel:
|
|
||||||
# Tutel doesn't support index values masked with zero
|
|
||||||
# so we need to replace masked indices with -1
|
|
||||||
indices_mask = mask1.sum(dim=1) * num_experts - 1
|
|
||||||
indices1_s = torch.min(indices1_s, indices_mask)
|
|
||||||
|
|
||||||
# Compute locations in capacity buffer
|
# Compute locations in capacity buffer
|
||||||
|
|
||||||
locations1 = torch.cumsum(mask1, dim=0) - 1
|
locations1 = torch.cumsum(mask1, dim=0) - 1
|
||||||
|
|
||||||
if use_tutel:
|
|
||||||
gates1_s = (gates * mask1).sum(dim=1)
|
|
||||||
locations1_s = torch.sum(locations1 * mask1, dim=1)
|
|
||||||
return l_aux, capacity, num_experts, [
|
|
||||||
indices1_s,
|
|
||||||
], [
|
|
||||||
locations1_s,
|
|
||||||
], [
|
|
||||||
gates1_s,
|
|
||||||
], exp_counts
|
|
||||||
|
|
||||||
# Store the capacity location for each token
|
# Store the capacity location for each token
|
||||||
locations1_s = torch.sum(locations1 * mask1, dim=1)
|
locations1_s = torch.sum(locations1 * mask1, dim=1)
|
||||||
|
|
||||||
|
@ -295,7 +282,7 @@ def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tup
|
||||||
locations2 += torch.sum(mask1, dim=0, keepdim=True)
|
locations2 += torch.sum(mask1, dim=0, keepdim=True)
|
||||||
|
|
||||||
# gating decisions
|
# gating decisions
|
||||||
exp_counts = torch.sum(mask1, dim=0).detach().to('cpu')
|
exp_counts = torch.sum(mask1, dim=0).detach().to("cpu")
|
||||||
|
|
||||||
# Compute l_aux
|
# Compute l_aux
|
||||||
me = torch.mean(gates, dim=0)
|
me = torch.mean(gates, dim=0)
|
||||||
|
@ -352,7 +339,8 @@ class TopKGate(Module):
|
||||||
|
|
||||||
wg: torch.nn.Linear
|
wg: torch.nn.Linear
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
model_dim: int,
|
model_dim: int,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
k: int = 1,
|
k: int = 1,
|
||||||
|
@ -361,12 +349,13 @@ class TopKGate(Module):
|
||||||
min_capacity: int = 8,
|
min_capacity: int = 8,
|
||||||
noisy_gate_policy: Optional[str] = None,
|
noisy_gate_policy: Optional[str] = None,
|
||||||
drop_tokens: bool = True,
|
drop_tokens: bool = True,
|
||||||
use_rts: bool = True) -> None:
|
use_rts: bool = True,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Only top-1 and top-2 are supported at the moment.
|
# Only top-1 and top-2 are supported at the moment.
|
||||||
if k != 1 and k != 2:
|
if k not in (1, 2):
|
||||||
raise ValueError('Only top-1 and top-2 gatings are supported.')
|
raise ValueError("Only top-1 and top-2 gatings are supported.")
|
||||||
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float()
|
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float()
|
||||||
self.k = k
|
self.k = k
|
||||||
self.capacity_factor = capacity_factor
|
self.capacity_factor = capacity_factor
|
||||||
|
@ -378,34 +367,40 @@ class TopKGate(Module):
|
||||||
self.drop_tokens = drop_tokens
|
self.drop_tokens = drop_tokens
|
||||||
self.use_rts = use_rts
|
self.use_rts = use_rts
|
||||||
|
|
||||||
def forward(self,
|
def forward(
|
||||||
input: torch.Tensor,
|
self, inputs: torch.Tensor, used_token: torch.Tensor = None
|
||||||
used_token: torch.Tensor = None,
|
) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore
|
||||||
use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore
|
|
||||||
|
|
||||||
if self.wall_clock_breakdown:
|
if self.wall_clock_breakdown:
|
||||||
timer('TopKGate').start()
|
timer("TopKGate").start()
|
||||||
|
|
||||||
if self.wg.weight.dtype != torch.float32:
|
if self.wg.weight.dtype != torch.float32:
|
||||||
self.wg = self.wg.float()
|
self.wg = self.wg.float()
|
||||||
input_fp32 = input.float()
|
inputs_fp32 = inputs.float()
|
||||||
# input jittering
|
# input jittering
|
||||||
if self.noisy_gate_policy == 'Jitter' and self.training:
|
if self.noisy_gate_policy == "Jitter" and self.training:
|
||||||
input_fp32 = multiplicative_jitter(input_fp32, device=input.device)
|
inputs_fp32 = multiplicative_jitter(inputs_fp32, device=inputs.device)
|
||||||
logits = self.wg(input_fp32)
|
logits = self.wg(inputs_fp32)
|
||||||
|
|
||||||
if self.k == 1:
|
if self.k == 1:
|
||||||
gate_output = top1gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
|
gate_output = top1gating(
|
||||||
self.min_capacity, used_token, self.noisy_gate_policy if self.training else None,
|
logits,
|
||||||
self.drop_tokens, self.use_rts, use_tutel)
|
self.capacity_factor if self.training else self.eval_capacity_factor,
|
||||||
|
self.min_capacity,
|
||||||
|
used_token,
|
||||||
|
self.noisy_gate_policy if self.training else None,
|
||||||
|
self.drop_tokens,
|
||||||
|
self.use_rts,
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
gate_output = top2gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
|
gate_output = top2gating(
|
||||||
self.min_capacity)
|
logits, self.capacity_factor if self.training else self.eval_capacity_factor, self.min_capacity
|
||||||
|
)
|
||||||
|
|
||||||
if self.wall_clock_breakdown:
|
if self.wall_clock_breakdown:
|
||||||
timer('TopKGate').stop()
|
timer("TopKGate").stop()
|
||||||
self.gate_time = timer('TopKGate').elapsed(reset=False)
|
self.gate_time = timer("TopKGate").elapsed(reset=False)
|
||||||
|
|
||||||
return gate_output
|
return gate_output
|
||||||
|
|
||||||
|
@ -416,7 +411,7 @@ class MOELayer(Base):
|
||||||
|
|
||||||
gate = TopKGate(model_dim, num_experts)
|
gate = TopKGate(model_dim, num_experts)
|
||||||
moe = MOELayer(gate, expert)
|
moe = MOELayer(gate, expert)
|
||||||
output = moe(input)
|
output = moe(inputs)
|
||||||
l_aux = moe.l_aux
|
l_aux = moe.l_aux
|
||||||
|
|
||||||
.. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
|
.. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
|
||||||
|
@ -428,12 +423,7 @@ class MOELayer(Base):
|
||||||
expert network
|
expert network
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self, gate: Module, experts: Module, ep_group, ep_size, num_local_experts: int) -> None:
|
||||||
gate: Module,
|
|
||||||
experts: Module,
|
|
||||||
ep_group,
|
|
||||||
ep_size,
|
|
||||||
num_local_experts: int) -> None:
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.gate = gate
|
self.gate = gate
|
||||||
self.experts = experts
|
self.experts = experts
|
||||||
|
@ -445,59 +435,59 @@ class MOELayer(Base):
|
||||||
self.time_moe = 0.0
|
self.time_moe = 0.0
|
||||||
self.wall_clock_breakdown = False
|
self.wall_clock_breakdown = False
|
||||||
|
|
||||||
|
|
||||||
def _set_ep_group(self, ep_group):
|
def _set_ep_group(self, ep_group):
|
||||||
self.ep_group = ep_group
|
self.ep_group = ep_group
|
||||||
|
|
||||||
def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
|
def forward(self, *inputs: Tensor) -> Tensor:
|
||||||
|
|
||||||
if self.wall_clock_breakdown:
|
if self.wall_clock_breakdown:
|
||||||
timer('moe').start()
|
timer("moe").start()
|
||||||
|
|
||||||
# Implement Algorithm 2 from GShard paper.
|
# Implement Algorithm 2 from GShard paper.
|
||||||
d_model = input[0].shape[-1]
|
d_model = inputs[0].shape[-1]
|
||||||
|
|
||||||
# Initial implementation -> Reshape into S tokens by dropping sequence dimension.
|
# Initial implementation -> Reshape into S tokens by dropping sequence dimension.
|
||||||
# Reshape into G groups so that each group can distribute tokens equally
|
# Reshape into G groups so that each group can distribute tokens equally
|
||||||
# group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1
|
# group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1
|
||||||
reshaped_input = input[0].reshape(-1, d_model)
|
reshaped_inputs = inputs[0].reshape(-1, d_model)
|
||||||
|
|
||||||
self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_input, input[1])
|
self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_inputs, inputs[1])
|
||||||
dispatched_input = einsum("sec,sm->ecm", dispatch_mask.type_as(input[0]), reshaped_input) ## TODO: heavy memory usage due to long sequence length
|
dispatched_inputs = einsum(
|
||||||
|
"sec,sm->ecm", dispatch_mask.type_as(inputs[0]), reshaped_inputs
|
||||||
|
) # TODO: heavy memory usage due to long sequence length
|
||||||
|
|
||||||
if self.wall_clock_breakdown:
|
if self.wall_clock_breakdown:
|
||||||
timer('falltoall').start()
|
timer("falltoall").start()
|
||||||
|
|
||||||
|
dispatched_inputs = _AllToAll.apply(self.ep_group, dispatched_inputs)
|
||||||
dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input)
|
|
||||||
|
|
||||||
if self.wall_clock_breakdown:
|
if self.wall_clock_breakdown:
|
||||||
timer('falltoall').stop()
|
timer("falltoall").stop()
|
||||||
self.time_falltoall = timer('falltoall').elapsed(reset=False)
|
self.time_falltoall = timer("falltoall").elapsed(reset=False)
|
||||||
|
|
||||||
# Re-shape after all-to-all: ecm -> gecm
|
# Re-shape after all-to-all: ecm -> gecm
|
||||||
dispatched_input = dispatched_input.reshape(self.ep_size, self.num_local_experts, -1, d_model)
|
dispatched_inputs = dispatched_inputs.reshape(self.ep_size, self.num_local_experts, -1, d_model)
|
||||||
|
|
||||||
expert_output = self.experts(dispatched_input)
|
expert_output = self.experts(dispatched_inputs)
|
||||||
|
|
||||||
if self.wall_clock_breakdown:
|
if self.wall_clock_breakdown:
|
||||||
timer('salltoall').start()
|
timer("salltoall").start()
|
||||||
|
|
||||||
expert_output = _AllToAll.apply(self.ep_group, expert_output)
|
expert_output = _AllToAll.apply(self.ep_group, expert_output)
|
||||||
|
|
||||||
if self.wall_clock_breakdown:
|
if self.wall_clock_breakdown:
|
||||||
timer('salltoall').stop()
|
timer("salltoall").stop()
|
||||||
self.time_salltoall = timer('salltoall').elapsed(reset=False)
|
self.time_salltoall = timer("salltoall").elapsed(reset=False)
|
||||||
|
|
||||||
# Re-shape back: gecm -> ecm
|
# Re-shape back: gecm -> ecm
|
||||||
expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model)
|
expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model)
|
||||||
|
|
||||||
combined_output = einsum("sec,ecm->sm", combine_weights.type_as(input[0]), expert_output)
|
combined_output = einsum("sec,ecm->sm", combine_weights.type_as(inputs[0]), expert_output)
|
||||||
|
|
||||||
a = combined_output.reshape(input[0].shape)
|
a = combined_output.reshape(inputs[0].shape)
|
||||||
|
|
||||||
if self.wall_clock_breakdown:
|
if self.wall_clock_breakdown:
|
||||||
timer('moe').stop()
|
timer("moe").stop()
|
||||||
self.time_moe = timer('moe').elapsed(reset=False)
|
self.time_moe = timer("moe").elapsed(reset=False)
|
||||||
|
|
||||||
return a
|
return a
|
||||||
|
|
Loading…
Reference in New Issue