ColossalAI/tests/components_to_test/bert.py

90 lines
2.9 KiB
Python
Raw Normal View History

import torch
import transformers
from packaging import version
from torch.utils.data import SequentialSampler
2022-03-09 05:38:20 +00:00
from transformers import BertConfig, BertForSequenceClassification
from .registry import non_distributed_component_funcs
def get_bert_data_loader(
n_class,
batch_size,
total_samples,
sequence_length,
device=torch.device("cpu:0"),
is_distributed=False,
):
train_data = torch.randint(
low=0,
high=n_class,
size=(total_samples, sequence_length),
device=device,
dtype=torch.long,
)
train_label = torch.randint(low=0, high=2, size=(total_samples,), device=device, dtype=torch.long)
train_dataset = torch.utils.data.TensorDataset(train_data, train_label)
if is_distributed:
sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
else:
sampler = SequentialSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)
return train_loader
@non_distributed_component_funcs.register(name="bert")
def get_training_components():
hidden_dim = 8
num_head = 4
sequence_length = 12
num_layer = 2
vocab_size = 32
def bert_model_builder(checkpoint: bool = False):
config = BertConfig(
vocab_size=vocab_size,
gradient_checkpointing=checkpoint,
hidden_size=hidden_dim,
intermediate_size=hidden_dim * 4,
num_attention_heads=num_head,
max_position_embeddings=sequence_length,
num_hidden_layers=num_layer,
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
)
print("building BertForSequenceClassification model")
2022-03-09 03:26:10 +00:00
# adapting huggingface BertForSequenceClassification for single unittest calling interface
class ModelAdaptor(BertForSequenceClassification):
2022-03-09 03:26:10 +00:00
def forward(self, input_ids, labels):
"""
inputs: data, label
outputs: loss
"""
return super().forward(input_ids=input_ids, labels=labels)[0]
model = ModelAdaptor(config)
if checkpoint and version.parse(transformers.__version__) >= version.parse("4.11.0"):
model.gradient_checkpointing_enable()
2022-03-09 03:26:10 +00:00
return model
is_distributed = torch.distributed.is_initialized()
trainloader = get_bert_data_loader(
n_class=vocab_size,
batch_size=2,
total_samples=10000,
sequence_length=sequence_length,
is_distributed=is_distributed,
)
testloader = get_bert_data_loader(
n_class=vocab_size,
batch_size=2,
total_samples=10000,
sequence_length=sequence_length,
is_distributed=is_distributed,
)
criterion = None
return bert_model_builder, trainloader, testloader, torch.optim.Adam, criterion