mirror of https://github.com/hpcaitech/ColossalAI
fix bert unit test
parent
5663616921
commit
f5f0ad266e
|
@ -1,9 +1,9 @@
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import BertConfig, BertForSequenceClassification
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from torch.utils.data import SequentialSampler
|
from torch.utils.data import SequentialSampler
|
||||||
|
from transformers import BertConfig, BertForSequenceClassification
|
||||||
|
|
||||||
from .registry import non_distributed_component_funcs
|
from .registry import non_distributed_component_funcs
|
||||||
|
|
||||||
|
|
||||||
|
@ -39,14 +39,14 @@ def get_training_components():
|
||||||
num_layer = 2
|
num_layer = 2
|
||||||
|
|
||||||
def bert_model_builder(checkpoint):
|
def bert_model_builder(checkpoint):
|
||||||
config = BertConfig(
|
config = BertConfig(gradient_checkpointing=checkpoint,
|
||||||
gradient_checkpointing=checkpoint,
|
hidden_size=hidden_dim,
|
||||||
hidden_size=hidden_dim,
|
intermediate_size=hidden_dim * 4,
|
||||||
intermediate_size=hidden_dim * 4,
|
num_attention_heads=num_head,
|
||||||
num_attention_heads=num_head,
|
max_position_embeddings=sequence_length,
|
||||||
max_position_embeddings=sequence_length,
|
num_hidden_layers=num_layer,
|
||||||
num_hidden_layers=num_layer,
|
hidden_dropout_prob=0.,
|
||||||
)
|
attention_probs_dropout_prob=0.)
|
||||||
print('building BertForSequenceClassification model')
|
print('building BertForSequenceClassification model')
|
||||||
|
|
||||||
# adapting huggingface BertForSequenceClassification for single unitest calling interface
|
# adapting huggingface BertForSequenceClassification for single unitest calling interface
|
||||||
|
|
|
@ -13,6 +13,7 @@ from colossalai.utils import free_port
|
||||||
from colossalai.zero.shard_utils.tensor_shard_strategy import \
|
from colossalai.zero.shard_utils.tensor_shard_strategy import \
|
||||||
TensorShardStrategy
|
TensorShardStrategy
|
||||||
from colossalai.zero.sharded_model import ShardedModelV2
|
from colossalai.zero.sharded_model import ShardedModelV2
|
||||||
|
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
|
@ -45,8 +46,7 @@ def run_fwd_bwd_no_criterion(model, data, label, enable_autocast=False):
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
|
||||||
test_models = ['bert']
|
test_models = ['repeated_computed_layers', 'resnet18', 'bert']
|
||||||
# repeated_computed_layers resnet18
|
|
||||||
shard_strategy = TensorShardStrategy()
|
shard_strategy = TensorShardStrategy()
|
||||||
for model_name in test_models:
|
for model_name in test_models:
|
||||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||||
|
@ -65,8 +65,7 @@ def run_dist(rank, world_size, port):
|
||||||
run_fwd_bwd_no_criterion(model, data, label, False)
|
run_fwd_bwd_no_criterion(model, data, label, False)
|
||||||
run_fwd_bwd_no_criterion(zero_model, data, label, False)
|
run_fwd_bwd_no_criterion(zero_model, data, label, False)
|
||||||
else:
|
else:
|
||||||
# FIXME() data can be interger!
|
data, label = cast_tensor_to_fp16(data).cuda(), label.cuda()
|
||||||
data, label = data.half().cuda(), label.cuda()
|
|
||||||
run_fwd_bwd(model, data, label, criterion, False)
|
run_fwd_bwd(model, data, label, criterion, False)
|
||||||
run_fwd_bwd(zero_model, data, label, criterion, False)
|
run_fwd_bwd(zero_model, data, label, criterion, False)
|
||||||
|
|
||||||
|
@ -76,7 +75,6 @@ def run_dist(rank, world_size, port):
|
||||||
check_grads(model, zero_model, loose=True)
|
check_grads(model, zero_model, loose=True)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Under development")
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize("world_size", [1, 2, 4])
|
@pytest.mark.parametrize("world_size", [1, 2, 4])
|
||||||
def test_shard_model_v2(world_size):
|
def test_shard_model_v2(world_size):
|
||||||
|
|
Loading…
Reference in New Issue