From d271f2596b94a12fa6b0a1d67e2269c463ea159c Mon Sep 17 00:00:00 2001 From: jiaruifang Date: Wed, 9 Mar 2022 12:03:49 +0800 Subject: [PATCH] polish engine unitest --- tests/test_engine/test_engine.py | 18 ++++++++++-------- .../test_shard_model_v2.py | 1 + 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/tests/test_engine/test_engine.py b/tests/test_engine/test_engine.py index 4e0928021..aa517e7bc 100644 --- a/tests/test_engine/test_engine.py +++ b/tests/test_engine/test_engine.py @@ -15,7 +15,6 @@ CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None def run_train(): - assert non_distributed_component_funcs.get_callable('bert') for get_components_func in non_distributed_component_funcs: model_builder, train_dataloader, _, optimizer_builder, criterion = get_components_func() @@ -27,12 +26,15 @@ def run_train(): try: engine.train() - for img, label in train_dataloader: + for data, label in train_dataloader: engine.zero_grad() - img = img.cuda() + data = data.cuda() label = label.cuda() - output = engine(img) - loss = engine.criterion(output, label) + if criterion: + output = engine(data) + loss = engine.criterion(output, label) + else: + loss = engine(data, label) engine.backward(loss) engine.step() break @@ -72,9 +74,9 @@ def run_engine(rank, world_size, port): # init dist env colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run_with_no_amp() - # run_with_torch_amp() - # run_with_apex_amp() - # run_with_naive_amp() + run_with_torch_amp() + run_with_apex_amp() + run_with_naive_amp() @pytest.mark.dist diff --git a/tests/test_zero_data_parallel/test_shard_model_v2.py b/tests/test_zero_data_parallel/test_shard_model_v2.py index 7f1af20cf..97ae6634f 100644 --- a/tests/test_zero_data_parallel/test_shard_model_v2.py +++ b/tests/test_zero_data_parallel/test_shard_model_v2.py @@ -76,6 +76,7 @@ def run_dist(rank, world_size, port): check_grads(model, zero_model, loose=True) +@pytest.mark.skip(reason="Under development") @pytest.mark.dist @pytest.mark.parametrize("world_size", [1, 2, 4]) def test_shard_model_v2(world_size):