From dbdc9a778389128e51601806d92c756657dad67c Mon Sep 17 00:00:00 2001 From: HELSON <72907851+1SAA@users.noreply.github.com> Date: Wed, 16 Mar 2022 16:47:44 +0800 Subject: [PATCH] added Multiply Jitter and capacity factor eval for MOE (#434) --- colossalai/nn/layer/moe/layers.py | 79 ++++++++++++++++++++++--------- colossalai/nn/layer/moe/utils.py | 21 ++++++++ model_zoo/moe/models.py | 19 ++++++-- 3 files changed, 92 insertions(+), 27 deletions(-) diff --git a/colossalai/nn/layer/moe/layers.py b/colossalai/nn/layer/moe/layers.py index 49a9645bc..f98e0764e 100644 --- a/colossalai/nn/layer/moe/layers.py +++ b/colossalai/nn/layer/moe/layers.py @@ -11,6 +11,7 @@ from colossalai.utils import get_current_device from ._operation import U_CUDA_MODE, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum from .experts import MoeExperts from .utils import autocast_softmax +from typing import Callable class Top1Router(nn.Module): @@ -18,21 +19,35 @@ class Top1Router(nn.Module): for routing usage. More deailted function can be found in the paper about Switch Transformer of Google. - :param capacity_factor: Capacity factor in routing + :param capacity_factor_train: Capacity factor in routing of training + :param capacity_factor_eval: Capacity factor in routing of evaluation :param min_capacity: The minimum number of the capacity of each expert + :param select_policy: The policy about tokens selection :param noisy_func: Noisy function used in logits + :param drop_tks: Whether drops tokens in evaluation - :type capacity_factor: float - :type min_capacity: int + :type capacity_factor_train: float, optional + :type capacity_factor_eval: float, optional + :type min_capacity: int, optional + :type select_policy: str, optional :type noisy_func: Callable, optional + :type drop_tks: bool, optional """ - def __init__(self, capacity_factor: float, min_capacity: int = 0, select_policy: str = "first", noisy_func=None): + def __init__(self, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + select_policy: str = "first", + noisy_func: Callable = None, + drop_tks: bool = True): super().__init__() - self.capacity_factor = capacity_factor + self.capacity_factor_train = capacity_factor_train + self.capacity_factor_eval = capacity_factor_eval self.min_capacity = min_capacity self.select_policy = select_policy self.noisy_func = noisy_func + self.drop_tks = drop_tks assert select_policy in {"first", "random"} if select_policy == "random": @@ -44,7 +59,8 @@ class Top1Router(nn.Module): self, logits_shape, ): - capacity = math.floor(self.capacity_factor * logits_shape[-2] / logits_shape[-1]) + capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval + capacity = math.floor(capacity_factor * logits_shape[-2] / logits_shape[-1]) capacity += capacity % 2 capacity = max(capacity, self.min_capacity) assert capacity > 0 @@ -53,15 +69,13 @@ class Top1Router(nn.Module): def forward(self, inputs: torch.Tensor, cuda_mode: bool = False): if self.noisy_func is not None and self.training: - inputs_noisy = self.noisy_func(inputs) - else: - inputs_noisy = inputs + inputs = self.noisy_func(inputs) logits = autocast_softmax(inputs, dim=-1) num_experts = logits.size(-1) capacity = self.get_capacity(logits.shape) - top1_idx = torch.argmax(inputs_noisy, dim=-1) + top1_idx = torch.argmax(inputs, dim=-1) mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) if self.training: @@ -69,14 +83,14 @@ class Top1Router(nn.Module): ce = torch.mean(mask.float(), dim=0) l_aux = num_experts * torch.sum(me * ce) moe_env.add_loss(l_aux) - else: + elif not self.drop_tks: max_num = torch.max(torch.sum(mask, dim=0)) dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MOE_MODEL)) capacity = max_num.item() + else: + pass - if not self.training: - ranks = moe_cumsum(mask) - elif self.select_policy == "random": + if self.select_policy == "random": rand_mask = mask * self.uniform(mask.shape) _, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0) mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1) @@ -106,21 +120,40 @@ class Top2Router(nn.Module): """Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] for routing usage. More deailted function can be found in the paper about ViT-MoE. - :param capacity_factor: Capacity factor in routing + :param capacity_factor_train: Capacity factor in routing of training + :param capacity_factor_eval: Capacity factor in routing of evaluation + :param min_capacity: The minimum number of the capacity of each expert :param noisy_func: Noisy function used in logits + :param drop_tks: Whether drops tokens in evaluation - :type capacity_factor: float + :type capacity_factor_train: float, optional + :type capacity_factor_eval: float, optional + :type min_capacity: int, optional :type noisy_func: Callable, optional + :type drop_tks: bool, optional """ - def __init__(self, capacity_factor: float, noisy_func=None): + def __init__(self, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + noisy_func: Callable = None, + drop_tks: bool = True): super().__init__() - self.capacity_factor = capacity_factor + self.capacity_factor_train = capacity_factor_train + self.capacity_factor_eval = capacity_factor_eval + self.min_capacity = min_capacity self.noisy_func = noisy_func + self.drop_tks = drop_tks - def get_capacity(self, logits_shape): - capacity = math.floor(2 * self.capacity_factor * logits_shape[-2] / logits_shape[-1]) + def get_capacity( + self, + logits_shape, + ): + capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval + capacity = math.floor(capacity_factor * logits_shape[-2] / logits_shape[-1]) capacity += capacity % 2 + capacity = max(capacity, self.min_capacity) assert capacity > 0 return capacity @@ -143,12 +176,14 @@ class Top2Router(nn.Module): if self.training: me = torch.mean(logits, dim=0) ce = torch.mean(cmask.float(), dim=0) - l_aux = num_experts * torch.sum(me * ce) / 2.0 + l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1 moe_env.add_loss(l_aux) - else: + elif not self.drop_tks: max_num = torch.max(torch.sum(cmask, dim=0)) dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MOE_MODEL)) capacity = max_num.item() + else: + pass rank1 = moe_cumsum(mask1) # rank1: [s, e] rank2 = moe_cumsum(mask2) diff --git a/colossalai/nn/layer/moe/utils.py b/colossalai/nn/layer/moe/utils.py index 4fa090662..37c57f396 100644 --- a/colossalai/nn/layer/moe/utils.py +++ b/colossalai/nn/layer/moe/utils.py @@ -25,6 +25,27 @@ class NormalNoiseGenerator: return inputs + noisy +class UniformNoiseGenerator: + """Generates a random noisy mask for logtis tensor. + copied from mesh tensorflow: + Multiply values by a random number between 1-epsilon and 1+epsilon. + Makes models more resilient to rounding errors introduced by bfloat16. + This seems particularly important for logits. + + :param eps: Epsilon in generator + :type eps: float + """ + + def __init__(self, eps: float): + self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(1.0 - eps, device=get_current_device()), + high=torch.tensor(1.0 + eps, + device=get_current_device())).rsample + + def __call__(self, inputs: torch.Tensor): + noisy = self.uniform(inputs.shape) + return inputs * noisy + + def autocast_softmax(inputs: torch.Tensor, dim: int): assert inputs.dtype in {torch.float16, torch.float32} fp16_flag = (inputs.dtype == torch.float16) diff --git a/model_zoo/moe/models.py b/model_zoo/moe/models.py index 277e377c7..40d288c4b 100644 --- a/model_zoo/moe/models.py +++ b/model_zoo/moe/models.py @@ -84,7 +84,9 @@ class Widenet(nn.Module): def __init__(self, num_experts: int, - capacity_factor: float, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + drop_tks: bool = True, img_size: int = 224, patch_size: int = 16, in_chans: int = 3, @@ -109,7 +111,10 @@ class Widenet(nn.Module): d_model=d_model, n_heads=num_heads, d_kv=d_kv, attention_drop=attention_drop, drop_rate=drop_rate)) noisy_func = NormalNoiseGenerator(num_experts) - shared_router = Top2Router(capacity_factor, noisy_func=noisy_func) + shared_router = Top2Router(capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + noisy_func=noisy_func, + drop_tks=drop_tks) shared_experts = build_ffn_experts(num_experts, d_model, d_ff, drop_rate=drop_rate) # stochastic depth decay rule @@ -142,7 +147,9 @@ class ViTMoE(nn.Module): def __init__(self, num_experts: int, - capacity_factor: float, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + drop_tks: bool = True, img_size: int = 224, patch_size: int = 16, in_chans: int = 3, @@ -164,8 +171,10 @@ class ViTMoE(nn.Module): embed_dropout = Dropout(p=drop_rate, mode=ParallelMode.TENSOR) noisy_func = NormalNoiseGenerator(num_experts) - router = Top2Router(capacity_factor, noisy_func=noisy_func) - + router = Top2Router(capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + noisy_func=noisy_func, + drop_tks=drop_tks) assert depth % 2 == 0 # stochastic depth decay rule