fix bert unit test

pull/394/head
ver217 2022-03-09 13:38:20 +08:00 committed by Frank Lee
parent 5663616921
commit f5f0ad266e
2 changed files with 13 additions and 15 deletions

View File

@ -1,9 +1,9 @@
import torch import torch
import transformers import transformers
from transformers import BertConfig, BertForSequenceClassification
from packaging import version from packaging import version
from torch.utils.data import SequentialSampler from torch.utils.data import SequentialSampler
from transformers import BertConfig, BertForSequenceClassification
from .registry import non_distributed_component_funcs from .registry import non_distributed_component_funcs
@ -39,14 +39,14 @@ def get_training_components():
num_layer = 2 num_layer = 2
def bert_model_builder(checkpoint): def bert_model_builder(checkpoint):
config = BertConfig( config = BertConfig(gradient_checkpointing=checkpoint,
gradient_checkpointing=checkpoint,
hidden_size=hidden_dim, hidden_size=hidden_dim,
intermediate_size=hidden_dim * 4, intermediate_size=hidden_dim * 4,
num_attention_heads=num_head, num_attention_heads=num_head,
max_position_embeddings=sequence_length, max_position_embeddings=sequence_length,
num_hidden_layers=num_layer, num_hidden_layers=num_layer,
) hidden_dropout_prob=0.,
attention_probs_dropout_prob=0.)
print('building BertForSequenceClassification model') print('building BertForSequenceClassification model')
# adapting huggingface BertForSequenceClassification for single unitest calling interface # adapting huggingface BertForSequenceClassification for single unitest calling interface

View File

@ -13,6 +13,7 @@ from colossalai.utils import free_port
from colossalai.zero.shard_utils.tensor_shard_strategy import \ from colossalai.zero.shard_utils.tensor_shard_strategy import \
TensorShardStrategy TensorShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
@ -45,8 +46,7 @@ def run_fwd_bwd_no_criterion(model, data, label, enable_autocast=False):
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
test_models = ['bert'] test_models = ['repeated_computed_layers', 'resnet18', 'bert']
# repeated_computed_layers resnet18
shard_strategy = TensorShardStrategy() shard_strategy = TensorShardStrategy()
for model_name in test_models: for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
@ -65,8 +65,7 @@ def run_dist(rank, world_size, port):
run_fwd_bwd_no_criterion(model, data, label, False) run_fwd_bwd_no_criterion(model, data, label, False)
run_fwd_bwd_no_criterion(zero_model, data, label, False) run_fwd_bwd_no_criterion(zero_model, data, label, False)
else: else:
# FIXME() data can be interger! data, label = cast_tensor_to_fp16(data).cuda(), label.cuda()
data, label = data.half().cuda(), label.cuda()
run_fwd_bwd(model, data, label, criterion, False) run_fwd_bwd(model, data, label, criterion, False)
run_fwd_bwd(zero_model, data, label, criterion, False) run_fwd_bwd(zero_model, data, label, criterion, False)
@ -76,7 +75,6 @@ def run_dist(rank, world_size, port):
check_grads(model, zero_model, loose=True) check_grads(model, zero_model, loose=True)
@pytest.mark.skip(reason="Under development")
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2, 4]) @pytest.mark.parametrize("world_size", [1, 2, 4])
def test_shard_model_v2(world_size): def test_shard_model_v2(world_size):