mirror of https://github.com/hpcaitech/ColossalAI
[ddp] supported customized torch ddp configuration (#1123)
parent
fcf55777dd
commit
91a5999825
|
@ -343,6 +343,9 @@ def initialize(model: nn.Module,
|
||||||
mode=amp_mode,
|
mode=amp_mode,
|
||||||
amp_config=cfg_)
|
amp_config=cfg_)
|
||||||
|
|
||||||
|
# get torch ddp config
|
||||||
|
torch_ddp_cfg = gpc.config.get('torch_ddp', dict())
|
||||||
|
|
||||||
# gradient handler
|
# gradient handler
|
||||||
gradient_handler_cfg = gpc.config.get('gradient_handler', None)
|
gradient_handler_cfg = gpc.config.get('gradient_handler', None)
|
||||||
if gradient_handler_cfg is None:
|
if gradient_handler_cfg is None:
|
||||||
|
@ -368,12 +371,16 @@ def initialize(model: nn.Module,
|
||||||
elif is_using_sequence():
|
elif is_using_sequence():
|
||||||
model = DDP(model,
|
model = DDP(model,
|
||||||
process_group=gpc.get_group(ParallelMode.SEQUENCE_DP),
|
process_group=gpc.get_group(ParallelMode.SEQUENCE_DP),
|
||||||
device_ids=[torch.cuda.current_device()])
|
device_ids=[torch.cuda.current_device()],
|
||||||
|
**torch_ddp_cfg)
|
||||||
if verbose:
|
if verbose:
|
||||||
logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism',
|
logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism',
|
||||||
ranks=[0])
|
ranks=[0])
|
||||||
elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE:
|
elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE:
|
||||||
model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA), device_ids=[torch.cuda.current_device()])
|
model = DDP(model,
|
||||||
|
process_group=gpc.get_group(ParallelMode.DATA),
|
||||||
|
device_ids=[torch.cuda.current_device()],
|
||||||
|
**torch_ddp_cfg)
|
||||||
if verbose:
|
if verbose:
|
||||||
logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0])
|
logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0])
|
||||||
elif is_using_ddp():
|
elif is_using_ddp():
|
||||||
|
|
Loading…
Reference in New Issue