[fix] remove unnecessary dp_size assert (#5351)

* fix: remove unnecessary assert

* test: add more 3d plugin tests

* fix: add warning
pull/5357/head
Wenhao Chen 2024-02-02 14:40:20 +08:00 committed by GitHub
parent ffffc32dc7
commit 1c790c0877
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 21 additions and 1 deletions

View File

@ -1,5 +1,6 @@
import ctypes import ctypes
import random import random
import warnings
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial from functools import partial
from types import MethodType from types import MethodType
@ -1134,7 +1135,12 @@ class HybridParallelPlugin(PipelinePluginBase):
tp_process_group=self.tp_group, tp_process_group=self.tp_group,
) )
else: else:
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1." if self.dp_size == 1:
warnings.warn(
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
"If you are not intended to use cpu_offload, please consider set zero_stage=0."
)
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO." assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
optimizer = HybridParallelZeroOptimizer( optimizer = HybridParallelZeroOptimizer(
optimizer, optimizer,

View File

@ -118,6 +118,20 @@ def check_3d_plugin(init_method: str = "none", early_stop: bool = True):
@parameterize( @parameterize(
"test_args", "test_args",
[ [
{
"batch_size": 8,
"num_steps": 4,
"tp": 2,
"pp": 2,
"pp_style": "1f1b",
"num_model_chunks": 1,
"num_microbatches": 4,
"zero": 1,
"precision": "fp16",
"initial_scale": 1,
"max_length": 512,
"gradient_accumulation_step": 2,
},
{ {
"batch_size": 8, "batch_size": 8,
"num_steps": 4, "num_steps": 4,