|
|
@ -4,11 +4,12 @@ import torch.nn as nn |
|
|
|
from colossalai.context import ParallelMode |
|
|
|
from colossalai.context import ParallelMode |
|
|
|
from colossalai.nn.layer import VanillaPatchEmbedding, VanillaClassifier, \ |
|
|
|
from colossalai.nn.layer import VanillaPatchEmbedding, VanillaClassifier, \ |
|
|
|
WrappedDropout as Dropout, WrappedDropPath as DropPath |
|
|
|
WrappedDropout as Dropout, WrappedDropPath as DropPath |
|
|
|
from colossalai.nn.layer.moe import build_ffn_experts, MoeLayer, Top2Router, NormalNoiseGenerator |
|
|
|
from colossalai.nn.layer.moe import build_ffn_experts, MoeLayer, Top2Router, NormalNoiseGenerator, MoeModule |
|
|
|
from .util import moe_sa_args, moe_mlp_args |
|
|
|
from .util import moe_sa_args, moe_mlp_args |
|
|
|
from ..helper import TransformerLayer |
|
|
|
from ..helper import TransformerLayer |
|
|
|
from colossalai.core import MOE_CONTEXT |
|
|
|
from colossalai.core import MOE_CONTEXT |
|
|
|
from colossalai.utils import get_current_device |
|
|
|
from colossalai.utils import get_current_device |
|
|
|
|
|
|
|
from typing import List |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VanillaSelfAttention(nn.Module): |
|
|
|
class VanillaSelfAttention(nn.Module): |
|
|
@ -146,7 +147,8 @@ class Widenet(nn.Module): |
|
|
|
class ViTMoE(nn.Module): |
|
|
|
class ViTMoE(nn.Module): |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, |
|
|
|
def __init__(self, |
|
|
|
num_experts: int, |
|
|
|
num_experts: int or List[int], |
|
|
|
|
|
|
|
use_residual: bool = False, |
|
|
|
capacity_factor_train: float = 1.25, |
|
|
|
capacity_factor_train: float = 1.25, |
|
|
|
capacity_factor_eval: float = 2.0, |
|
|
|
capacity_factor_eval: float = 2.0, |
|
|
|
drop_tks: bool = True, |
|
|
|
drop_tks: bool = True, |
|
|
@ -164,29 +166,45 @@ class ViTMoE(nn.Module): |
|
|
|
drop_path: float = 0.): |
|
|
|
drop_path: float = 0.): |
|
|
|
super().__init__() |
|
|
|
super().__init__() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert depth % 2 == 0, "The number of layers should be even right now" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(num_experts, list): |
|
|
|
|
|
|
|
assert len(num_experts) == depth // 2, \ |
|
|
|
|
|
|
|
"The length of num_experts should equal to the number of MOE layers" |
|
|
|
|
|
|
|
num_experts_list = num_experts |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
num_experts_list = [num_experts] * (depth // 2) |
|
|
|
|
|
|
|
|
|
|
|
embedding = VanillaPatchEmbedding(img_size=img_size, |
|
|
|
embedding = VanillaPatchEmbedding(img_size=img_size, |
|
|
|
patch_size=patch_size, |
|
|
|
patch_size=patch_size, |
|
|
|
in_chans=in_chans, |
|
|
|
in_chans=in_chans, |
|
|
|
embed_size=d_model) |
|
|
|
embed_size=d_model) |
|
|
|
embed_dropout = Dropout(p=drop_rate, mode=ParallelMode.TENSOR) |
|
|
|
embed_dropout = Dropout(p=drop_rate, mode=ParallelMode.TENSOR) |
|
|
|
|
|
|
|
|
|
|
|
noisy_func = NormalNoiseGenerator(num_experts) |
|
|
|
|
|
|
|
router = Top2Router(capacity_factor_train=capacity_factor_train, |
|
|
|
|
|
|
|
capacity_factor_eval=capacity_factor_eval, |
|
|
|
|
|
|
|
noisy_func=noisy_func, |
|
|
|
|
|
|
|
drop_tks=drop_tks) |
|
|
|
|
|
|
|
assert depth % 2 == 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# stochastic depth decay rule |
|
|
|
# stochastic depth decay rule |
|
|
|
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)] |
|
|
|
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)] |
|
|
|
blocks = [] |
|
|
|
blocks = [] |
|
|
|
for i in range(depth): |
|
|
|
for i in range(depth): |
|
|
|
sa = VanillaSelfAttention(**moe_sa_args( |
|
|
|
sa = VanillaSelfAttention(**moe_sa_args( |
|
|
|
d_model=d_model, n_heads=num_heads, d_kv=d_kv, attention_drop=attention_drop, drop_rate=drop_rate)) |
|
|
|
d_model=d_model, n_heads=num_heads, d_kv=d_kv, attention_drop=attention_drop, drop_rate=drop_rate)) |
|
|
|
ffn = VanillaFFN(**moe_mlp_args( |
|
|
|
|
|
|
|
d_model=d_model, d_ff=d_ff, drop_rate=drop_rate)) if i % 2 == 0 else \ |
|
|
|
if i % 2 == 0: |
|
|
|
MoeLayer(dim_model=d_model, num_experts=num_experts, router=router, |
|
|
|
ffn = VanillaFFN(**moe_mlp_args(d_model=d_model, d_ff=d_ff, drop_rate=drop_rate)) |
|
|
|
experts=build_ffn_experts(num_experts, d_model, d_ff, drop_rate=drop_rate)) |
|
|
|
else: |
|
|
|
|
|
|
|
num_experts = num_experts_list[i // 2] |
|
|
|
|
|
|
|
experts = build_ffn_experts(num_experts, d_model, d_ff, drop_rate=drop_rate) |
|
|
|
|
|
|
|
ffn = MoeModule(dim_model=d_model, |
|
|
|
|
|
|
|
num_experts=num_experts, |
|
|
|
|
|
|
|
top_k=1 if use_residual else 2, |
|
|
|
|
|
|
|
capacity_factor_train=capacity_factor_train, |
|
|
|
|
|
|
|
capacity_factor_eval=capacity_factor_eval, |
|
|
|
|
|
|
|
noisy_policy='Jitter' if use_residual else 'Gaussian', |
|
|
|
|
|
|
|
drop_tks=drop_tks, |
|
|
|
|
|
|
|
use_residual=use_residual, |
|
|
|
|
|
|
|
expert_instance=experts, |
|
|
|
|
|
|
|
expert_cls=VanillaFFN, |
|
|
|
|
|
|
|
**moe_mlp_args(d_model=d_model, d_ff=d_ff, drop_rate=drop_rate)) |
|
|
|
|
|
|
|
|
|
|
|
layer = TransformerLayer(att=sa, |
|
|
|
layer = TransformerLayer(att=sa, |
|
|
|
ffn=ffn, |
|
|
|
ffn=ffn, |
|
|
|
norm1=nn.LayerNorm(d_model, eps=1e-6), |
|
|
|
norm1=nn.LayerNorm(d_model, eps=1e-6), |
|
|
|