You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/tests/kit/model_zoo/transformers/gpt.py

58 lines
2.7 KiB

import torch
import transformers
from ..registry import ModelAttribute, model_zoo
# ===============================
# Register single-sentence GPT
# ===============================
BATCH_SIZE = 1 # it can only be 1 as GPT cannot handle batch sizes > 1 if no padding token is defined.
SEQ_LENGTH = 16
def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
def seq_classification_data_gen():
# batch sizes should be 1 if no padding token is defined.
input_ids = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64)
token_type_ids = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64)
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
output_transform_fn = lambda x: x
config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=4)
# register the following models
model_zoo.register(name='transformers_gpt',
model_fn=lambda: transformers.GPT2Model(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_gpt_lm',
model_fn=lambda: transformers.GPT2LMHeadModel(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_gpt_double_heads',
model_fn=lambda: transformers.GPT2DoubleHeadsModel(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_gpt_for_token_classification',
model_fn=lambda: transformers.GPT2ForTokenClassification(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_gpt_for_sequence_classification',
model_fn=lambda: transformers.GPT2ForSequenceClassification(config),
data_gen_fn=seq_classification_data_gen,
output_transform_fn=output_transform_fn,
model_attribute=ModelAttribute(has_control_flow=True))