2023-03-15 03:26:10 +00:00
|
|
|
import torch
|
|
|
|
import transformers
|
|
|
|
|
|
|
|
from ..registry import ModelAttribute, model_zoo
|
|
|
|
|
|
|
|
# ===============================
|
|
|
|
# Register single-sentence GPT
|
|
|
|
# ===============================
|
2023-03-17 03:00:15 +00:00
|
|
|
BATCH_SIZE = 1 # it can only be 1 as GPT cannot handle batch sizes > 1 if no padding token is defined.
|
2023-03-15 03:26:10 +00:00
|
|
|
SEQ_LENGTH = 16
|
|
|
|
|
|
|
|
|
|
|
|
def data_gen():
|
|
|
|
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
|
|
|
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
|
|
|
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
|
|
|
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
|
|
|
|
|
|
|
|
|
|
|
output_transform_fn = lambda x: x
|
|
|
|
|
|
|
|
config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=4)
|
|
|
|
|
|
|
|
# register the following models
|
|
|
|
model_zoo.register(name='transformers_gpt',
|
|
|
|
model_fn=lambda: transformers.GPT2Model(config),
|
|
|
|
data_gen_fn=data_gen,
|
|
|
|
output_transform_fn=output_transform_fn,
|
|
|
|
model_attribute=ModelAttribute(has_control_flow=True))
|
|
|
|
model_zoo.register(name='transformers_gpt_lm',
|
|
|
|
model_fn=lambda: transformers.GPT2LMHeadModel(config),
|
|
|
|
data_gen_fn=data_gen,
|
|
|
|
output_transform_fn=output_transform_fn,
|
|
|
|
model_attribute=ModelAttribute(has_control_flow=True))
|
|
|
|
model_zoo.register(name='transformers_gpt_double_heads',
|
|
|
|
model_fn=lambda: transformers.GPT2DoubleHeadsModel(config),
|
|
|
|
data_gen_fn=data_gen,
|
|
|
|
output_transform_fn=output_transform_fn,
|
|
|
|
model_attribute=ModelAttribute(has_control_flow=True))
|
|
|
|
model_zoo.register(name='transformers_gpt_for_token_classification',
|
|
|
|
model_fn=lambda: transformers.GPT2ForTokenClassification(config),
|
|
|
|
data_gen_fn=data_gen,
|
|
|
|
output_transform_fn=output_transform_fn,
|
|
|
|
model_attribute=ModelAttribute(has_control_flow=True))
|
|
|
|
model_zoo.register(name='transformers_gpt_for_sequence_classification',
|
|
|
|
model_fn=lambda: transformers.GPT2ForSequenceClassification(config),
|
|
|
|
data_gen_fn=data_gen,
|
|
|
|
output_transform_fn=output_transform_fn,
|
|
|
|
model_attribute=ModelAttribute(has_control_flow=True))
|