mirror of https://github.com/hpcaitech/ColossalAI
[MOE] add FP32LinearGate for MOE in NaiveAMP context (#480)
parent
353566c198
commit
d7ea63992b
|
@ -8,7 +8,7 @@ from colossalai.core import MOE_CONTEXT
|
|||
from colossalai.utils import get_current_device
|
||||
from ._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum
|
||||
from .experts import MoeExperts
|
||||
from .utils import autocast_softmax
|
||||
from .utils import ForceFP32Parameter
|
||||
from typing import Callable, Optional
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
|
@ -70,7 +70,7 @@ class Top1Router(nn.Module):
|
|||
if self.noisy_func is not None and self.training:
|
||||
inputs = self.noisy_func(inputs)
|
||||
|
||||
logits = autocast_softmax(inputs, dim=-1)
|
||||
logits = F.softmax(inputs, dim=-1)
|
||||
num_experts = logits.size(-1)
|
||||
capacity = self.get_capacity(logits.shape)
|
||||
|
||||
|
@ -161,7 +161,7 @@ class Top2Router(nn.Module):
|
|||
if self.noisy_func is not None and self.training:
|
||||
inputs = self.noisy_func(inputs)
|
||||
|
||||
logits = autocast_softmax(inputs, dim=-1) # logits: [s, e]
|
||||
logits = F.softmax(inputs, dim=-1) # logits: [s, e]
|
||||
num_experts = logits.size(-1)
|
||||
capacity = self.get_capacity(logits.shape)
|
||||
|
||||
|
@ -216,6 +216,23 @@ class Top2Router(nn.Module):
|
|||
return cb_weight, sec_mask
|
||||
|
||||
|
||||
class FP32LinearGate(nn.Linear):
|
||||
"""Gate module used in MOE layer. Just a linear function without bias.
|
||||
But it should be kept as fp32 forever.
|
||||
|
||||
Args:
|
||||
d_model (int): Hidden dimension of training model
|
||||
num_experts (int): The number experts
|
||||
|
||||
Attributes:
|
||||
weight (ForceFP32Parameter): The weight of linear gate
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, num_experts: int):
|
||||
super().__init__(d_model, num_experts, bias=False, device=get_current_device())
|
||||
self.weight = ForceFP32Parameter(self.weight)
|
||||
|
||||
|
||||
class MoeLayer(nn.Module):
|
||||
"""A MoE layer, that puts its input tensor to its gate and uses the output logits
|
||||
to router all tokens, is mainly used to exchange all tokens for every expert across
|
||||
|
@ -237,7 +254,7 @@ class MoeLayer(nn.Module):
|
|||
super().__init__()
|
||||
self.d_model = dim_model
|
||||
self.num_experts = num_experts
|
||||
self.gate = nn.Linear(dim_model, num_experts, bias=False, device=get_current_device())
|
||||
self.gate = FP32LinearGate(dim_model, num_experts)
|
||||
self.router = router
|
||||
self.experts = experts
|
||||
self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False
|
||||
|
@ -266,7 +283,8 @@ class MoeLayer(nn.Module):
|
|||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
tokens = inputs.reshape(-1, self.d_model)
|
||||
gate_output = self.gate(tokens)
|
||||
fp32_input = tokens.to(torch.float32) if inputs.dtype != torch.float32 else tokens
|
||||
gate_output = self.gate(fp32_input)
|
||||
router_res = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group)
|
||||
|
||||
if self.use_kernel:
|
||||
|
@ -290,7 +308,7 @@ class MoeLayer(nn.Module):
|
|||
expert_output = expert_output.reshape(-1, self.d_model)
|
||||
ans = MoeCombine.apply(expert_output, *router_res)
|
||||
else:
|
||||
combine_weights = router_res[0]
|
||||
combine_weights = router_res[0].type_as(inputs)
|
||||
combine_weights = combine_weights.view(combine_weights.shape[0], -1)
|
||||
expert_output = expert_output.view(-1, expert_output.shape[-1])
|
||||
ans = torch.matmul(combine_weights, expert_output)
|
||||
|
|
|
@ -1,10 +1,15 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.core import MOE_CONTEXT
|
||||
from .experts import FFNExperts, TPExperts
|
||||
|
||||
|
||||
class ForceFP32Parameter(torch.nn.Parameter):
|
||||
|
||||
def half(self, memory_format=None):
|
||||
return self
|
||||
|
||||
|
||||
class NormalNoiseGenerator:
|
||||
"""Generates a random noisy mask for logtis tensor.
|
||||
|
||||
|
@ -46,14 +51,6 @@ class UniformNoiseGenerator:
|
|||
return inputs * noisy
|
||||
|
||||
|
||||
def autocast_softmax(inputs: torch.Tensor, dim: int):
|
||||
assert inputs.dtype in {torch.float16, torch.float32}
|
||||
fp16_flag = (inputs.dtype == torch.float16)
|
||||
sm_input = inputs.to(torch.float32) if fp16_flag else inputs
|
||||
sm_output = F.softmax(sm_input, dim)
|
||||
return sm_output
|
||||
|
||||
|
||||
def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
|
||||
mep_size = MOE_CONTEXT.max_ep_size
|
||||
if num_experts % mep_size == 0 or mep_size % num_experts == 0:
|
||||
|
|
Loading…
Reference in New Issue