adapting bert unitest interface

pull/394/head
jiaruifang 2022-03-09 11:26:10 +08:00 committed by Frank Lee
parent 7977422aeb
commit 4d94cd513e
2 changed files with 18 additions and 6 deletions

View File

@ -48,9 +48,21 @@ def get_training_components():
num_hidden_layers=num_layer,
)
print('building BertForSequenceClassification model')
model = BertForSequenceClassification(config)
# adapting huggingface BertForSequenceClassification for single unitest calling interface
class ModelAaptor(BertForSequenceClassification):
def forward(self, input_ids, labels):
"""
inputs: data, label
outputs: loss
"""
return super().forward(input_ids=input_ids, labels=labels)[0]
model = ModelAaptor(config)
if checkpoint and version.parse(transformers.__version__) >= version.parse("4.11.0"):
model.gradient_checkpointing_enable()
return model
trainloader = get_bert_data_loader(batch_size=2,

View File

@ -31,11 +31,11 @@ def run_fwd_bwd(model, data, label, criterion, enable_autocast=False):
loss.backward()
def run_bert_fwd_bwd(model, data, label, enable_autocast=False):
# with no criterion
def run_fwd_bwd_no_criterion(model, data, label, enable_autocast=False):
model.train()
with torch.cuda.amp.autocast(enabled=enable_autocast):
output = model(input_ids=data, labels=label)
loss = output[0]
loss = model(data, label)
if isinstance(model, ShardedModelV2):
model.backward(loss)
else:
@ -60,8 +60,8 @@ def run_dist(rank, world_size, port):
if model_name == 'bert':
data, label = data.cuda(), label.cuda()
run_bert_fwd_bwd(model, data, label, False)
run_bert_fwd_bwd(zero_model, data, label, False)
run_fwd_bwd_no_criterion(model, data, label, False)
run_fwd_bwd_no_criterion(zero_model, data, label, False)
else:
data, label = data.half().cuda(), label.cuda()
run_fwd_bwd(model, data, label, criterion, False)