mirror of https://github.com/hpcaitech/ColossalAI
[ddp] supported customized torch ddp configuration (#1123)
parent
fcf55777dd
commit
91a5999825
|
@ -343,6 +343,9 @@ def initialize(model: nn.Module,
|
|||
mode=amp_mode,
|
||||
amp_config=cfg_)
|
||||
|
||||
# get torch ddp config
|
||||
torch_ddp_cfg = gpc.config.get('torch_ddp', dict())
|
||||
|
||||
# gradient handler
|
||||
gradient_handler_cfg = gpc.config.get('gradient_handler', None)
|
||||
if gradient_handler_cfg is None:
|
||||
|
@ -368,12 +371,16 @@ def initialize(model: nn.Module,
|
|||
elif is_using_sequence():
|
||||
model = DDP(model,
|
||||
process_group=gpc.get_group(ParallelMode.SEQUENCE_DP),
|
||||
device_ids=[torch.cuda.current_device()])
|
||||
device_ids=[torch.cuda.current_device()],
|
||||
**torch_ddp_cfg)
|
||||
if verbose:
|
||||
logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism',
|
||||
ranks=[0])
|
||||
elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE:
|
||||
model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA), device_ids=[torch.cuda.current_device()])
|
||||
model = DDP(model,
|
||||
process_group=gpc.get_group(ParallelMode.DATA),
|
||||
device_ids=[torch.cuda.current_device()],
|
||||
**torch_ddp_cfg)
|
||||
if verbose:
|
||||
logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0])
|
||||
elif is_using_ddp():
|
||||
|
|
Loading…
Reference in New Issue