mirror of https://github.com/hpcaitech/ColossalAI
[MOE] add FP32LinearGate for MOE in NaiveAMP context (#480)
parent
353566c198
commit
d7ea63992b
|
@ -8,7 +8,7 @@ from colossalai.core import MOE_CONTEXT
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from ._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum
|
from ._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum
|
||||||
from .experts import MoeExperts
|
from .experts import MoeExperts
|
||||||
from .utils import autocast_softmax
|
from .utils import ForceFP32Parameter
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
@ -70,7 +70,7 @@ class Top1Router(nn.Module):
|
||||||
if self.noisy_func is not None and self.training:
|
if self.noisy_func is not None and self.training:
|
||||||
inputs = self.noisy_func(inputs)
|
inputs = self.noisy_func(inputs)
|
||||||
|
|
||||||
logits = autocast_softmax(inputs, dim=-1)
|
logits = F.softmax(inputs, dim=-1)
|
||||||
num_experts = logits.size(-1)
|
num_experts = logits.size(-1)
|
||||||
capacity = self.get_capacity(logits.shape)
|
capacity = self.get_capacity(logits.shape)
|
||||||
|
|
||||||
|
@ -161,7 +161,7 @@ class Top2Router(nn.Module):
|
||||||
if self.noisy_func is not None and self.training:
|
if self.noisy_func is not None and self.training:
|
||||||
inputs = self.noisy_func(inputs)
|
inputs = self.noisy_func(inputs)
|
||||||
|
|
||||||
logits = autocast_softmax(inputs, dim=-1) # logits: [s, e]
|
logits = F.softmax(inputs, dim=-1) # logits: [s, e]
|
||||||
num_experts = logits.size(-1)
|
num_experts = logits.size(-1)
|
||||||
capacity = self.get_capacity(logits.shape)
|
capacity = self.get_capacity(logits.shape)
|
||||||
|
|
||||||
|
@ -216,6 +216,23 @@ class Top2Router(nn.Module):
|
||||||
return cb_weight, sec_mask
|
return cb_weight, sec_mask
|
||||||
|
|
||||||
|
|
||||||
|
class FP32LinearGate(nn.Linear):
|
||||||
|
"""Gate module used in MOE layer. Just a linear function without bias.
|
||||||
|
But it should be kept as fp32 forever.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
d_model (int): Hidden dimension of training model
|
||||||
|
num_experts (int): The number experts
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
weight (ForceFP32Parameter): The weight of linear gate
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, d_model: int, num_experts: int):
|
||||||
|
super().__init__(d_model, num_experts, bias=False, device=get_current_device())
|
||||||
|
self.weight = ForceFP32Parameter(self.weight)
|
||||||
|
|
||||||
|
|
||||||
class MoeLayer(nn.Module):
|
class MoeLayer(nn.Module):
|
||||||
"""A MoE layer, that puts its input tensor to its gate and uses the output logits
|
"""A MoE layer, that puts its input tensor to its gate and uses the output logits
|
||||||
to router all tokens, is mainly used to exchange all tokens for every expert across
|
to router all tokens, is mainly used to exchange all tokens for every expert across
|
||||||
|
@ -237,7 +254,7 @@ class MoeLayer(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.d_model = dim_model
|
self.d_model = dim_model
|
||||||
self.num_experts = num_experts
|
self.num_experts = num_experts
|
||||||
self.gate = nn.Linear(dim_model, num_experts, bias=False, device=get_current_device())
|
self.gate = FP32LinearGate(dim_model, num_experts)
|
||||||
self.router = router
|
self.router = router
|
||||||
self.experts = experts
|
self.experts = experts
|
||||||
self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False
|
self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False
|
||||||
|
@ -266,7 +283,8 @@ class MoeLayer(nn.Module):
|
||||||
|
|
||||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||||
tokens = inputs.reshape(-1, self.d_model)
|
tokens = inputs.reshape(-1, self.d_model)
|
||||||
gate_output = self.gate(tokens)
|
fp32_input = tokens.to(torch.float32) if inputs.dtype != torch.float32 else tokens
|
||||||
|
gate_output = self.gate(fp32_input)
|
||||||
router_res = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group)
|
router_res = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group)
|
||||||
|
|
||||||
if self.use_kernel:
|
if self.use_kernel:
|
||||||
|
@ -290,7 +308,7 @@ class MoeLayer(nn.Module):
|
||||||
expert_output = expert_output.reshape(-1, self.d_model)
|
expert_output = expert_output.reshape(-1, self.d_model)
|
||||||
ans = MoeCombine.apply(expert_output, *router_res)
|
ans = MoeCombine.apply(expert_output, *router_res)
|
||||||
else:
|
else:
|
||||||
combine_weights = router_res[0]
|
combine_weights = router_res[0].type_as(inputs)
|
||||||
combine_weights = combine_weights.view(combine_weights.shape[0], -1)
|
combine_weights = combine_weights.view(combine_weights.shape[0], -1)
|
||||||
expert_output = expert_output.view(-1, expert_output.shape[-1])
|
expert_output = expert_output.view(-1, expert_output.shape[-1])
|
||||||
ans = torch.matmul(combine_weights, expert_output)
|
ans = torch.matmul(combine_weights, expert_output)
|
||||||
|
|
|
@ -1,10 +1,15 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.core import MOE_CONTEXT
|
from colossalai.core import MOE_CONTEXT
|
||||||
from .experts import FFNExperts, TPExperts
|
from .experts import FFNExperts, TPExperts
|
||||||
|
|
||||||
|
|
||||||
|
class ForceFP32Parameter(torch.nn.Parameter):
|
||||||
|
|
||||||
|
def half(self, memory_format=None):
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class NormalNoiseGenerator:
|
class NormalNoiseGenerator:
|
||||||
"""Generates a random noisy mask for logtis tensor.
|
"""Generates a random noisy mask for logtis tensor.
|
||||||
|
|
||||||
|
@ -46,14 +51,6 @@ class UniformNoiseGenerator:
|
||||||
return inputs * noisy
|
return inputs * noisy
|
||||||
|
|
||||||
|
|
||||||
def autocast_softmax(inputs: torch.Tensor, dim: int):
|
|
||||||
assert inputs.dtype in {torch.float16, torch.float32}
|
|
||||||
fp16_flag = (inputs.dtype == torch.float16)
|
|
||||||
sm_input = inputs.to(torch.float32) if fp16_flag else inputs
|
|
||||||
sm_output = F.softmax(sm_input, dim)
|
|
||||||
return sm_output
|
|
||||||
|
|
||||||
|
|
||||||
def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
|
def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
|
||||||
mep_size = MOE_CONTEXT.max_ep_size
|
mep_size = MOE_CONTEXT.max_ep_size
|
||||||
if num_experts % mep_size == 0 or mep_size % num_experts == 0:
|
if num_experts % mep_size == 0 or mep_size % num_experts == 0:
|
||||||
|
|
Loading…
Reference in New Issue