2023-07-25 07:02:29 +00:00
|
|
|
import torch
|
|
|
|
import transformers
|
|
|
|
|
|
|
|
from ..registry import ModelAttribute, model_zoo
|
|
|
|
|
|
|
|
# ===============================
|
|
|
|
# Register single-sentence VIT
|
|
|
|
# ===============================
|
|
|
|
|
2023-08-11 07:43:23 +00:00
|
|
|
config = transformers.ViTConfig(num_hidden_layers=4, hidden_size=128, intermediate_size=256, num_attention_heads=4)
|
2023-07-25 07:02:29 +00:00
|
|
|
|
|
|
|
|
|
|
|
# define data gen function
|
|
|
|
def data_gen():
|
|
|
|
pixel_values = torch.randn(1, 3, 224, 224)
|
|
|
|
return dict(pixel_values=pixel_values)
|
|
|
|
|
|
|
|
|
|
|
|
def data_gen_for_image_classification():
|
|
|
|
data = data_gen()
|
2023-09-19 06:20:26 +00:00
|
|
|
data["labels"] = torch.tensor([0])
|
2023-07-25 07:02:29 +00:00
|
|
|
return data
|
|
|
|
|
|
|
|
|
|
|
|
def data_gen_for_masked_image_modeling():
|
|
|
|
data = data_gen()
|
2023-09-19 06:20:26 +00:00
|
|
|
num_patches = (config.image_size // config.patch_size) ** 2
|
2023-07-25 07:02:29 +00:00
|
|
|
bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
|
2023-09-19 06:20:26 +00:00
|
|
|
data["bool_masked_pos"] = bool_masked_pos
|
2023-07-25 07:02:29 +00:00
|
|
|
return data
|
|
|
|
|
|
|
|
|
|
|
|
# define output transform function
|
|
|
|
output_transform_fn = lambda x: x
|
|
|
|
|
|
|
|
# function to get the loss
|
|
|
|
loss_fn_for_vit_model = lambda x: x.pooler_output.mean()
|
|
|
|
loss_fn_for_image_classification = lambda x: x.logits.mean()
|
|
|
|
loss_fn_for_masked_image_modeling = lambda x: x.loss
|
|
|
|
|
|
|
|
# register the following models
|
|
|
|
# transformers.ViTModel,
|
|
|
|
# transformers.ViTForMaskedImageModeling,
|
|
|
|
# transformers.ViTForImageClassification,
|
2023-09-19 06:20:26 +00:00
|
|
|
model_zoo.register(
|
|
|
|
name="transformers_vit",
|
|
|
|
model_fn=lambda: transformers.ViTModel(config),
|
|
|
|
data_gen_fn=data_gen,
|
|
|
|
output_transform_fn=output_transform_fn,
|
|
|
|
loss_fn=loss_fn_for_vit_model,
|
|
|
|
model_attribute=ModelAttribute(has_control_flow=True),
|
|
|
|
)
|
2023-07-25 07:02:29 +00:00
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
model_zoo.register(
|
|
|
|
name="transformers_vit_for_masked_image_modeling",
|
|
|
|
model_fn=lambda: transformers.ViTForMaskedImageModeling(config),
|
|
|
|
data_gen_fn=data_gen_for_masked_image_modeling,
|
|
|
|
output_transform_fn=output_transform_fn,
|
|
|
|
loss_fn=loss_fn_for_masked_image_modeling,
|
|
|
|
model_attribute=ModelAttribute(has_control_flow=True),
|
|
|
|
)
|
2023-07-25 07:02:29 +00:00
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
model_zoo.register(
|
|
|
|
name="transformers_vit_for_image_classification",
|
|
|
|
model_fn=lambda: transformers.ViTForImageClassification(config),
|
|
|
|
data_gen_fn=data_gen_for_image_classification,
|
|
|
|
output_transform_fn=output_transform_fn,
|
|
|
|
loss_fn=loss_fn_for_image_classification,
|
|
|
|
model_attribute=ModelAttribute(has_control_flow=True),
|
|
|
|
)
|