diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 5837156a9..943e137e6 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1,5 +1,6 @@ import ctypes import random +import warnings from contextlib import contextmanager from functools import partial from types import MethodType @@ -1134,7 +1135,12 @@ class HybridParallelPlugin(PipelinePluginBase): tp_process_group=self.tp_group, ) else: - assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1." + if self.dp_size == 1: + warnings.warn( + "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " + "If you are not intended to use cpu_offload, please consider set zero_stage=0." + ) + assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO." optimizer = HybridParallelZeroOptimizer( optimizer, diff --git a/tests/test_booster/test_plugin/test_3d_plugin.py b/tests/test_booster/test_plugin/test_3d_plugin.py index 67b0bef50..d629e769d 100644 --- a/tests/test_booster/test_plugin/test_3d_plugin.py +++ b/tests/test_booster/test_plugin/test_3d_plugin.py @@ -118,6 +118,20 @@ def check_3d_plugin(init_method: str = "none", early_stop: bool = True): @parameterize( "test_args", [ + { + "batch_size": 8, + "num_steps": 4, + "tp": 2, + "pp": 2, + "pp_style": "1f1b", + "num_model_chunks": 1, + "num_microbatches": 4, + "zero": 1, + "precision": "fp16", + "initial_scale": 1, + "max_length": 512, + "gradient_accumulation_step": 2, + }, { "batch_size": 8, "num_steps": 4,