diff --git a/colossalai/amp/naive_amp/_fp16_optimizer.py b/colossalai/amp/naive_amp/_fp16_optimizer.py index 8eecacb77..1e8884c86 100644 --- a/colossalai/amp/naive_amp/_fp16_optimizer.py +++ b/colossalai/amp/naive_amp/_fp16_optimizer.py @@ -70,8 +70,8 @@ class FP16Optimizer(Optimizer): # get process group def _get_process_group(parallel_mode): - if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA): - return gpc.get_group(ParallelMode.DATA) + if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode): + return gpc.get_group(parallel_mode) else: return None