From 91a5999825137ffb4d575b21bf4c6cb41033161a Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Wed, 15 Jun 2022 18:11:53 +0800 Subject: [PATCH] [ddp] supported customized torch ddp configuration (#1123) --- colossalai/initialize.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/colossalai/initialize.py b/colossalai/initialize.py index fd7a202b7..086efaac3 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -343,6 +343,9 @@ def initialize(model: nn.Module, mode=amp_mode, amp_config=cfg_) + # get torch ddp config + torch_ddp_cfg = gpc.config.get('torch_ddp', dict()) + # gradient handler gradient_handler_cfg = gpc.config.get('gradient_handler', None) if gradient_handler_cfg is None: @@ -368,12 +371,16 @@ def initialize(model: nn.Module, elif is_using_sequence(): model = DDP(model, process_group=gpc.get_group(ParallelMode.SEQUENCE_DP), - device_ids=[torch.cuda.current_device()]) + device_ids=[torch.cuda.current_device()], + **torch_ddp_cfg) if verbose: logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism', ranks=[0]) elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE: - model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA), device_ids=[torch.cuda.current_device()]) + model = DDP(model, + process_group=gpc.get_group(ParallelMode.DATA), + device_ids=[torch.cuda.current_device()], + **torch_ddp_cfg) if verbose: logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0]) elif is_using_ddp():