From 4d94cd513e5c2b264ba27c775f3e24cf72e0b309 Mon Sep 17 00:00:00 2001 From: jiaruifang Date: Wed, 9 Mar 2022 11:26:10 +0800 Subject: [PATCH] adapting bert unitest interface --- tests/components_to_test/bert.py | 14 +++++++++++++- .../test_zero_data_parallel/test_shard_model_v2.py | 10 +++++----- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/tests/components_to_test/bert.py b/tests/components_to_test/bert.py index ea48eeac6..ac9e163f7 100644 --- a/tests/components_to_test/bert.py +++ b/tests/components_to_test/bert.py @@ -48,9 +48,21 @@ def get_training_components(): num_hidden_layers=num_layer, ) print('building BertForSequenceClassification model') - model = BertForSequenceClassification(config) + + # adapting huggingface BertForSequenceClassification for single unitest calling interface + class ModelAaptor(BertForSequenceClassification): + + def forward(self, input_ids, labels): + """ + inputs: data, label + outputs: loss + """ + return super().forward(input_ids=input_ids, labels=labels)[0] + + model = ModelAaptor(config) if checkpoint and version.parse(transformers.__version__) >= version.parse("4.11.0"): model.gradient_checkpointing_enable() + return model trainloader = get_bert_data_loader(batch_size=2, diff --git a/tests/test_zero_data_parallel/test_shard_model_v2.py b/tests/test_zero_data_parallel/test_shard_model_v2.py index f7f191171..0c5686d16 100644 --- a/tests/test_zero_data_parallel/test_shard_model_v2.py +++ b/tests/test_zero_data_parallel/test_shard_model_v2.py @@ -31,11 +31,11 @@ def run_fwd_bwd(model, data, label, criterion, enable_autocast=False): loss.backward() -def run_bert_fwd_bwd(model, data, label, enable_autocast=False): +# with no criterion +def run_fwd_bwd_no_criterion(model, data, label, enable_autocast=False): model.train() with torch.cuda.amp.autocast(enabled=enable_autocast): - output = model(input_ids=data, labels=label) - loss = output[0] + loss = model(data, label) if isinstance(model, ShardedModelV2): model.backward(loss) else: @@ -60,8 +60,8 @@ def run_dist(rank, world_size, port): if model_name == 'bert': data, label = data.cuda(), label.cuda() - run_bert_fwd_bwd(model, data, label, False) - run_bert_fwd_bwd(zero_model, data, label, False) + run_fwd_bwd_no_criterion(model, data, label, False) + run_fwd_bwd_no_criterion(zero_model, data, label, False) else: data, label = data.half().cuda(), label.cuda() run_fwd_bwd(model, data, label, criterion, False)