feat(configs/7B_sft.py): adapt to old version config

pull/293/head
huangting4201 2023-10-09 18:28:41 +08:00
parent edd7f9e8e1
commit 45c846f7df
3 changed files with 8 additions and 4 deletions

View File

@ -124,7 +124,7 @@ pipeline parallel: pipeline parallel size, only 1 is accepted currently.
tensor parallel: tensor parallel size, usually the number of GPUs per node, only 1 is accepted currently.
"""
parallel = dict(
zero1=dict(size=8, fsdp=False),
zero1=8,
)
cudnn_deterministic = False

View File

@ -508,7 +508,7 @@ class ParallelContext(metaclass=SingletonMeta):
initializers.append(pgroup_initializer.Initializer_Model(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args))
if self.config.parallel.zero1.get("fsdp", False):
if isinstance(self.config.parallel.zero1, dict) and self.config.parallel.zero1.get("fsdp", False):
initializers.append(pgroup_initializer.Initializer_Zero3_dp(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Nettest(*initializer_args))
if self.pipeline_parallel_size > 1:

View File

@ -65,10 +65,14 @@ def args_sanity_check():
# procssing the parallel config in gpc
if "zero1" not in gpc.config.parallel:
gpc.config.parallel._add_item("zero1", -1)
gpc.config.parallel._add_item("zero1", dict(size=-1, fsdp=False))
if isinstance(gpc.config.parallel.zero1, int):
zero1_size = gpc.config.parallel.zero1
gpc.config.parallel._add_item("zero1", dict(size=zero1_size, fsdp=False))
if "pipeline" not in gpc.config.parallel:
gpc.config.parallel._add_item("pipeline", 1)
gpc.config.parallel._add_item("pipeline", dict(size=1, interleaved_overlap=False))
if "tensor" not in gpc.config.parallel:
gpc.config.parallel._add_item("tensor", 1)