diff --git a/colossalai/initialize.py b/colossalai/initialize.py index fd7a202b7..086efaac3 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -343,6 +343,9 @@ def initialize(model: nn.Module, mode=amp_mode, amp_config=cfg_) + # get torch ddp config + torch_ddp_cfg = gpc.config.get('torch_ddp', dict()) + # gradient handler gradient_handler_cfg = gpc.config.get('gradient_handler', None) if gradient_handler_cfg is None: @@ -368,12 +371,16 @@ def initialize(model: nn.Module, elif is_using_sequence(): model = DDP(model, process_group=gpc.get_group(ParallelMode.SEQUENCE_DP), - device_ids=[torch.cuda.current_device()]) + device_ids=[torch.cuda.current_device()], + **torch_ddp_cfg) if verbose: logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism', ranks=[0]) elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE: - model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA), device_ids=[torch.cuda.current_device()]) + model = DDP(model, + process_group=gpc.get_group(ParallelMode.DATA), + device_ids=[torch.cuda.current_device()], + **torch_ddp_cfg) if verbose: logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0]) elif is_using_ddp():