mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
416 lines
14 KiB
416 lines
14 KiB
import math
|
|
from typing import Callable
|
|
|
|
import torch
|
|
from colossalai import nn as col_nn
|
|
from colossalai.nn.layer.utils import CheckpointModule
|
|
from colossalai.registry import LAYERS, MODELS
|
|
from torch import dtype, nn
|
|
|
|
__all__ = [
|
|
'VisionTransformer',
|
|
'vit_lite_depth7_patch4_32',
|
|
'vit_tiny_patch4_32',
|
|
'vit_tiny_patch16_224',
|
|
'vit_tiny_patch16_384',
|
|
'vit_small_patch16_224',
|
|
'vit_small_patch16_384',
|
|
'vit_small_patch32_224',
|
|
'vit_small_patch32_384',
|
|
'vit_base_patch16_224',
|
|
'vit_base_patch16_384',
|
|
'vit_base_patch32_224',
|
|
'vit_base_patch32_384',
|
|
'vit_large_patch16_224',
|
|
'vit_large_patch16_384',
|
|
'vit_large_patch32_224',
|
|
'vit_large_patch32_384',
|
|
]
|
|
|
|
_init_rules = dict(
|
|
torch=dict(
|
|
embed=dict(
|
|
weight_initializer=col_nn.init.kaiming_uniform_(a=math.sqrt(5)),
|
|
bias_initializer=col_nn.init.xavier_uniform_(a=1, scale=1),
|
|
position_embed_initializer=col_nn.init.zeros_(),
|
|
),
|
|
transformer=dict(
|
|
weight_initializer=col_nn.init.kaiming_uniform_(a=math.sqrt(5)),
|
|
bias_initializer=col_nn.init.xavier_uniform_(a=1, scale=1),
|
|
),
|
|
head=dict(
|
|
weight_initializer=col_nn.init.kaiming_uniform_(a=math.sqrt(5)),
|
|
bias_initializer=col_nn.init.xavier_uniform_(a=1, scale=1),
|
|
),
|
|
),
|
|
jax=dict(
|
|
embed=dict(
|
|
weight_initializer=col_nn.init.lecun_normal_(),
|
|
bias_initializer=col_nn.init.zeros_(),
|
|
position_embed_initializer=col_nn.init.trunc_normal_(std=.02),
|
|
),
|
|
transformer=dict(
|
|
weight_initializer=col_nn.init.xavier_uniform_(),
|
|
bias_initializer=col_nn.init.normal_(std=1e-6),
|
|
),
|
|
head=dict(
|
|
weight_initializer=col_nn.init.zeros_(),
|
|
bias_initializer=col_nn.init.zeros_(),
|
|
),
|
|
),
|
|
)
|
|
|
|
|
|
@LAYERS.register_module
|
|
class ViTEmbedding(nn.Module):
|
|
def __init__(self,
|
|
img_size: int,
|
|
patch_size: int,
|
|
in_chans: int,
|
|
embedding_dim: int,
|
|
dropout: float,
|
|
dtype: dtype = None,
|
|
flatten: bool = True,
|
|
init_method: str = 'torch'):
|
|
super().__init__()
|
|
self.patch_embed = col_nn.PatchEmbedding(img_size,
|
|
patch_size,
|
|
in_chans,
|
|
embedding_dim,
|
|
dtype=dtype,
|
|
flatten=flatten,
|
|
**_init_rules[init_method]['embed'])
|
|
self.dropout = col_nn.Dropout(dropout)
|
|
|
|
def forward(self, x):
|
|
x = self.patch_embed(x)
|
|
x = self.dropout(x)
|
|
return x
|
|
|
|
|
|
@LAYERS.register_module
|
|
class ViTSelfAttention(nn.Module):
|
|
def __init__(self,
|
|
dim: int,
|
|
num_heads: int,
|
|
attention_dropout: float,
|
|
dropout: float,
|
|
bias: bool = True,
|
|
dtype: dtype = None,
|
|
init_method: str = 'torch'):
|
|
super().__init__()
|
|
self.attention_head_size = dim // num_heads
|
|
self.query_key_value = col_nn.Linear(dim,
|
|
3 * dim,
|
|
dtype=dtype,
|
|
bias=bias,
|
|
**_init_rules[init_method]['transformer'])
|
|
self.attention_dropout = col_nn.Dropout(attention_dropout)
|
|
self.dense = col_nn.Linear(dim, dim, dtype=dtype, bias=True, **_init_rules[init_method]['transformer'])
|
|
self.dropout = col_nn.Dropout(dropout)
|
|
self.softmax = nn.Softmax(dim=-1)
|
|
|
|
def forward(self, x):
|
|
qkv = self.query_key_value(x)
|
|
all_head_size = qkv.shape[-1] // 3
|
|
num_attention_heads = all_head_size // self.attention_head_size
|
|
new_qkv_shape = qkv.shape[:-1] + \
|
|
(num_attention_heads, 3 * self.attention_head_size)
|
|
qkv = qkv.view(new_qkv_shape)
|
|
qkv = qkv.permute((0, 2, 1, 3))
|
|
q, k, v = torch.chunk(qkv, 3, dim=-1)
|
|
|
|
x = torch.matmul(q, k.transpose(-1, -2))
|
|
x = x / math.sqrt(self.attention_head_size)
|
|
x = self.softmax(x)
|
|
x = self.attention_dropout(x)
|
|
|
|
x = torch.matmul(x, v)
|
|
x = x.transpose(1, 2)
|
|
new_context_layer_shape = x.size()[:-2] + (all_head_size, )
|
|
x = x.reshape(new_context_layer_shape)
|
|
|
|
x = self.dense(x)
|
|
x = self.dropout(x)
|
|
|
|
return x
|
|
|
|
|
|
@LAYERS.register_module
|
|
class ViTMLP(nn.Module):
|
|
def __init__(self,
|
|
dim: int,
|
|
mlp_ratio: int,
|
|
activation: Callable,
|
|
dropout: float,
|
|
dtype: dtype = None,
|
|
bias: bool = True,
|
|
init_method: str = 'torch'):
|
|
super().__init__()
|
|
self.dense_1 = col_nn.Linear(dim,
|
|
mlp_ratio * dim,
|
|
dtype=dtype,
|
|
bias=bias,
|
|
**_init_rules[init_method]['transformer'])
|
|
self.activation = activation
|
|
self.dropout_1 = col_nn.Dropout(dropout)
|
|
self.dense_2 = col_nn.Linear(mlp_ratio * dim,
|
|
dim,
|
|
dtype=dtype,
|
|
bias=bias,
|
|
**_init_rules[init_method]['transformer'])
|
|
self.dropout_2 = col_nn.Dropout(dropout)
|
|
|
|
def forward(self, x):
|
|
x = self.dense_1(x)
|
|
x = self.activation(x)
|
|
x = self.dropout_1(x)
|
|
x = self.dense_2(x)
|
|
x = self.dropout_2(x)
|
|
return x
|
|
|
|
|
|
@LAYERS.register_module
|
|
class ViTHead(nn.Module):
|
|
def __init__(self,
|
|
dim: int,
|
|
num_classes: int,
|
|
representation_size: int = None,
|
|
dtype: dtype = None,
|
|
bias: bool = True,
|
|
init_method: str = 'torch'):
|
|
super().__init__()
|
|
if representation_size:
|
|
self.representation = col_nn.Linear(dim,
|
|
representation_size,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
**_init_rules[init_method]['head'])
|
|
else:
|
|
self.representation = None
|
|
representation_size = dim
|
|
|
|
self.dense = col_nn.Classifier(representation_size,
|
|
num_classes,
|
|
dtype=dtype,
|
|
bias=bias,
|
|
**_init_rules[init_method]['head'])
|
|
|
|
def forward(self, x):
|
|
x = x[:, 0]
|
|
if self.representation is not None:
|
|
x = self.representation(x)
|
|
x = self.dense(x)
|
|
return x
|
|
|
|
|
|
@LAYERS.register_module
|
|
class ViTBlock(CheckpointModule):
|
|
def __init__(self,
|
|
dim: int,
|
|
num_heads: int,
|
|
mlp_ratio: int,
|
|
activation: Callable,
|
|
attention_dropout: float = 0.,
|
|
dropout: float = 0.,
|
|
drop_path: float = 0.,
|
|
layernorm_epsilon: float = 1e-6,
|
|
dtype: dtype = None,
|
|
bias: bool = True,
|
|
checkpoint: bool = False,
|
|
init_method: str = 'torch'):
|
|
super().__init__(checkpoint)
|
|
self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
|
|
self.attn = ViTSelfAttention(dim=dim,
|
|
num_heads=num_heads,
|
|
attention_dropout=attention_dropout,
|
|
dropout=dropout,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
init_method=init_method)
|
|
self.drop_path = col_nn.DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
self.norm2 = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
|
|
self.mlp = ViTMLP(dim=dim,
|
|
mlp_ratio=mlp_ratio,
|
|
activation=activation,
|
|
dropout=dropout,
|
|
dtype=dtype,
|
|
bias=bias,
|
|
init_method=init_method)
|
|
|
|
def _forward(self, x):
|
|
x = x + self.drop_path(self.attn(self.norm1(x)))
|
|
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
|
return x
|
|
|
|
|
|
@MODELS.register_module
|
|
class VisionTransformer(nn.Module):
|
|
def __init__(self,
|
|
img_size: int = 224,
|
|
patch_size: int = 16,
|
|
in_chans: int = 3,
|
|
num_classes: int = 1000,
|
|
depth: int = 12,
|
|
num_heads: int = 12,
|
|
dim: int = 768,
|
|
mlp_ratio: int = 4,
|
|
attention_dropout: float = 0.,
|
|
dropout: float = 0.1,
|
|
drop_path: float = 0.,
|
|
layernorm_epsilon: float = 1e-6,
|
|
activation: Callable = nn.functional.gelu,
|
|
representation_size: int = None,
|
|
dtype: dtype = None,
|
|
bias: bool = True,
|
|
checkpoint: bool = False,
|
|
init_method: str = 'torch'):
|
|
super().__init__()
|
|
|
|
embed = ViTEmbedding(img_size=img_size,
|
|
patch_size=patch_size,
|
|
in_chans=in_chans,
|
|
embedding_dim=dim,
|
|
dropout=dropout,
|
|
dtype=dtype,
|
|
init_method=init_method)
|
|
|
|
# stochastic depth decay rule
|
|
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
|
|
blocks = [
|
|
ViTBlock(
|
|
dim=dim,
|
|
num_heads=num_heads,
|
|
mlp_ratio=mlp_ratio,
|
|
attention_dropout=attention_dropout,
|
|
dropout=dropout,
|
|
drop_path=dpr[i],
|
|
activation=activation,
|
|
dtype=dtype,
|
|
bias=bias,
|
|
checkpoint=checkpoint,
|
|
init_method=init_method,
|
|
) for i in range(depth)
|
|
]
|
|
|
|
norm = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
|
|
|
|
head = ViTHead(dim=dim,
|
|
num_classes=num_classes,
|
|
representation_size=representation_size,
|
|
dtype=dtype,
|
|
bias=bias,
|
|
init_method=init_method)
|
|
|
|
self.layers = nn.Sequential(
|
|
embed,
|
|
*blocks,
|
|
norm,
|
|
head,
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.layers(x)
|
|
return x
|
|
|
|
|
|
def _create_vit_model(**model_kwargs):
|
|
model = VisionTransformer(**model_kwargs)
|
|
return model
|
|
|
|
|
|
@MODELS.register_module
|
|
def vit_lite_depth7_patch4_32(**kwargs):
|
|
model_kwargs = dict(img_size=32, patch_size=4, dim=256, depth=7, num_heads=4, mlp_ratio=2, num_classes=10, **kwargs)
|
|
return _create_vit_model(**model_kwargs)
|
|
|
|
|
|
@MODELS.register_module
|
|
def vit_tiny_patch4_32(**kwargs):
|
|
model_kwargs = dict(img_size=32, patch_size=4, dim=512, depth=6, num_heads=8, mlp_ratio=1, num_classes=10, **kwargs)
|
|
return _create_vit_model(**model_kwargs)
|
|
|
|
|
|
@MODELS.register_module
|
|
def vit_tiny_patch16_224(**kwargs):
|
|
model_kwargs = dict(img_size=224, patch_size=16, dim=192, depth=12, num_heads=3, mlp_ratio=4, **kwargs)
|
|
return _create_vit_model(**model_kwargs)
|
|
|
|
|
|
@MODELS.register_module
|
|
def vit_tiny_patch16_384(**kwargs):
|
|
model_kwargs = dict(img_size=384, patch_size=16, dim=192, depth=12, num_heads=3, mlp_ratio=4, **kwargs)
|
|
return _create_vit_model(**model_kwargs)
|
|
|
|
|
|
@MODELS.register_module
|
|
def vit_small_patch16_224(**kwargs):
|
|
model_kwargs = dict(img_size=224, patch_size=16, dim=384, depth=12, num_heads=6, mlp_ratio=4, **kwargs)
|
|
return _create_vit_model(**model_kwargs)
|
|
|
|
|
|
@MODELS.register_module
|
|
def vit_small_patch16_384(**kwargs):
|
|
model_kwargs = dict(img_size=384, patch_size=16, dim=384, depth=12, num_heads=6, mlp_ratio=4, **kwargs)
|
|
return _create_vit_model(**model_kwargs)
|
|
|
|
|
|
@MODELS.register_module
|
|
def vit_small_patch32_224(**kwargs):
|
|
model_kwargs = dict(img_size=224, patch_size=32, dim=384, depth=12, num_heads=6, mlp_ratio=4, **kwargs)
|
|
return _create_vit_model(**model_kwargs)
|
|
|
|
|
|
@MODELS.register_module
|
|
def vit_small_patch32_384(**kwargs):
|
|
model_kwargs = dict(img_size=384, patch_size=32, dim=384, depth=12, num_heads=6, mlp_ratio=4, **kwargs)
|
|
return _create_vit_model(**model_kwargs)
|
|
|
|
|
|
@MODELS.register_module
|
|
def vit_base_patch16_224(**kwargs):
|
|
model_kwargs = dict(img_size=224, patch_size=16, dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
|
|
return _create_vit_model(**model_kwargs)
|
|
|
|
|
|
@MODELS.register_module
|
|
def vit_base_patch16_384(**kwargs):
|
|
model_kwargs = dict(img_size=384, patch_size=16, dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
|
|
return _create_vit_model(**model_kwargs)
|
|
|
|
|
|
@MODELS.register_module
|
|
def vit_base_patch32_224(**kwargs):
|
|
model_kwargs = dict(img_size=224, patch_size=32, dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
|
|
return _create_vit_model(**model_kwargs)
|
|
|
|
|
|
@MODELS.register_module
|
|
def vit_base_patch32_384(**kwargs):
|
|
model_kwargs = dict(img_size=384, patch_size=32, dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
|
|
return _create_vit_model(**model_kwargs)
|
|
|
|
|
|
@MODELS.register_module
|
|
def vit_large_patch16_224(**kwargs):
|
|
model_kwargs = dict(img_size=224, patch_size=16, dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
|
|
return _create_vit_model(**model_kwargs)
|
|
|
|
|
|
@MODELS.register_module
|
|
def vit_large_patch16_384(**kwargs):
|
|
model_kwargs = dict(img_size=384, patch_size=16, dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
|
|
return _create_vit_model(**model_kwargs)
|
|
|
|
|
|
@MODELS.register_module
|
|
def vit_large_patch32_224(**kwargs):
|
|
model_kwargs = dict(img_size=224, patch_size=32, dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
|
|
return _create_vit_model(**model_kwargs)
|
|
|
|
|
|
@MODELS.register_module
|
|
def vit_large_patch32_384(**kwargs):
|
|
model_kwargs = dict(img_size=384, patch_size=32, dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
|
|
return _create_vit_model(**model_kwargs)
|