From 45c846f7dfbd6a0d7459b4b5a7d2554b7abbb9e4 Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Mon, 9 Oct 2023 18:28:41 +0800 Subject: [PATCH] feat(configs/7B_sft.py): adapt to old version config --- ci_scripts/train/ci_7B_sft.py | 2 +- internlm/core/context/parallel_context.py | 2 +- internlm/initialize/launch.py | 8 ++++++-- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/ci_scripts/train/ci_7B_sft.py b/ci_scripts/train/ci_7B_sft.py index 617ddb7..fea45e1 100644 --- a/ci_scripts/train/ci_7B_sft.py +++ b/ci_scripts/train/ci_7B_sft.py @@ -124,7 +124,7 @@ pipeline parallel: pipeline parallel size, only 1 is accepted currently. tensor parallel: tensor parallel size, usually the number of GPUs per node, only 1 is accepted currently. """ parallel = dict( - zero1=dict(size=8, fsdp=False), + zero1=8, ) cudnn_deterministic = False diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index 031d6f7..b649726 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -508,7 +508,7 @@ class ParallelContext(metaclass=SingletonMeta): initializers.append(pgroup_initializer.Initializer_Model(*initializer_args)) initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args)) initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args)) - if self.config.parallel.zero1.get("fsdp", False): + if isinstance(self.config.parallel.zero1, dict) and self.config.parallel.zero1.get("fsdp", False): initializers.append(pgroup_initializer.Initializer_Zero3_dp(*initializer_args)) initializers.append(pgroup_initializer.Initializer_Nettest(*initializer_args)) if self.pipeline_parallel_size > 1: diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 23596fd..214b51b 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -65,10 +65,14 @@ def args_sanity_check(): # procssing the parallel config in gpc if "zero1" not in gpc.config.parallel: - gpc.config.parallel._add_item("zero1", -1) + gpc.config.parallel._add_item("zero1", dict(size=-1, fsdp=False)) + + if isinstance(gpc.config.parallel.zero1, int): + zero1_size = gpc.config.parallel.zero1 + gpc.config.parallel._add_item("zero1", dict(size=zero1_size, fsdp=False)) if "pipeline" not in gpc.config.parallel: - gpc.config.parallel._add_item("pipeline", 1) + gpc.config.parallel._add_item("pipeline", dict(size=1, interleaved_overlap=False)) if "tensor" not in gpc.config.parallel: gpc.config.parallel._add_item("tensor", 1)