2022-12-05 06:09:34 +00:00
|
|
|
import torch
|
|
|
|
from transformers import AlbertConfig, AlbertForSequenceClassification
|
|
|
|
|
|
|
|
from .bert import get_bert_data_loader
|
|
|
|
from .registry import non_distributed_component_funcs
|
|
|
|
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
@non_distributed_component_funcs.register(name="albert")
|
2022-12-05 06:09:34 +00:00
|
|
|
def get_training_components():
|
|
|
|
hidden_dim = 8
|
|
|
|
num_head = 4
|
|
|
|
sequence_length = 12
|
|
|
|
num_layer = 2
|
|
|
|
vocab_size = 32
|
|
|
|
|
|
|
|
def bert_model_builder(checkpoint: bool = False):
|
2023-09-19 06:20:26 +00:00
|
|
|
config = AlbertConfig(
|
|
|
|
vocab_size=vocab_size,
|
|
|
|
gradient_checkpointing=checkpoint,
|
|
|
|
hidden_size=hidden_dim,
|
|
|
|
intermediate_size=hidden_dim * 4,
|
|
|
|
num_attention_heads=num_head,
|
|
|
|
max_position_embeddings=sequence_length,
|
|
|
|
num_hidden_layers=num_layer,
|
|
|
|
hidden_dropout_prob=0.0,
|
|
|
|
attention_probs_dropout_prob=0.0,
|
|
|
|
)
|
|
|
|
print("building AlbertForSequenceClassification model")
|
2022-12-05 06:09:34 +00:00
|
|
|
|
2023-05-11 08:30:58 +00:00
|
|
|
# adapting huggingface BertForSequenceClassification for single unittest calling interface
|
2023-05-10 09:12:03 +00:00
|
|
|
class ModelAdaptor(AlbertForSequenceClassification):
|
2022-12-05 06:09:34 +00:00
|
|
|
def forward(self, input_ids, labels):
|
|
|
|
"""
|
|
|
|
inputs: data, label
|
|
|
|
outputs: loss
|
|
|
|
"""
|
|
|
|
return super().forward(input_ids=input_ids, labels=labels)[0]
|
|
|
|
|
2023-05-10 09:12:03 +00:00
|
|
|
model = ModelAdaptor(config)
|
2022-12-05 06:09:34 +00:00
|
|
|
# if checkpoint and version.parse(transformers.__version__) >= version.parse("4.11.0"):
|
|
|
|
# model.gradient_checkpointing_enable()
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
2023-05-10 09:12:03 +00:00
|
|
|
is_distributed = torch.distributed.is_initialized()
|
2023-09-19 06:20:26 +00:00
|
|
|
trainloader = get_bert_data_loader(
|
|
|
|
n_class=vocab_size,
|
|
|
|
batch_size=2,
|
|
|
|
total_samples=10000,
|
|
|
|
sequence_length=sequence_length,
|
|
|
|
is_distributed=is_distributed,
|
|
|
|
)
|
|
|
|
testloader = get_bert_data_loader(
|
|
|
|
n_class=vocab_size,
|
|
|
|
batch_size=2,
|
|
|
|
total_samples=10000,
|
|
|
|
sequence_length=sequence_length,
|
|
|
|
is_distributed=is_distributed,
|
|
|
|
)
|
2022-12-05 06:09:34 +00:00
|
|
|
|
|
|
|
criterion = None
|
|
|
|
return bert_model_builder, trainloader, testloader, torch.optim.Adam, criterion
|