added Multiply Jitter and capacity factor eval for MOE (#434)

pull/420/head^2
HELSON 2022-03-16 16:47:44 +08:00 committed by GitHub
parent b03b3ae99c
commit dbdc9a7783
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 92 additions and 27 deletions

View File

@ -11,6 +11,7 @@ from colossalai.utils import get_current_device
from ._operation import U_CUDA_MODE, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum
from .experts import MoeExperts
from .utils import autocast_softmax
from typing import Callable
class Top1Router(nn.Module):
@ -18,21 +19,35 @@ class Top1Router(nn.Module):
for routing usage. More deailted function can be found in the paper about Switch Transformer
of Google.
:param capacity_factor: Capacity factor in routing
:param capacity_factor_train: Capacity factor in routing of training
:param capacity_factor_eval: Capacity factor in routing of evaluation
:param min_capacity: The minimum number of the capacity of each expert
:param select_policy: The policy about tokens selection
:param noisy_func: Noisy function used in logits
:param drop_tks: Whether drops tokens in evaluation
:type capacity_factor: float
:type min_capacity: int
:type capacity_factor_train: float, optional
:type capacity_factor_eval: float, optional
:type min_capacity: int, optional
:type select_policy: str, optional
:type noisy_func: Callable, optional
:type drop_tks: bool, optional
"""
def __init__(self, capacity_factor: float, min_capacity: int = 0, select_policy: str = "first", noisy_func=None):
def __init__(self,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
select_policy: str = "first",
noisy_func: Callable = None,
drop_tks: bool = True):
super().__init__()
self.capacity_factor = capacity_factor
self.capacity_factor_train = capacity_factor_train
self.capacity_factor_eval = capacity_factor_eval
self.min_capacity = min_capacity
self.select_policy = select_policy
self.noisy_func = noisy_func
self.drop_tks = drop_tks
assert select_policy in {"first", "random"}
if select_policy == "random":
@ -44,7 +59,8 @@ class Top1Router(nn.Module):
self,
logits_shape,
):
capacity = math.floor(self.capacity_factor * logits_shape[-2] / logits_shape[-1])
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
capacity = math.floor(capacity_factor * logits_shape[-2] / logits_shape[-1])
capacity += capacity % 2
capacity = max(capacity, self.min_capacity)
assert capacity > 0
@ -53,15 +69,13 @@ class Top1Router(nn.Module):
def forward(self, inputs: torch.Tensor, cuda_mode: bool = False):
if self.noisy_func is not None and self.training:
inputs_noisy = self.noisy_func(inputs)
else:
inputs_noisy = inputs
inputs = self.noisy_func(inputs)
logits = autocast_softmax(inputs, dim=-1)
num_experts = logits.size(-1)
capacity = self.get_capacity(logits.shape)
top1_idx = torch.argmax(inputs_noisy, dim=-1)
top1_idx = torch.argmax(inputs, dim=-1)
mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
if self.training:
@ -69,14 +83,14 @@ class Top1Router(nn.Module):
ce = torch.mean(mask.float(), dim=0)
l_aux = num_experts * torch.sum(me * ce)
moe_env.add_loss(l_aux)
else:
elif not self.drop_tks:
max_num = torch.max(torch.sum(mask, dim=0))
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MOE_MODEL))
capacity = max_num.item()
else:
pass
if not self.training:
ranks = moe_cumsum(mask)
elif self.select_policy == "random":
if self.select_policy == "random":
rand_mask = mask * self.uniform(mask.shape)
_, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0)
mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1)
@ -106,21 +120,40 @@ class Top2Router(nn.Module):
"""Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
for routing usage. More deailted function can be found in the paper about ViT-MoE.
:param capacity_factor: Capacity factor in routing
:param capacity_factor_train: Capacity factor in routing of training
:param capacity_factor_eval: Capacity factor in routing of evaluation
:param min_capacity: The minimum number of the capacity of each expert
:param noisy_func: Noisy function used in logits
:param drop_tks: Whether drops tokens in evaluation
:type capacity_factor: float
:type capacity_factor_train: float, optional
:type capacity_factor_eval: float, optional
:type min_capacity: int, optional
:type noisy_func: Callable, optional
:type drop_tks: bool, optional
"""
def __init__(self, capacity_factor: float, noisy_func=None):
def __init__(self,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
noisy_func: Callable = None,
drop_tks: bool = True):
super().__init__()
self.capacity_factor = capacity_factor
self.capacity_factor_train = capacity_factor_train
self.capacity_factor_eval = capacity_factor_eval
self.min_capacity = min_capacity
self.noisy_func = noisy_func
self.drop_tks = drop_tks
def get_capacity(self, logits_shape):
capacity = math.floor(2 * self.capacity_factor * logits_shape[-2] / logits_shape[-1])
def get_capacity(
self,
logits_shape,
):
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
capacity = math.floor(capacity_factor * logits_shape[-2] / logits_shape[-1])
capacity += capacity % 2
capacity = max(capacity, self.min_capacity)
assert capacity > 0
return capacity
@ -143,12 +176,14 @@ class Top2Router(nn.Module):
if self.training:
me = torch.mean(logits, dim=0)
ce = torch.mean(cmask.float(), dim=0)
l_aux = num_experts * torch.sum(me * ce) / 2.0
l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1
moe_env.add_loss(l_aux)
else:
elif not self.drop_tks:
max_num = torch.max(torch.sum(cmask, dim=0))
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MOE_MODEL))
capacity = max_num.item()
else:
pass
rank1 = moe_cumsum(mask1) # rank1: [s, e]
rank2 = moe_cumsum(mask2)

View File

@ -25,6 +25,27 @@ class NormalNoiseGenerator:
return inputs + noisy
class UniformNoiseGenerator:
"""Generates a random noisy mask for logtis tensor.
copied from mesh tensorflow:
Multiply values by a random number between 1-epsilon and 1+epsilon.
Makes models more resilient to rounding errors introduced by bfloat16.
This seems particularly important for logits.
:param eps: Epsilon in generator
:type eps: float
"""
def __init__(self, eps: float):
self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(1.0 - eps, device=get_current_device()),
high=torch.tensor(1.0 + eps,
device=get_current_device())).rsample
def __call__(self, inputs: torch.Tensor):
noisy = self.uniform(inputs.shape)
return inputs * noisy
def autocast_softmax(inputs: torch.Tensor, dim: int):
assert inputs.dtype in {torch.float16, torch.float32}
fp16_flag = (inputs.dtype == torch.float16)

View File

@ -84,7 +84,9 @@ class Widenet(nn.Module):
def __init__(self,
num_experts: int,
capacity_factor: float,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
drop_tks: bool = True,
img_size: int = 224,
patch_size: int = 16,
in_chans: int = 3,
@ -109,7 +111,10 @@ class Widenet(nn.Module):
d_model=d_model, n_heads=num_heads, d_kv=d_kv, attention_drop=attention_drop, drop_rate=drop_rate))
noisy_func = NormalNoiseGenerator(num_experts)
shared_router = Top2Router(capacity_factor, noisy_func=noisy_func)
shared_router = Top2Router(capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
noisy_func=noisy_func,
drop_tks=drop_tks)
shared_experts = build_ffn_experts(num_experts, d_model, d_ff, drop_rate=drop_rate)
# stochastic depth decay rule
@ -142,7 +147,9 @@ class ViTMoE(nn.Module):
def __init__(self,
num_experts: int,
capacity_factor: float,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
drop_tks: bool = True,
img_size: int = 224,
patch_size: int = 16,
in_chans: int = 3,
@ -164,8 +171,10 @@ class ViTMoE(nn.Module):
embed_dropout = Dropout(p=drop_rate, mode=ParallelMode.TENSOR)
noisy_func = NormalNoiseGenerator(num_experts)
router = Top2Router(capacity_factor, noisy_func=noisy_func)
router = Top2Router(capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
noisy_func=noisy_func,
drop_tks=drop_tks)
assert depth % 2 == 0
# stochastic depth decay rule