ColossalAI/model_zoo/vit/vit.py

416 lines
14 KiB
Python

import math
from typing import Callable
import torch
from colossalai import nn as col_nn
from colossalai.nn.layer.utils import CheckpointModule
from colossalai.registry import LAYERS, MODELS
from torch import dtype, nn
__all__ = [
'VisionTransformer',
'vit_lite_depth7_patch4_32',
'vit_tiny_patch4_32',
'vit_tiny_patch16_224',
'vit_tiny_patch16_384',
'vit_small_patch16_224',
'vit_small_patch16_384',
'vit_small_patch32_224',
'vit_small_patch32_384',
'vit_base_patch16_224',
'vit_base_patch16_384',
'vit_base_patch32_224',
'vit_base_patch32_384',
'vit_large_patch16_224',
'vit_large_patch16_384',
'vit_large_patch32_224',
'vit_large_patch32_384',
]
_init_rules = dict(
torch=dict(
embed=dict(
weight_initializer=col_nn.init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer=col_nn.init.xavier_uniform_(a=1, scale=1),
position_embed_initializer=col_nn.init.zeros_(),
),
transformer=dict(
weight_initializer=col_nn.init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer=col_nn.init.xavier_uniform_(a=1, scale=1),
),
head=dict(
weight_initializer=col_nn.init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer=col_nn.init.xavier_uniform_(a=1, scale=1),
),
),
jax=dict(
embed=dict(
weight_initializer=col_nn.init.lecun_normal_(),
bias_initializer=col_nn.init.zeros_(),
position_embed_initializer=col_nn.init.trunc_normal_(std=.02),
),
transformer=dict(
weight_initializer=col_nn.init.xavier_uniform_(),
bias_initializer=col_nn.init.normal_(std=1e-6),
),
head=dict(
weight_initializer=col_nn.init.zeros_(),
bias_initializer=col_nn.init.zeros_(),
),
),
)
@LAYERS.register_module
class ViTEmbedding(nn.Module):
def __init__(self,
img_size: int,
patch_size: int,
in_chans: int,
embedding_dim: int,
dropout: float,
dtype: dtype = None,
flatten: bool = True,
init_method: str = 'torch'):
super().__init__()
self.patch_embed = col_nn.PatchEmbedding(img_size,
patch_size,
in_chans,
embedding_dim,
dtype=dtype,
flatten=flatten,
**_init_rules[init_method]['embed'])
self.dropout = col_nn.Dropout(dropout)
def forward(self, x):
x = self.patch_embed(x)
x = self.dropout(x)
return x
@LAYERS.register_module
class ViTSelfAttention(nn.Module):
def __init__(self,
dim: int,
num_heads: int,
attention_dropout: float,
dropout: float,
bias: bool = True,
dtype: dtype = None,
init_method: str = 'torch'):
super().__init__()
self.attention_head_size = dim // num_heads
self.query_key_value = col_nn.Linear(dim,
3 * dim,
dtype=dtype,
bias=bias,
**_init_rules[init_method]['transformer'])
self.attention_dropout = col_nn.Dropout(attention_dropout)
self.dense = col_nn.Linear(dim, dim, dtype=dtype, bias=True, **_init_rules[init_method]['transformer'])
self.dropout = col_nn.Dropout(dropout)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
qkv = self.query_key_value(x)
all_head_size = qkv.shape[-1] // 3
num_attention_heads = all_head_size // self.attention_head_size
new_qkv_shape = qkv.shape[:-1] + \
(num_attention_heads, 3 * self.attention_head_size)
qkv = qkv.view(new_qkv_shape)
qkv = qkv.permute((0, 2, 1, 3))
q, k, v = torch.chunk(qkv, 3, dim=-1)
x = torch.matmul(q, k.transpose(-1, -2))
x = x / math.sqrt(self.attention_head_size)
x = self.softmax(x)
x = self.attention_dropout(x)
x = torch.matmul(x, v)
x = x.transpose(1, 2)
new_context_layer_shape = x.size()[:-2] + (all_head_size, )
x = x.reshape(new_context_layer_shape)
x = self.dense(x)
x = self.dropout(x)
return x
@LAYERS.register_module
class ViTMLP(nn.Module):
def __init__(self,
dim: int,
mlp_ratio: int,
activation: Callable,
dropout: float,
dtype: dtype = None,
bias: bool = True,
init_method: str = 'torch'):
super().__init__()
self.dense_1 = col_nn.Linear(dim,
mlp_ratio * dim,
dtype=dtype,
bias=bias,
**_init_rules[init_method]['transformer'])
self.activation = activation
self.dropout_1 = col_nn.Dropout(dropout)
self.dense_2 = col_nn.Linear(mlp_ratio * dim,
dim,
dtype=dtype,
bias=bias,
**_init_rules[init_method]['transformer'])
self.dropout_2 = col_nn.Dropout(dropout)
def forward(self, x):
x = self.dense_1(x)
x = self.activation(x)
x = self.dropout_1(x)
x = self.dense_2(x)
x = self.dropout_2(x)
return x
@LAYERS.register_module
class ViTHead(nn.Module):
def __init__(self,
dim: int,
num_classes: int,
representation_size: int = None,
dtype: dtype = None,
bias: bool = True,
init_method: str = 'torch'):
super().__init__()
if representation_size:
self.representation = col_nn.Linear(dim,
representation_size,
bias=bias,
dtype=dtype,
**_init_rules[init_method]['head'])
else:
self.representation = None
representation_size = dim
self.dense = col_nn.Classifier(representation_size,
num_classes,
dtype=dtype,
bias=bias,
**_init_rules[init_method]['head'])
def forward(self, x):
x = x[:, 0]
if self.representation is not None:
x = self.representation(x)
x = self.dense(x)
return x
@LAYERS.register_module
class ViTBlock(CheckpointModule):
def __init__(self,
dim: int,
num_heads: int,
mlp_ratio: int,
activation: Callable,
attention_dropout: float = 0.,
dropout: float = 0.,
drop_path: float = 0.,
layernorm_epsilon: float = 1e-6,
dtype: dtype = None,
bias: bool = True,
checkpoint: bool = False,
init_method: str = 'torch'):
super().__init__(checkpoint)
self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
self.attn = ViTSelfAttention(dim=dim,
num_heads=num_heads,
attention_dropout=attention_dropout,
dropout=dropout,
bias=bias,
dtype=dtype,
init_method=init_method)
self.drop_path = col_nn.DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
self.mlp = ViTMLP(dim=dim,
mlp_ratio=mlp_ratio,
activation=activation,
dropout=dropout,
dtype=dtype,
bias=bias,
init_method=init_method)
def _forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
@MODELS.register_module
class VisionTransformer(nn.Module):
def __init__(self,
img_size: int = 224,
patch_size: int = 16,
in_chans: int = 3,
num_classes: int = 1000,
depth: int = 12,
num_heads: int = 12,
dim: int = 768,
mlp_ratio: int = 4,
attention_dropout: float = 0.,
dropout: float = 0.1,
drop_path: float = 0.,
layernorm_epsilon: float = 1e-6,
activation: Callable = nn.functional.gelu,
representation_size: int = None,
dtype: dtype = None,
bias: bool = True,
checkpoint: bool = False,
init_method: str = 'torch'):
super().__init__()
embed = ViTEmbedding(img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embedding_dim=dim,
dropout=dropout,
dtype=dtype,
init_method=init_method)
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
blocks = [
ViTBlock(
dim=dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
attention_dropout=attention_dropout,
dropout=dropout,
drop_path=dpr[i],
activation=activation,
dtype=dtype,
bias=bias,
checkpoint=checkpoint,
init_method=init_method,
) for i in range(depth)
]
norm = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
head = ViTHead(dim=dim,
num_classes=num_classes,
representation_size=representation_size,
dtype=dtype,
bias=bias,
init_method=init_method)
self.layers = nn.Sequential(
embed,
*blocks,
norm,
head,
)
def forward(self, x):
x = self.layers(x)
return x
def _create_vit_model(**model_kwargs):
model = VisionTransformer(**model_kwargs)
return model
@MODELS.register_module
def vit_lite_depth7_patch4_32(**kwargs):
model_kwargs = dict(img_size=32, patch_size=4, dim=256, depth=7, num_heads=4, mlp_ratio=2, num_classes=10, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_tiny_patch4_32(**kwargs):
model_kwargs = dict(img_size=32, patch_size=4, dim=512, depth=6, num_heads=8, mlp_ratio=1, num_classes=10, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_tiny_patch16_224(**kwargs):
model_kwargs = dict(img_size=224, patch_size=16, dim=192, depth=12, num_heads=3, mlp_ratio=4, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_tiny_patch16_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=16, dim=192, depth=12, num_heads=3, mlp_ratio=4, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_small_patch16_224(**kwargs):
model_kwargs = dict(img_size=224, patch_size=16, dim=384, depth=12, num_heads=6, mlp_ratio=4, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_small_patch16_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=16, dim=384, depth=12, num_heads=6, mlp_ratio=4, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_small_patch32_224(**kwargs):
model_kwargs = dict(img_size=224, patch_size=32, dim=384, depth=12, num_heads=6, mlp_ratio=4, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_small_patch32_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=32, dim=384, depth=12, num_heads=6, mlp_ratio=4, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_base_patch16_224(**kwargs):
model_kwargs = dict(img_size=224, patch_size=16, dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_base_patch16_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=16, dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_base_patch32_224(**kwargs):
model_kwargs = dict(img_size=224, patch_size=32, dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_base_patch32_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=32, dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_large_patch16_224(**kwargs):
model_kwargs = dict(img_size=224, patch_size=16, dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_large_patch16_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=16, dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_large_patch32_224(**kwargs):
model_kwargs = dict(img_size=224, patch_size=32, dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_large_patch32_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=32, dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
return _create_vit_model(**model_kwargs)