From f4f42d4c3c30f980defcdd2634d282c62763c9a8 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Wed, 13 Apr 2022 00:08:46 +0800 Subject: [PATCH] [bug] fixed DDP compatibility with torch 1.8 (#739) --- tests/test_zero/test_shard_model_v2.py | 2 +- tests/test_zero/test_sharded_optim_v2.py | 2 +- tests/test_zero/test_zero_engine.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_zero/test_shard_model_v2.py b/tests/test_zero/test_shard_model_v2.py index bf84fd29a..2d230f85f 100644 --- a/tests/test_zero/test_shard_model_v2.py +++ b/tests/test_zero/test_shard_model_v2.py @@ -39,7 +39,7 @@ def run_model_test(enable_autocast, shard_strategy_class): col_model_deepcopy(zero_model, model) model = model.cuda() - model = DDP(model) + model = DDP(model, device_ids=[torch.cuda.current_device()]) for i, (data, label) in enumerate(train_dataloader): if i > 5: diff --git a/tests/test_zero/test_sharded_optim_v2.py b/tests/test_zero/test_sharded_optim_v2.py index 0c8f8ea66..34287969f 100644 --- a/tests/test_zero/test_sharded_optim_v2.py +++ b/tests/test_zero/test_sharded_optim_v2.py @@ -86,7 +86,7 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False) apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config) if dist.get_world_size() > 1: - apex_model = DDP(apex_model) + apex_model = DDP(apex_model, device_ids=[torch.cuda.current_device()]) for i, (data, label) in enumerate(train_dataloader): if i > 5: diff --git a/tests/test_zero/test_zero_engine.py b/tests/test_zero/test_zero_engine.py index 50153427c..82e0f69ca 100644 --- a/tests/test_zero/test_zero_engine.py +++ b/tests/test_zero/test_zero_engine.py @@ -50,7 +50,7 @@ def run_dist(rank, world_size, port, parallel_config): torch_optimizer = optimizer_class(torch_model.parameters(), lr=1e-3) if dist.get_world_size() > 1: - torch_model = DDP(torch_model) + torch_model = DDP(torch_model, device_ids=[torch.cuda.current_device()]) i = 0 for data, label in train_dataloader: