mirror of https://github.com/hpcaitech/ColossalAI
Added ViLT-MLM model
parent
3f70a2b12f
commit
7fde1442e9
|
@ -0,0 +1 @@
|
||||||
|
from .vilt import ViLT
|
|
@ -0,0 +1,227 @@
|
||||||
|
from typing import Callable
|
||||||
|
import torch
|
||||||
|
from colossalai import nn as col_nn
|
||||||
|
from colossalai.registry import MODELS
|
||||||
|
from torch import dtype, nn
|
||||||
|
from model_zoo.vit.vit import ViTBlock, ViTEmbedding
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from colossalai.nn.layer.colossalai_layer import LayerNorm
|
||||||
|
from transformers.models.bert.modeling_bert import BertConfig, BertEmbeddings, BertPredictionHeadTransform
|
||||||
|
|
||||||
|
|
||||||
|
@MODELS.register_module
|
||||||
|
class ViLT(nn.Module):
|
||||||
|
"""
|
||||||
|
Vision Language Transformer
|
||||||
|
Capable for masked language modeling
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_text_len: int,
|
||||||
|
num_layers: int,
|
||||||
|
vocab_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
img_size: int = 384,
|
||||||
|
patch_size: int = 16,
|
||||||
|
in_chans: int = 3,
|
||||||
|
depth: int = 12,
|
||||||
|
num_heads: int = 12,
|
||||||
|
dim: int = 768,
|
||||||
|
mlp_ratio: int = 4,
|
||||||
|
attention_dropout: float = 0.,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
drop_path: float = 0.,
|
||||||
|
layernorm_epsilon: float = 1e-6,
|
||||||
|
activation: Callable = nn.functional.gelu,
|
||||||
|
dtype: dtype = None,
|
||||||
|
bias: bool = True,
|
||||||
|
checkpoint: bool = False,
|
||||||
|
init_method: str = 'torch',):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
max_sequence_length = max_text_len
|
||||||
|
num_layers = num_layers
|
||||||
|
vocab_size = vocab_size
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
hidden_size = hidden_size
|
||||||
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
bert_config = BertConfig(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
num_hidden_layers=num_layers,
|
||||||
|
num_attention_heads=num_heads,
|
||||||
|
intermediate_size=hidden_size * mlp_ratio,
|
||||||
|
max_position_embeddings=max_sequence_length,
|
||||||
|
hidden_dropout_prob=dropout,
|
||||||
|
attention_probs_dropout_prob=attention_dropout,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pooler = Pooler(hidden_size)
|
||||||
|
self.token_type_embeddings = nn.Embedding(2, hidden_size)
|
||||||
|
self.token_type_embeddings.apply(init_weights)
|
||||||
|
self.text_embedding = BertEmbeddings(bert_config)
|
||||||
|
self.vis_embedding = ViTEmbedding(
|
||||||
|
img_size=img_size,
|
||||||
|
patch_size=patch_size,
|
||||||
|
in_chans=in_chans,
|
||||||
|
embedding_dim=dim,
|
||||||
|
dropout=dropout,
|
||||||
|
dtype=dtype,
|
||||||
|
init_method=init_method)
|
||||||
|
|
||||||
|
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
|
||||||
|
blocks = [
|
||||||
|
ViTBlock(
|
||||||
|
dim=dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
attention_dropout=attention_dropout,
|
||||||
|
dropout=dropout,
|
||||||
|
drop_path=dpr[i],
|
||||||
|
activation=activation,
|
||||||
|
dtype=dtype,
|
||||||
|
bias=bias,
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
init_method=init_method,
|
||||||
|
) for i in range(depth)
|
||||||
|
]
|
||||||
|
norm = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
|
||||||
|
|
||||||
|
if self.last_stage:
|
||||||
|
self.mlm_score = MLMHead(bert_config)
|
||||||
|
self.mlm_score.apply(init_weights)
|
||||||
|
|
||||||
|
self.layer_norm = LayerNorm(hidden_size)
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
layers.extend(blocks)
|
||||||
|
layers.extend([norm])
|
||||||
|
self.layers = nn.Sequential(
|
||||||
|
*layers
|
||||||
|
)
|
||||||
|
|
||||||
|
def infer(self, x, image_token_type_idx=1):
|
||||||
|
do_mlm = "_mlm"
|
||||||
|
if f"image_{image_token_type_idx - 1}" in x:
|
||||||
|
imgkey = f"image_{image_token_type_idx - 1}"
|
||||||
|
else:
|
||||||
|
imgkey = "image"
|
||||||
|
img = x[imgkey]
|
||||||
|
text_ids = x[f"text_ids{do_mlm}"]
|
||||||
|
text_labels = x[f"text_labels{do_mlm}"]
|
||||||
|
image_embeds = self.vis_embedding(img)
|
||||||
|
text_embeds = self.text_embedding(text_ids)
|
||||||
|
co_embeds = torch.cat([text_embeds, image_embeds], dim=1)
|
||||||
|
x = co_embeds
|
||||||
|
x = self.layers(x)
|
||||||
|
text_feats, image_feats = (
|
||||||
|
x[:, : text_embeds.shape[1]],
|
||||||
|
x[:, text_embeds.shape[1] :],
|
||||||
|
)
|
||||||
|
cls_feats = self.pooler(x)
|
||||||
|
ret = {
|
||||||
|
"text_feats": text_feats,
|
||||||
|
"image_feats": image_feats,
|
||||||
|
"cls_feats": cls_feats,
|
||||||
|
"raw_cls_feats": x[:, 0],
|
||||||
|
"text_labels": text_labels,
|
||||||
|
"text_ids": text_ids,
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
ret = dict()
|
||||||
|
ret.update(self.compute_mlm(x))
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def compute_mlm(self, batch):
|
||||||
|
infer = self.infer(batch)
|
||||||
|
mlm_logits = self.mlm_score(infer["text_feats"])
|
||||||
|
mlm_labels = infer["text_labels"]
|
||||||
|
|
||||||
|
mlm_loss = F.cross_entropy(
|
||||||
|
mlm_logits.view(-1, self.vocab_size),
|
||||||
|
mlm_labels.view(-1),
|
||||||
|
ignore_index=-100,
|
||||||
|
)
|
||||||
|
|
||||||
|
ret = {
|
||||||
|
"mlm_loss": mlm_loss,
|
||||||
|
"mlm_logits": mlm_logits,
|
||||||
|
"mlm_labels": mlm_labels,
|
||||||
|
"mlm_ids": infer["text_ids"],
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class Pooler(nn.Module):
|
||||||
|
def __init__(self, hidden_size):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = nn.Linear(hidden_size, hidden_size)
|
||||||
|
self.activation = nn.Tanh()
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
first_token_tensor = hidden_states[:, 0]
|
||||||
|
pooled_output = self.dense(first_token_tensor)
|
||||||
|
pooled_output = self.activation(pooled_output)
|
||||||
|
return pooled_output
|
||||||
|
|
||||||
|
|
||||||
|
class ITMHead(nn.Module):
|
||||||
|
def __init__(self, hidden_size):
|
||||||
|
super().__init__()
|
||||||
|
self.fc = nn.Linear(hidden_size, 2)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.fc(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MLMHead(nn.Module):
|
||||||
|
def __init__(self, config, weight=None):
|
||||||
|
super().__init__()
|
||||||
|
self.transform = BertPredictionHeadTransform(config)
|
||||||
|
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||||
|
if weight is not None:
|
||||||
|
self.decoder.weight = weight
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.transform(x)
|
||||||
|
x = self.decoder(x) + self.bias
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MPPHead(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.transform = BertPredictionHeadTransform(config)
|
||||||
|
self.decoder = nn.Linear(config.hidden_size, 256 * 3)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.transform(x)
|
||||||
|
x = self.decoder(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_device():
|
||||||
|
'''
|
||||||
|
Returns the index of a currently selected device (gpu/cpu).
|
||||||
|
'''
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return torch.cuda.current_device()
|
||||||
|
else:
|
||||||
|
return 'cpu'
|
||||||
|
|
||||||
|
|
||||||
|
def init_weights(module):
|
||||||
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||||
|
module.weight.data.normal_(mean=0.0, std=0.02)
|
||||||
|
elif isinstance(module, nn.LayerNorm):
|
||||||
|
module.bias.data.zero_()
|
||||||
|
module.weight.data.fill_(1.0)
|
||||||
|
|
||||||
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
Loading…
Reference in New Issue