mirror of https://github.com/hpcaitech/ColossalAI
fixed ddp bug on torch 1.8 (#194)
parent
569357fea0
commit
765db512b5
|
@ -348,12 +348,12 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
|||
"added even though not specified in the configuration",
|
||||
ranks=[0])
|
||||
elif is_using_sequence():
|
||||
model = DDP(model, process_group=gpc.get_group(ParallelMode.SEQUENCE_DP))
|
||||
model = DDP(model, process_group=gpc.get_group(ParallelMode.SEQUENCE_DP), device_ids=[torch.cuda.current_device()])
|
||||
if verbose:
|
||||
logger.info(
|
||||
'Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism', ranks=[0])
|
||||
elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE:
|
||||
model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA))
|
||||
model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA), device_ids=[torch.cuda.current_device()])
|
||||
if verbose:
|
||||
logger.info(
|
||||
'Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0])
|
||||
|
|
Loading…
Reference in New Issue