mirror of https://github.com/hpcaitech/ColossalAI
skip bert in test engine
parent
d41a9f12c6
commit
2b8cddd40e
|
@ -4,9 +4,9 @@ import colossalai
|
|||
import pytest
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.amp import AMP_TYPE
|
||||
from colossalai.context import Config
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.context import Config
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)),
|
||||
|
@ -15,7 +15,10 @@ CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None
|
|||
|
||||
|
||||
def run_train():
|
||||
for get_components_func in non_distributed_component_funcs:
|
||||
test_models = ['repeated_computed_layers', 'resnet18', 'repeated_computed_layers']
|
||||
# FIXME: test bert
|
||||
for model_name in test_models:
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, _, optimizer_builder, criterion = get_components_func()
|
||||
|
||||
model = model_builder(checkpoint=False)
|
||||
|
|
Loading…
Reference in New Issue