[MOE] add FP32LinearGate for MOE in NaiveAMP context (#480)

pull/489/head
HELSON 2022-03-22 10:50:20 +08:00 committed by GitHub
parent 353566c198
commit d7ea63992b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 15 deletions

View File

@ -8,7 +8,7 @@ from colossalai.core import MOE_CONTEXT
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from ._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum from ._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum
from .experts import MoeExperts from .experts import MoeExperts
from .utils import autocast_softmax from .utils import ForceFP32Parameter
from typing import Callable, Optional from typing import Callable, Optional
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
@ -70,7 +70,7 @@ class Top1Router(nn.Module):
if self.noisy_func is not None and self.training: if self.noisy_func is not None and self.training:
inputs = self.noisy_func(inputs) inputs = self.noisy_func(inputs)
logits = autocast_softmax(inputs, dim=-1) logits = F.softmax(inputs, dim=-1)
num_experts = logits.size(-1) num_experts = logits.size(-1)
capacity = self.get_capacity(logits.shape) capacity = self.get_capacity(logits.shape)
@ -161,7 +161,7 @@ class Top2Router(nn.Module):
if self.noisy_func is not None and self.training: if self.noisy_func is not None and self.training:
inputs = self.noisy_func(inputs) inputs = self.noisy_func(inputs)
logits = autocast_softmax(inputs, dim=-1) # logits: [s, e] logits = F.softmax(inputs, dim=-1) # logits: [s, e]
num_experts = logits.size(-1) num_experts = logits.size(-1)
capacity = self.get_capacity(logits.shape) capacity = self.get_capacity(logits.shape)
@ -216,6 +216,23 @@ class Top2Router(nn.Module):
return cb_weight, sec_mask return cb_weight, sec_mask
class FP32LinearGate(nn.Linear):
"""Gate module used in MOE layer. Just a linear function without bias.
But it should be kept as fp32 forever.
Args:
d_model (int): Hidden dimension of training model
num_experts (int): The number experts
Attributes:
weight (ForceFP32Parameter): The weight of linear gate
"""
def __init__(self, d_model: int, num_experts: int):
super().__init__(d_model, num_experts, bias=False, device=get_current_device())
self.weight = ForceFP32Parameter(self.weight)
class MoeLayer(nn.Module): class MoeLayer(nn.Module):
"""A MoE layer, that puts its input tensor to its gate and uses the output logits """A MoE layer, that puts its input tensor to its gate and uses the output logits
to router all tokens, is mainly used to exchange all tokens for every expert across to router all tokens, is mainly used to exchange all tokens for every expert across
@ -237,7 +254,7 @@ class MoeLayer(nn.Module):
super().__init__() super().__init__()
self.d_model = dim_model self.d_model = dim_model
self.num_experts = num_experts self.num_experts = num_experts
self.gate = nn.Linear(dim_model, num_experts, bias=False, device=get_current_device()) self.gate = FP32LinearGate(dim_model, num_experts)
self.router = router self.router = router
self.experts = experts self.experts = experts
self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False
@ -266,7 +283,8 @@ class MoeLayer(nn.Module):
def forward(self, inputs: torch.Tensor) -> torch.Tensor: def forward(self, inputs: torch.Tensor) -> torch.Tensor:
tokens = inputs.reshape(-1, self.d_model) tokens = inputs.reshape(-1, self.d_model)
gate_output = self.gate(tokens) fp32_input = tokens.to(torch.float32) if inputs.dtype != torch.float32 else tokens
gate_output = self.gate(fp32_input)
router_res = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group) router_res = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group)
if self.use_kernel: if self.use_kernel:
@ -290,7 +308,7 @@ class MoeLayer(nn.Module):
expert_output = expert_output.reshape(-1, self.d_model) expert_output = expert_output.reshape(-1, self.d_model)
ans = MoeCombine.apply(expert_output, *router_res) ans = MoeCombine.apply(expert_output, *router_res)
else: else:
combine_weights = router_res[0] combine_weights = router_res[0].type_as(inputs)
combine_weights = combine_weights.view(combine_weights.shape[0], -1) combine_weights = combine_weights.view(combine_weights.shape[0], -1)
expert_output = expert_output.view(-1, expert_output.shape[-1]) expert_output = expert_output.view(-1, expert_output.shape[-1])
ans = torch.matmul(combine_weights, expert_output) ans = torch.matmul(combine_weights, expert_output)

View File

@ -1,10 +1,15 @@
import torch import torch
import torch.nn.functional as F
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.core import MOE_CONTEXT from colossalai.core import MOE_CONTEXT
from .experts import FFNExperts, TPExperts from .experts import FFNExperts, TPExperts
class ForceFP32Parameter(torch.nn.Parameter):
def half(self, memory_format=None):
return self
class NormalNoiseGenerator: class NormalNoiseGenerator:
"""Generates a random noisy mask for logtis tensor. """Generates a random noisy mask for logtis tensor.
@ -46,14 +51,6 @@ class UniformNoiseGenerator:
return inputs * noisy return inputs * noisy
def autocast_softmax(inputs: torch.Tensor, dim: int):
assert inputs.dtype in {torch.float16, torch.float32}
fp16_flag = (inputs.dtype == torch.float16)
sm_input = inputs.to(torch.float32) if fp16_flag else inputs
sm_output = F.softmax(sm_input, dim)
return sm_output
def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
mep_size = MOE_CONTEXT.max_ep_size mep_size = MOE_CONTEXT.max_ep_size
if num_experts % mep_size == 0 or mep_size % num_experts == 0: if num_experts % mep_size == 0 or mep_size % num_experts == 0: