[MOE] add FP32LinearGate for MOE in NaiveAMP context (#480)

pull/489/head
HELSON 2022-03-22 10:50:20 +08:00 committed by GitHub
parent 353566c198
commit d7ea63992b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 15 deletions

View File

@ -8,7 +8,7 @@ from colossalai.core import MOE_CONTEXT
from colossalai.utils import get_current_device
from ._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum
from .experts import MoeExperts
from .utils import autocast_softmax
from .utils import ForceFP32Parameter
from typing import Callable, Optional
from torch.distributed import ProcessGroup
@ -70,7 +70,7 @@ class Top1Router(nn.Module):
if self.noisy_func is not None and self.training:
inputs = self.noisy_func(inputs)
logits = autocast_softmax(inputs, dim=-1)
logits = F.softmax(inputs, dim=-1)
num_experts = logits.size(-1)
capacity = self.get_capacity(logits.shape)
@ -161,7 +161,7 @@ class Top2Router(nn.Module):
if self.noisy_func is not None and self.training:
inputs = self.noisy_func(inputs)
logits = autocast_softmax(inputs, dim=-1) # logits: [s, e]
logits = F.softmax(inputs, dim=-1) # logits: [s, e]
num_experts = logits.size(-1)
capacity = self.get_capacity(logits.shape)
@ -216,6 +216,23 @@ class Top2Router(nn.Module):
return cb_weight, sec_mask
class FP32LinearGate(nn.Linear):
"""Gate module used in MOE layer. Just a linear function without bias.
But it should be kept as fp32 forever.
Args:
d_model (int): Hidden dimension of training model
num_experts (int): The number experts
Attributes:
weight (ForceFP32Parameter): The weight of linear gate
"""
def __init__(self, d_model: int, num_experts: int):
super().__init__(d_model, num_experts, bias=False, device=get_current_device())
self.weight = ForceFP32Parameter(self.weight)
class MoeLayer(nn.Module):
"""A MoE layer, that puts its input tensor to its gate and uses the output logits
to router all tokens, is mainly used to exchange all tokens for every expert across
@ -237,7 +254,7 @@ class MoeLayer(nn.Module):
super().__init__()
self.d_model = dim_model
self.num_experts = num_experts
self.gate = nn.Linear(dim_model, num_experts, bias=False, device=get_current_device())
self.gate = FP32LinearGate(dim_model, num_experts)
self.router = router
self.experts = experts
self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False
@ -266,7 +283,8 @@ class MoeLayer(nn.Module):
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
tokens = inputs.reshape(-1, self.d_model)
gate_output = self.gate(tokens)
fp32_input = tokens.to(torch.float32) if inputs.dtype != torch.float32 else tokens
gate_output = self.gate(fp32_input)
router_res = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group)
if self.use_kernel:
@ -290,7 +308,7 @@ class MoeLayer(nn.Module):
expert_output = expert_output.reshape(-1, self.d_model)
ans = MoeCombine.apply(expert_output, *router_res)
else:
combine_weights = router_res[0]
combine_weights = router_res[0].type_as(inputs)
combine_weights = combine_weights.view(combine_weights.shape[0], -1)
expert_output = expert_output.view(-1, expert_output.shape[-1])
ans = torch.matmul(combine_weights, expert_output)

View File

@ -1,10 +1,15 @@
import torch
import torch.nn.functional as F
from colossalai.utils import get_current_device
from colossalai.core import MOE_CONTEXT
from .experts import FFNExperts, TPExperts
class ForceFP32Parameter(torch.nn.Parameter):
def half(self, memory_format=None):
return self
class NormalNoiseGenerator:
"""Generates a random noisy mask for logtis tensor.
@ -46,14 +51,6 @@ class UniformNoiseGenerator:
return inputs * noisy
def autocast_softmax(inputs: torch.Tensor, dim: int):
assert inputs.dtype in {torch.float16, torch.float32}
fp16_flag = (inputs.dtype == torch.float16)
sm_input = inputs.to(torch.float32) if fp16_flag else inputs
sm_output = F.softmax(sm_input, dim)
return sm_output
def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
mep_size = MOE_CONTEXT.max_ep_size
if num_experts % mep_size == 0 or mep_size % num_experts == 0: