fixed ddp bug on torch 1.8 (#194)

pull/200/head
Frank Lee 2022-01-28 15:14:04 +08:00 committed by GitHub
parent 569357fea0
commit 765db512b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -348,12 +348,12 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
"added even though not specified in the configuration",
ranks=[0])
elif is_using_sequence():
model = DDP(model, process_group=gpc.get_group(ParallelMode.SEQUENCE_DP))
model = DDP(model, process_group=gpc.get_group(ParallelMode.SEQUENCE_DP), device_ids=[torch.cuda.current_device()])
if verbose:
logger.info(
'Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism', ranks=[0])
elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE:
model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA))
model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA), device_ids=[torch.cuda.current_device()])
if verbose:
logger.info(
'Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0])