From d7ea63992bd2b9dd7de0ba03633d5345b17ac549 Mon Sep 17 00:00:00 2001 From: HELSON <72907851+1SAA@users.noreply.github.com> Date: Tue, 22 Mar 2022 10:50:20 +0800 Subject: [PATCH] [MOE] add FP32LinearGate for MOE in NaiveAMP context (#480) --- colossalai/nn/layer/moe/layers.py | 30 ++++++++++++++++++++++++------ colossalai/nn/layer/moe/utils.py | 15 ++++++--------- 2 files changed, 30 insertions(+), 15 deletions(-) diff --git a/colossalai/nn/layer/moe/layers.py b/colossalai/nn/layer/moe/layers.py index 39b23abed..7903b6286 100644 --- a/colossalai/nn/layer/moe/layers.py +++ b/colossalai/nn/layer/moe/layers.py @@ -8,7 +8,7 @@ from colossalai.core import MOE_CONTEXT from colossalai.utils import get_current_device from ._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum from .experts import MoeExperts -from .utils import autocast_softmax +from .utils import ForceFP32Parameter from typing import Callable, Optional from torch.distributed import ProcessGroup @@ -70,7 +70,7 @@ class Top1Router(nn.Module): if self.noisy_func is not None and self.training: inputs = self.noisy_func(inputs) - logits = autocast_softmax(inputs, dim=-1) + logits = F.softmax(inputs, dim=-1) num_experts = logits.size(-1) capacity = self.get_capacity(logits.shape) @@ -161,7 +161,7 @@ class Top2Router(nn.Module): if self.noisy_func is not None and self.training: inputs = self.noisy_func(inputs) - logits = autocast_softmax(inputs, dim=-1) # logits: [s, e] + logits = F.softmax(inputs, dim=-1) # logits: [s, e] num_experts = logits.size(-1) capacity = self.get_capacity(logits.shape) @@ -216,6 +216,23 @@ class Top2Router(nn.Module): return cb_weight, sec_mask +class FP32LinearGate(nn.Linear): + """Gate module used in MOE layer. Just a linear function without bias. + But it should be kept as fp32 forever. + + Args: + d_model (int): Hidden dimension of training model + num_experts (int): The number experts + + Attributes: + weight (ForceFP32Parameter): The weight of linear gate + """ + + def __init__(self, d_model: int, num_experts: int): + super().__init__(d_model, num_experts, bias=False, device=get_current_device()) + self.weight = ForceFP32Parameter(self.weight) + + class MoeLayer(nn.Module): """A MoE layer, that puts its input tensor to its gate and uses the output logits to router all tokens, is mainly used to exchange all tokens for every expert across @@ -237,7 +254,7 @@ class MoeLayer(nn.Module): super().__init__() self.d_model = dim_model self.num_experts = num_experts - self.gate = nn.Linear(dim_model, num_experts, bias=False, device=get_current_device()) + self.gate = FP32LinearGate(dim_model, num_experts) self.router = router self.experts = experts self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False @@ -266,7 +283,8 @@ class MoeLayer(nn.Module): def forward(self, inputs: torch.Tensor) -> torch.Tensor: tokens = inputs.reshape(-1, self.d_model) - gate_output = self.gate(tokens) + fp32_input = tokens.to(torch.float32) if inputs.dtype != torch.float32 else tokens + gate_output = self.gate(fp32_input) router_res = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group) if self.use_kernel: @@ -290,7 +308,7 @@ class MoeLayer(nn.Module): expert_output = expert_output.reshape(-1, self.d_model) ans = MoeCombine.apply(expert_output, *router_res) else: - combine_weights = router_res[0] + combine_weights = router_res[0].type_as(inputs) combine_weights = combine_weights.view(combine_weights.shape[0], -1) expert_output = expert_output.view(-1, expert_output.shape[-1]) ans = torch.matmul(combine_weights, expert_output) diff --git a/colossalai/nn/layer/moe/utils.py b/colossalai/nn/layer/moe/utils.py index 98f54cde7..ad9c99621 100644 --- a/colossalai/nn/layer/moe/utils.py +++ b/colossalai/nn/layer/moe/utils.py @@ -1,10 +1,15 @@ import torch -import torch.nn.functional as F from colossalai.utils import get_current_device from colossalai.core import MOE_CONTEXT from .experts import FFNExperts, TPExperts +class ForceFP32Parameter(torch.nn.Parameter): + + def half(self, memory_format=None): + return self + + class NormalNoiseGenerator: """Generates a random noisy mask for logtis tensor. @@ -46,14 +51,6 @@ class UniformNoiseGenerator: return inputs * noisy -def autocast_softmax(inputs: torch.Tensor, dim: int): - assert inputs.dtype in {torch.float16, torch.float32} - fp16_flag = (inputs.dtype == torch.float16) - sm_input = inputs.to(torch.float32) if fp16_flag else inputs - sm_output = F.softmax(sm_input, dim) - return sm_output - - def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): mep_size = MOE_CONTEXT.max_ep_size if num_experts % mep_size == 0 or mep_size % num_experts == 0: