mirror of https://github.com/hpcaitech/ColossalAI
fix format (#374)
parent
526a318032
commit
ce886a9062
|
@ -1 +0,0 @@
|
||||||
from .mlp_mixer import *
|
|
|
@ -1,63 +0,0 @@
|
||||||
# modified from https://github.com/lucidrains/mlp-mixer-pytorch/blob/main/mlp_mixer_pytorch/mlp_mixer_pytorch.py
|
|
||||||
from functools import partial
|
|
||||||
from colossalai.context import ParallelMode
|
|
||||||
from colossalai.registry import MODELS
|
|
||||||
from torch import nn
|
|
||||||
from colossalai import nn as col_nn
|
|
||||||
from colossalai.nn.layer.parallel_3d._utils import get_depth_from_env
|
|
||||||
from einops.layers.torch import Rearrange, Reduce
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'MLPMixer',
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class PreNormResidual(nn.Module):
|
|
||||||
def __init__(self, dim, fn, depth_3d):
|
|
||||||
super().__init__()
|
|
||||||
self.fn = fn
|
|
||||||
self.norm = col_nn.LayerNorm3D(
|
|
||||||
dim, depth_3d, ParallelMode.PARALLEL_3D_INPUT, ParallelMode.PARALLEL_3D_WEIGHT)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.fn(self.norm(x)) + x
|
|
||||||
|
|
||||||
|
|
||||||
def FeedForward(dim, depth_3d, expansion_factor=4, dropout=0., dense=None):
|
|
||||||
if dense is None:
|
|
||||||
dense = partial(col_nn.Linear3D, depth=depth_3d, input_parallel_mode=ParallelMode.PARALLEL_3D_INPUT,
|
|
||||||
weight_parallel_mode=ParallelMode.PARALLEL_3D_WEIGHT)
|
|
||||||
return nn.Sequential(
|
|
||||||
dense(dim, dim * expansion_factor),
|
|
||||||
nn.GELU(),
|
|
||||||
nn.Dropout(dropout),
|
|
||||||
dense(dim * expansion_factor, dim),
|
|
||||||
nn.Dropout(dropout)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module
|
|
||||||
def MLPMixer(image_size, channels, patch_size, dim, depth, num_classes, expansion_factor=4, dropout=0.):
|
|
||||||
assert (image_size % patch_size) == 0, 'image must be divisible by patch size'
|
|
||||||
num_patches = (image_size // patch_size) ** 2
|
|
||||||
depth_3d = get_depth_from_env()
|
|
||||||
linear = partial(col_nn.Linear3D, depth=depth_3d, input_parallel_mode=ParallelMode.PARALLEL_3D_INPUT,
|
|
||||||
weight_parallel_mode=ParallelMode.PARALLEL_3D_WEIGHT)
|
|
||||||
norm_layer = partial(col_nn.LayerNorm3D, depth=depth_3d, input_parallel_mode=ParallelMode.PARALLEL_3D_INPUT,
|
|
||||||
weight_parallel_mode=ParallelMode.PARALLEL_3D_WEIGHT)
|
|
||||||
chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), linear
|
|
||||||
|
|
||||||
return nn.Sequential(
|
|
||||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
|
|
||||||
p1=patch_size, p2=patch_size),
|
|
||||||
linear((patch_size ** 2) * channels, dim),
|
|
||||||
*[nn.Sequential(
|
|
||||||
PreNormResidual(dim, FeedForward(
|
|
||||||
num_patches, expansion_factor, dropout, chan_first)),
|
|
||||||
PreNormResidual(dim, FeedForward(
|
|
||||||
dim, expansion_factor, dropout, chan_last))
|
|
||||||
) for _ in range(depth)],
|
|
||||||
norm_layer(dim),
|
|
||||||
Reduce('b n c -> b c', 'mean'),
|
|
||||||
linear(dim, num_classes)
|
|
||||||
)
|
|
Loading…
Reference in New Issue