ColossalAI/model_zoo/vit/parallel_2d/vit.py

219 lines
7.0 KiB
Python

from colossalai.context import ParallelMode, seed
from colossalai import nn as clsl_nn
from colossalai.registry import MODELS
from torch import nn
import torch
__all__ = [
'VisionTransformer2D',
'vit_tiny_2d_patch4_32',
'vit_tiny_2d_patch16_224',
'vit_tiny_2d_patch16_384',
'vit_small_2d_patch16_224',
'vit_small_2d_patch16_384',
'vit_small_2d_patch32_224',
'vit_small_2d_patch32_384',
'vit_base_2d_patch16_224',
'vit_base_2d_patch16_384',
'vit_base_2d_patch32_224',
'vit_base_2d_patch32_384',
'vit_large_2d_patch16_224',
'vit_large_2d_patch16_384',
'vit_large_2d_patch32_224',
'vit_large_2d_patch32_384',
]
class ViTBlock2D(nn.Module):
def __init__(self,
dim: int,
num_heads: int,
mlp_ratio: int = 4,
drop: float = 0.,
attn_drop: float = 0.,
drop_path: float = 0.,
act_layer: str = 'gelu'):
super().__init__()
self.norm1 = clsl_nn.LayerNorm2D(dim, eps=1e-6)
self.attn = clsl_nn.ViTSelfAttention2D(dim, num_heads, attn_drop, drop)
self.drop_path = clsl_nn.VanillaViTDropPath(drop_path) if drop_path > 0. \
else nn.Identity()
self.norm2 = clsl_nn.LayerNorm2D(dim, eps=1e-6)
self.mlp = clsl_nn.ViTMLP2D(dim, mlp_ratio, act_layer, drop)
def forward(self, x):
y = self.attn(self.norm1(x))
with seed(ParallelMode.TENSOR):
x = x + self.drop_path(y)
y = self.mlp(self.norm2(x))
with seed(ParallelMode.TENSOR):
x = x + self.drop_path(y)
return x
@MODELS.register_module
class VisionTransformer2D(nn.Module):
def __init__(self,
img_size: int = 224,
patch_size: int = 16,
in_chans: int = 3,
num_classes: int = 1000,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: int = 4,
drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
act_layer: str = 'gelu'):
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim
self.patch_embed = clsl_nn.ViTPatchEmbedding2D(
img_size, patch_size, embed_dim, in_chans
)
self.splitter = clsl_nn.ViTInputSplitter2D()
self.token_fuser = clsl_nn.ViTTokenFuser2D(
img_size, patch_size, embed_dim, drop_rate
)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.blocks = nn.Sequential(*[
ViTBlock2D(embed_dim, num_heads, mlp_ratio, drop_rate,
attn_drop_rate, dpr[i], act_layer)
for i in range(depth)
])
self.norm = clsl_nn.LayerNorm2D(embed_dim, eps=1e-6)
self.head = clsl_nn.ViTHead2D(self.num_features, num_classes) if num_classes > 0 \
else nn.Identity()
self.init_weights()
def init_weights(self):
pass
def forward(self, x):
x = self.patch_embed(x)
x = self.splitter(x)
x = self.token_fuser(x)
x = self.blocks(x)
x = self.norm(x)
x = self.head(x)
return x
def _create_vit_model(**model_kwargs):
model = VisionTransformer2D(**model_kwargs)
return model
@MODELS.register_module
def vit_tiny_2d_patch4_32(**kwargs):
model_kwargs = dict(img_size=32, patch_size=4, embed_dim=512,
depth=6, num_heads=8, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_tiny_2d_patch16_224(**kwargs):
model_kwargs = dict(patch_size=16, embed_dim=192,
depth=12, num_heads=3, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_tiny_2d_patch16_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=16, embed_dim=192,
depth=12, num_heads=3, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_small_2d_patch16_224(**kwargs):
model_kwargs = dict(patch_size=16, embed_dim=384,
depth=12, num_heads=6, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_small_2d_patch16_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=16, embed_dim=384,
depth=12, num_heads=6, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_small_2d_patch32_224(**kwargs):
model_kwargs = dict(patch_size=32, embed_dim=384,
depth=12, num_heads=6, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_small_2d_patch32_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=32, embed_dim=384,
depth=12, num_heads=6, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_base_2d_patch16_224(**kwargs):
model_kwargs = dict(patch_size=16, embed_dim=768,
depth=12, num_heads=12, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_base_2d_patch16_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=16, embed_dim=768,
depth=12, num_heads=12, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_base_2d_patch32_224(**kwargs):
model_kwargs = dict(patch_size=32, embed_dim=768,
depth=12, num_heads=12, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_base_2d_patch32_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=32, embed_dim=768,
depth=12, num_heads=12, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_large_2d_patch16_224(**kwargs):
model_kwargs = dict(patch_size=16, embed_dim=1024,
depth=24, num_heads=16, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_large_2d_patch16_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=16, embed_dim=1024,
depth=24, num_heads=16, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_large_2d_patch32_224(**kwargs):
model_kwargs = dict(patch_size=32, embed_dim=1024,
depth=24, num_heads=16, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_large_2d_patch32_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=32, embed_dim=1024,
depth=24, num_heads=16, **kwargs)
return _create_vit_model(**model_kwargs)