added Multiply Jitter and capacity factor eval for MOE (#434)

pull/420/head^2
HELSON 2022-03-16 16:47:44 +08:00 committed by GitHub
parent b03b3ae99c
commit dbdc9a7783
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 92 additions and 27 deletions

View File

@ -11,6 +11,7 @@ from colossalai.utils import get_current_device
from ._operation import U_CUDA_MODE, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum from ._operation import U_CUDA_MODE, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum
from .experts import MoeExperts from .experts import MoeExperts
from .utils import autocast_softmax from .utils import autocast_softmax
from typing import Callable
class Top1Router(nn.Module): class Top1Router(nn.Module):
@ -18,21 +19,35 @@ class Top1Router(nn.Module):
for routing usage. More deailted function can be found in the paper about Switch Transformer for routing usage. More deailted function can be found in the paper about Switch Transformer
of Google. of Google.
:param capacity_factor: Capacity factor in routing :param capacity_factor_train: Capacity factor in routing of training
:param capacity_factor_eval: Capacity factor in routing of evaluation
:param min_capacity: The minimum number of the capacity of each expert :param min_capacity: The minimum number of the capacity of each expert
:param select_policy: The policy about tokens selection
:param noisy_func: Noisy function used in logits :param noisy_func: Noisy function used in logits
:param drop_tks: Whether drops tokens in evaluation
:type capacity_factor: float :type capacity_factor_train: float, optional
:type min_capacity: int :type capacity_factor_eval: float, optional
:type min_capacity: int, optional
:type select_policy: str, optional
:type noisy_func: Callable, optional :type noisy_func: Callable, optional
:type drop_tks: bool, optional
""" """
def __init__(self, capacity_factor: float, min_capacity: int = 0, select_policy: str = "first", noisy_func=None): def __init__(self,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
select_policy: str = "first",
noisy_func: Callable = None,
drop_tks: bool = True):
super().__init__() super().__init__()
self.capacity_factor = capacity_factor self.capacity_factor_train = capacity_factor_train
self.capacity_factor_eval = capacity_factor_eval
self.min_capacity = min_capacity self.min_capacity = min_capacity
self.select_policy = select_policy self.select_policy = select_policy
self.noisy_func = noisy_func self.noisy_func = noisy_func
self.drop_tks = drop_tks
assert select_policy in {"first", "random"} assert select_policy in {"first", "random"}
if select_policy == "random": if select_policy == "random":
@ -44,7 +59,8 @@ class Top1Router(nn.Module):
self, self,
logits_shape, logits_shape,
): ):
capacity = math.floor(self.capacity_factor * logits_shape[-2] / logits_shape[-1]) capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
capacity = math.floor(capacity_factor * logits_shape[-2] / logits_shape[-1])
capacity += capacity % 2 capacity += capacity % 2
capacity = max(capacity, self.min_capacity) capacity = max(capacity, self.min_capacity)
assert capacity > 0 assert capacity > 0
@ -53,15 +69,13 @@ class Top1Router(nn.Module):
def forward(self, inputs: torch.Tensor, cuda_mode: bool = False): def forward(self, inputs: torch.Tensor, cuda_mode: bool = False):
if self.noisy_func is not None and self.training: if self.noisy_func is not None and self.training:
inputs_noisy = self.noisy_func(inputs) inputs = self.noisy_func(inputs)
else:
inputs_noisy = inputs
logits = autocast_softmax(inputs, dim=-1) logits = autocast_softmax(inputs, dim=-1)
num_experts = logits.size(-1) num_experts = logits.size(-1)
capacity = self.get_capacity(logits.shape) capacity = self.get_capacity(logits.shape)
top1_idx = torch.argmax(inputs_noisy, dim=-1) top1_idx = torch.argmax(inputs, dim=-1)
mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
if self.training: if self.training:
@ -69,14 +83,14 @@ class Top1Router(nn.Module):
ce = torch.mean(mask.float(), dim=0) ce = torch.mean(mask.float(), dim=0)
l_aux = num_experts * torch.sum(me * ce) l_aux = num_experts * torch.sum(me * ce)
moe_env.add_loss(l_aux) moe_env.add_loss(l_aux)
else: elif not self.drop_tks:
max_num = torch.max(torch.sum(mask, dim=0)) max_num = torch.max(torch.sum(mask, dim=0))
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MOE_MODEL)) dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MOE_MODEL))
capacity = max_num.item() capacity = max_num.item()
else:
pass
if not self.training: if self.select_policy == "random":
ranks = moe_cumsum(mask)
elif self.select_policy == "random":
rand_mask = mask * self.uniform(mask.shape) rand_mask = mask * self.uniform(mask.shape)
_, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0) _, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0)
mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1) mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1)
@ -106,21 +120,40 @@ class Top2Router(nn.Module):
"""Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] """Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
for routing usage. More deailted function can be found in the paper about ViT-MoE. for routing usage. More deailted function can be found in the paper about ViT-MoE.
:param capacity_factor: Capacity factor in routing :param capacity_factor_train: Capacity factor in routing of training
:param capacity_factor_eval: Capacity factor in routing of evaluation
:param min_capacity: The minimum number of the capacity of each expert
:param noisy_func: Noisy function used in logits :param noisy_func: Noisy function used in logits
:param drop_tks: Whether drops tokens in evaluation
:type capacity_factor: float :type capacity_factor_train: float, optional
:type capacity_factor_eval: float, optional
:type min_capacity: int, optional
:type noisy_func: Callable, optional :type noisy_func: Callable, optional
:type drop_tks: bool, optional
""" """
def __init__(self, capacity_factor: float, noisy_func=None): def __init__(self,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
noisy_func: Callable = None,
drop_tks: bool = True):
super().__init__() super().__init__()
self.capacity_factor = capacity_factor self.capacity_factor_train = capacity_factor_train
self.capacity_factor_eval = capacity_factor_eval
self.min_capacity = min_capacity
self.noisy_func = noisy_func self.noisy_func = noisy_func
self.drop_tks = drop_tks
def get_capacity(self, logits_shape): def get_capacity(
capacity = math.floor(2 * self.capacity_factor * logits_shape[-2] / logits_shape[-1]) self,
logits_shape,
):
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
capacity = math.floor(capacity_factor * logits_shape[-2] / logits_shape[-1])
capacity += capacity % 2 capacity += capacity % 2
capacity = max(capacity, self.min_capacity)
assert capacity > 0 assert capacity > 0
return capacity return capacity
@ -143,12 +176,14 @@ class Top2Router(nn.Module):
if self.training: if self.training:
me = torch.mean(logits, dim=0) me = torch.mean(logits, dim=0)
ce = torch.mean(cmask.float(), dim=0) ce = torch.mean(cmask.float(), dim=0)
l_aux = num_experts * torch.sum(me * ce) / 2.0 l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1
moe_env.add_loss(l_aux) moe_env.add_loss(l_aux)
else: elif not self.drop_tks:
max_num = torch.max(torch.sum(cmask, dim=0)) max_num = torch.max(torch.sum(cmask, dim=0))
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MOE_MODEL)) dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MOE_MODEL))
capacity = max_num.item() capacity = max_num.item()
else:
pass
rank1 = moe_cumsum(mask1) # rank1: [s, e] rank1 = moe_cumsum(mask1) # rank1: [s, e]
rank2 = moe_cumsum(mask2) rank2 = moe_cumsum(mask2)

