diff --git a/tests/test_engine/test_engine.py b/tests/test_engine/test_engine.py index 4e0928021..aa517e7bc 100644 --- a/tests/test_engine/test_engine.py +++ b/tests/test_engine/test_engine.py @@ -15,7 +15,6 @@ CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None def run_train(): - assert non_distributed_component_funcs.get_callable('bert') for get_components_func in non_distributed_component_funcs: model_builder, train_dataloader, _, optimizer_builder, criterion = get_components_func() @@ -27,12 +26,15 @@ def run_train(): try: engine.train() - for img, label in train_dataloader: + for data, label in train_dataloader: engine.zero_grad() - img = img.cuda() + data = data.cuda() label = label.cuda() - output = engine(img) - loss = engine.criterion(output, label) + if criterion: + output = engine(data) + loss = engine.criterion(output, label) + else: + loss = engine(data, label) engine.backward(loss) engine.step() break @@ -72,9 +74,9 @@ def run_engine(rank, world_size, port): # init dist env colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run_with_no_amp() - # run_with_torch_amp() - # run_with_apex_amp() - # run_with_naive_amp() + run_with_torch_amp() + run_with_apex_amp() + run_with_naive_amp() @pytest.mark.dist diff --git a/tests/test_zero_data_parallel/test_shard_model_v2.py b/tests/test_zero_data_parallel/test_shard_model_v2.py index 7f1af20cf..97ae6634f 100644 --- a/tests/test_zero_data_parallel/test_shard_model_v2.py +++ b/tests/test_zero_data_parallel/test_shard_model_v2.py @@ -76,6 +76,7 @@ def run_dist(rank, world_size, port): check_grads(model, zero_model, loose=True) +@pytest.mark.skip(reason="Under development") @pytest.mark.dist @pytest.mark.parametrize("world_size", [1, 2, 4]) def test_shard_model_v2(world_size):