import torch from transformers import AlbertConfig, AlbertForSequenceClassification from .bert import get_bert_data_loader from .registry import non_distributed_component_funcs @non_distributed_component_funcs.register(name="albert") def get_training_components(): hidden_dim = 8 num_head = 4 sequence_length = 12 num_layer = 2 vocab_size = 32 def bert_model_builder(checkpoint: bool = False): config = AlbertConfig( vocab_size=vocab_size, gradient_checkpointing=checkpoint, hidden_size=hidden_dim, intermediate_size=hidden_dim * 4, num_attention_heads=num_head, max_position_embeddings=sequence_length, num_hidden_layers=num_layer, hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0, ) print("building AlbertForSequenceClassification model") # adapting huggingface BertForSequenceClassification for single unittest calling interface class ModelAdaptor(AlbertForSequenceClassification): def forward(self, input_ids, labels): """ inputs: data, label outputs: loss """ return super().forward(input_ids=input_ids, labels=labels)[0] model = ModelAdaptor(config) # if checkpoint and version.parse(transformers.__version__) >= version.parse("4.11.0"): # model.gradient_checkpointing_enable() return model is_distributed = torch.distributed.is_initialized() trainloader = get_bert_data_loader( n_class=vocab_size, batch_size=2, total_samples=10000, sequence_length=sequence_length, is_distributed=is_distributed, ) testloader = get_bert_data_loader( n_class=vocab_size, batch_size=2, total_samples=10000, sequence_length=sequence_length, is_distributed=is_distributed, ) criterion = None return bert_model_builder, trainloader, testloader, torch.optim.Adam, criterion