View File

@ -25,6 +25,27 @@ class NormalNoiseGenerator:
return inputs + noisy return inputs + noisy
class UniformNoiseGenerator:
"""Generates a random noisy mask for logtis tensor.
copied from mesh tensorflow:
Multiply values by a random number between 1-epsilon and 1+epsilon.
Makes models more resilient to rounding errors introduced by bfloat16.
This seems particularly important for logits.
:param eps: Epsilon in generator
:type eps: float
"""
def __init__(self, eps: float):
self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(1.0 - eps, device=get_current_device()),
high=torch.tensor(1.0 + eps,
device=get_current_device())).rsample
def __call__(self, inputs: torch.Tensor):
noisy = self.uniform(inputs.shape)
return inputs * noisy
def autocast_softmax(inputs: torch.Tensor, dim: int): def autocast_softmax(inputs: torch.Tensor, dim: int):
assert inputs.dtype in {torch.float16, torch.float32} assert inputs.dtype in {torch.float16, torch.float32}
fp16_flag = (inputs.dtype == torch.float16) fp16_flag = (inputs.dtype == torch.float16)

View File

@ -84,7 +84,9 @@ class Widenet(nn.Module):
def __init__(self, def __init__(self,
num_experts: int, num_experts: int,
capacity_factor: float, capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
drop_tks: bool = True,
img_size: int = 224, img_size: int = 224,
patch_size: int = 16, patch_size: int = 16,
in_chans: int = 3, in_chans: int = 3,
@ -109,7 +111,10 @@ class Widenet(nn.Module):
d_model=d_model, n_heads=num_heads, d_kv=d_kv, attention_drop=attention_drop, drop_rate=drop_rate)) d_model=d_model, n_heads=num_heads, d_kv=d_kv, attention_drop=attention_drop, drop_rate=drop_rate))
noisy_func = NormalNoiseGenerator(num_experts) noisy_func = NormalNoiseGenerator(num_experts)
shared_router = Top2Router(capacity_factor, noisy_func=noisy_func) shared_router = Top2Router(capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
noisy_func=noisy_func,
drop_tks=drop_tks)
shared_experts = build_ffn_experts(num_experts, d_model, d_ff, drop_rate=drop_rate) shared_experts = build_ffn_experts(num_experts, d_model, d_ff, drop_rate=drop_rate)
# stochastic depth decay rule # stochastic depth decay rule
@ -142,7 +147,9 @@ class ViTMoE(nn.Module):
def __init__(self, def __init__(self,
num_experts: int, num_experts: int,
capacity_factor: float, capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
drop_tks: bool = True,
img_size: int = 224, img_size: int = 224,
patch_size: int = 16, patch_size: int = 16,
in_chans: int = 3, in_chans: int = 3,
@ -164,8 +171,10 @@ class ViTMoE(nn.Module):
embed_dropout = Dropout(p=drop_rate, mode=ParallelMode.TENSOR) embed_dropout = Dropout(p=drop_rate, mode=ParallelMode.TENSOR)
noisy_func = NormalNoiseGenerator(num_experts) noisy_func = NormalNoiseGenerator(num_experts)
router = Top2Router(capacity_factor, noisy_func=noisy_func) router = Top2Router(capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
noisy_func=noisy_func,
drop_tks=drop_tks)
assert depth % 2 == 0 assert depth % 2 == 0
# stochastic depth decay rule # stochastic depth decay rule