You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/tests/components_to_test/albert.py

60 lines
2.4 KiB

import torch
import transformers
from packaging import version
from transformers import AlbertConfig, AlbertForSequenceClassification
from .bert import get_bert_data_loader
from .registry import non_distributed_component_funcs
@non_distributed_component_funcs.register(name='albert')
def get_training_components():
hidden_dim = 8
num_head = 4
sequence_length = 12
num_layer = 2
vocab_size = 32
def bert_model_builder(checkpoint: bool = False):
config = AlbertConfig(vocab_size=vocab_size,
gradient_checkpointing=checkpoint,
hidden_size=hidden_dim,
intermediate_size=hidden_dim * 4,
num_attention_heads=num_head,
max_position_embeddings=sequence_length,
num_hidden_layers=num_layer,
hidden_dropout_prob=0.,
attention_probs_dropout_prob=0.)
print('building AlbertForSequenceClassification model')
# adapting huggingface BertForSequenceClassification for single unitest calling interface
class ModelAaptor(AlbertForSequenceClassification):
def forward(self, input_ids, labels):
"""
inputs: data, label
outputs: loss
"""
return super().forward(input_ids=input_ids, labels=labels)[0]
model = ModelAaptor(config)
# if checkpoint and version.parse(transformers.__version__) >= version.parse("4.11.0"):
# model.gradient_checkpointing_enable()
return model
is_distrbuted = torch.distributed.is_initialized()
trainloader = get_bert_data_loader(n_class=vocab_size,
batch_size=2,
total_samples=10000,
sequence_length=sequence_length,
is_distrbuted=is_distrbuted)
testloader = get_bert_data_loader(n_class=vocab_size,
batch_size=2,
total_samples=10000,
sequence_length=sequence_length,
is_distrbuted=is_distrbuted)
criterion = None
return bert_model_builder, trainloader, testloader, torch.optim.Adam, criterion