mirror of https://github.com/hpcaitech/ColossalAI
[bug] fixed DDP compatibility with torch 1.8 (#739)
parent
a4e91bc87f
commit
f4f42d4c3c
|
@ -39,7 +39,7 @@ def run_model_test(enable_autocast, shard_strategy_class):
|
||||||
col_model_deepcopy(zero_model, model)
|
col_model_deepcopy(zero_model, model)
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
|
|
||||||
model = DDP(model)
|
model = DDP(model, device_ids=[torch.cuda.current_device()])
|
||||||
|
|
||||||
for i, (data, label) in enumerate(train_dataloader):
|
for i, (data, label) in enumerate(train_dataloader):
|
||||||
if i > 5:
|
if i > 5:
|
||||||
|
|
|
@ -86,7 +86,7 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
|
||||||
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False)
|
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False)
|
||||||
apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config)
|
apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config)
|
||||||
if dist.get_world_size() > 1:
|
if dist.get_world_size() > 1:
|
||||||
apex_model = DDP(apex_model)
|
apex_model = DDP(apex_model, device_ids=[torch.cuda.current_device()])
|
||||||
|
|
||||||
for i, (data, label) in enumerate(train_dataloader):
|
for i, (data, label) in enumerate(train_dataloader):
|
||||||
if i > 5:
|
if i > 5:
|
||||||
|
|
|
@ -50,7 +50,7 @@ def run_dist(rank, world_size, port, parallel_config):
|
||||||
torch_optimizer = optimizer_class(torch_model.parameters(), lr=1e-3)
|
torch_optimizer = optimizer_class(torch_model.parameters(), lr=1e-3)
|
||||||
|
|
||||||
if dist.get_world_size() > 1:
|
if dist.get_world_size() > 1:
|
||||||
torch_model = DDP(torch_model)
|
torch_model = DDP(torch_model, device_ids=[torch.cuda.current_device()])
|
||||||
|
|
||||||
i = 0
|
i = 0
|
||||||
for data, label in train_dataloader:
|
for data, label in train_dataloader:
|
||||||
|
|
Loading…
Reference in New Issue