Browse Source

[bug] fixed DDP compatibility with torch 1.8 (#739)

pull/742/head
Frank Lee 3 years ago committed by GitHub
parent
commit
f4f42d4c3c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      tests/test_zero/test_shard_model_v2.py
  2. 2
      tests/test_zero/test_sharded_optim_v2.py
  3. 2
      tests/test_zero/test_zero_engine.py

2
tests/test_zero/test_shard_model_v2.py

@ -39,7 +39,7 @@ def run_model_test(enable_autocast, shard_strategy_class):
col_model_deepcopy(zero_model, model)
model = model.cuda()
model = DDP(model)
model = DDP(model, device_ids=[torch.cuda.current_device()])
for i, (data, label) in enumerate(train_dataloader):
if i > 5:

2
tests/test_zero/test_sharded_optim_v2.py

@ -86,7 +86,7 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False)
apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config)
if dist.get_world_size() > 1:
apex_model = DDP(apex_model)
apex_model = DDP(apex_model, device_ids=[torch.cuda.current_device()])
for i, (data, label) in enumerate(train_dataloader):
if i > 5:

2
tests/test_zero/test_zero_engine.py

@ -50,7 +50,7 @@ def run_dist(rank, world_size, port, parallel_config):
torch_optimizer = optimizer_class(torch_model.parameters(), lr=1e-3)
if dist.get_world_size() > 1:
torch_model = DDP(torch_model)
torch_model = DDP(torch_model, device_ids=[torch.cuda.current_device()])
i = 0
for data, label in train_dataloader:

Loading…
Cancel
Save