[fix] remove unnecessary dp_size assert (#5351)

* fix: remove unnecessary assert

* test: add more 3d plugin tests

* fix: add warning
pull/5357/head
Wenhao Chen 2024-02-02 14:40:20 +08:00 committed by GitHub
parent ffffc32dc7
commit 1c790c0877
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 21 additions and 1 deletions

View File

@ -1,5 +1,6 @@
import ctypes
import random
import warnings
from contextlib import contextmanager
from functools import partial
from types import MethodType
@ -1134,7 +1135,12 @@ class HybridParallelPlugin(PipelinePluginBase):
tp_process_group=self.tp_group,
)
else:
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
if self.dp_size == 1:
warnings.warn(
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
"If you are not intended to use cpu_offload, please consider set zero_stage=0."
)
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
optimizer = HybridParallelZeroOptimizer(
optimizer,

View File

@ -118,6 +118,20 @@ def check_3d_plugin(init_method: str = "none", early_stop: bool = True):
@parameterize(
"test_args",
[
{
"batch_size": 8,
"num_steps": 4,
"tp": 2,
"pp": 2,
"pp_style": "1f1b",
"num_model_chunks": 1,
"num_microbatches": 4,
"zero": 1,
"precision": "fp16",
"initial_scale": 1,
"max_length": 512,
"gradient_accumulation_step": 2,
},
{
"batch_size": 8,
"num_steps": 4,