mirror of https://github.com/hpcaitech/ColossalAI
lucasliunju
3 years ago
committed by
Frank Lee
6 changed files with 0 additions and 64 deletions
@ -1 +0,0 @@
|
||||
from .mlp_mixer import * |
@ -1,63 +0,0 @@
|
||||
# modified from https://github.com/lucidrains/mlp-mixer-pytorch/blob/main/mlp_mixer_pytorch/mlp_mixer_pytorch.py |
||||
from functools import partial |
||||
from colossalai.context import ParallelMode |
||||
from colossalai.registry import MODELS |
||||
from torch import nn |
||||
from colossalai import nn as col_nn |
||||
from colossalai.nn.layer.parallel_3d._utils import get_depth_from_env |
||||
from einops.layers.torch import Rearrange, Reduce |
||||
|
||||
__all__ = [ |
||||
'MLPMixer', |
||||
] |
||||
|
||||
|
||||
class PreNormResidual(nn.Module): |
||||
def __init__(self, dim, fn, depth_3d): |
||||
super().__init__() |
||||
self.fn = fn |
||||
self.norm = col_nn.LayerNorm3D( |
||||
dim, depth_3d, ParallelMode.PARALLEL_3D_INPUT, ParallelMode.PARALLEL_3D_WEIGHT) |
||||
|
||||
def forward(self, x): |
||||
return self.fn(self.norm(x)) + x |
||||
|
||||
|
||||
def FeedForward(dim, depth_3d, expansion_factor=4, dropout=0., dense=None): |
||||
if dense is None: |
||||
dense = partial(col_nn.Linear3D, depth=depth_3d, input_parallel_mode=ParallelMode.PARALLEL_3D_INPUT, |
||||
weight_parallel_mode=ParallelMode.PARALLEL_3D_WEIGHT) |
||||
return nn.Sequential( |
||||
dense(dim, dim * expansion_factor), |
||||
nn.GELU(), |
||||
nn.Dropout(dropout), |
||||
dense(dim * expansion_factor, dim), |
||||
nn.Dropout(dropout) |
||||
) |
||||
|
||||
|
||||
@MODELS.register_module |
||||
def MLPMixer(image_size, channels, patch_size, dim, depth, num_classes, expansion_factor=4, dropout=0.): |
||||
assert (image_size % patch_size) == 0, 'image must be divisible by patch size' |
||||
num_patches = (image_size // patch_size) ** 2 |
||||
depth_3d = get_depth_from_env() |
||||
linear = partial(col_nn.Linear3D, depth=depth_3d, input_parallel_mode=ParallelMode.PARALLEL_3D_INPUT, |
||||
weight_parallel_mode=ParallelMode.PARALLEL_3D_WEIGHT) |
||||
norm_layer = partial(col_nn.LayerNorm3D, depth=depth_3d, input_parallel_mode=ParallelMode.PARALLEL_3D_INPUT, |
||||
weight_parallel_mode=ParallelMode.PARALLEL_3D_WEIGHT) |
||||
chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), linear |
||||
|
||||
return nn.Sequential( |
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', |
||||
p1=patch_size, p2=patch_size), |
||||
linear((patch_size ** 2) * channels, dim), |
||||
*[nn.Sequential( |
||||
PreNormResidual(dim, FeedForward( |
||||
num_patches, expansion_factor, dropout, chan_first)), |
||||
PreNormResidual(dim, FeedForward( |
||||
dim, expansion_factor, dropout, chan_last)) |
||||
) for _ in range(depth)], |
||||
norm_layer(dim), |
||||
Reduce('b n c -> b c', 'mean'), |
||||
linear(dim, num_classes) |
||||
) |
Loading…
Reference in new issue