diff --git a/ci_scripts/train/ci_7B_sft.py b/ci_scripts/train/ci_7B_sft.py index fea45e1..617ddb7 100644 --- a/ci_scripts/train/ci_7B_sft.py +++ b/ci_scripts/train/ci_7B_sft.py @@ -124,7 +124,7 @@ pipeline parallel: pipeline parallel size, only 1 is accepted currently. tensor parallel: tensor parallel size, usually the number of GPUs per node, only 1 is accepted currently. """ parallel = dict( - zero1=8, + zero1=dict(size=8, fsdp=False), ) cudnn_deterministic = False diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 505d17b..865b959 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -154,7 +154,7 @@ pipeline parallel (dict): tensor parallel: tensor parallel size, usually the number of GPUs per node. """ parallel = dict( - zero1=8, + zero1=dict(size=8, fsdp=False), tensor=1, pipeline=dict(size=1, interleaved_overlap=True), sequence_parallel=False, diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index f0a358a..031d6f7 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -436,10 +436,10 @@ class ParallelContext(metaclass=SingletonMeta): # if zo_size < dp_size, ckpts saving will introduce redundent storage for model weights # because pytorch "ShardTensor" need to ensure current global rank equals to saved shard's global rank # pytorch vision: 1.13.1+cu117 - if self.data_parallel_size > self.zero1_parallel_size and self.config.parallel.get("use_fsdp", False): + if self.data_parallel_size > self.zero1_parallel_size and self.config.parallel.zero1.get("fsdp", False): logger.warning( f"zo size: {self.zero1_parallel_size} < dp size: {self.data_parallel_size}, " - "will introduce redundancy when saving ckpts, recommend setting them to same value" + "will introduce redundancy when saving fsdp model ckpts, recommend setting them to same value" ) def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str): @@ -508,7 +508,7 @@ class ParallelContext(metaclass=SingletonMeta): initializers.append(pgroup_initializer.Initializer_Model(*initializer_args)) initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args)) initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args)) - if self.config.parallel.get("use_fsdp", False): + if self.config.parallel.zero1.get("fsdp", False): initializers.append(pgroup_initializer.Initializer_Zero3_dp(*initializer_args)) initializers.append(pgroup_initializer.Initializer_Nettest(*initializer_args)) if self.pipeline_parallel_size > 1: diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index e7d8932..23596fd 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -79,12 +79,14 @@ def args_sanity_check(): pp = gpc.config.parallel.pipeline.size # check fsdp config - if "use_fsdp" not in gpc.config.parallel: - gpc.config.parallel._add_item("use_fsdp", False) + if "fsdp" not in gpc.config.parallel.zero1: + gpc.config.parallel.zero1._add_item("fsdp", False) + assert not ( - gpc.config.parallel.use_fsdp and pp > 1 - ), "FSDP not support when pipeline size > 1, please set pipeline size to 1 or close FSDP" - if gpc.config.parallel.use_fsdp: + gpc.config.parallel.zero1.fsdp and pp > 1 + ), "FSDP is not supportted when pipeline size > 1, please set pipeline size to 1 or disabled FSDP" + + if gpc.config.parallel.zero1.fsdp: assert ( torch.__version__ >= "2.0.1" ), f"requires torch>=2.0.1 when using fsdp but current version is {torch.__version__}" @@ -288,7 +290,7 @@ def args_sanity_check(): if "moe_gate_k" not in model: model._add_item("moe_gate_k", 2) assert not ( - gpc.config.model.num_experts > 1 and gpc.config.parallel.use_fsdp + gpc.config.model.num_experts > 1 and gpc.config.parallel.zero1.fsdp ), "FSDP does not support num_experts > 1" # process the parallel config diff --git a/internlm/solver/optimizer/fsdp_optimizer.py b/internlm/solver/optimizer/fsdp_optimizer.py index a3f21db..6000185 100644 --- a/internlm/solver/optimizer/fsdp_optimizer.py +++ b/internlm/solver/optimizer/fsdp_optimizer.py @@ -20,7 +20,7 @@ logger = get_logger(__file__) class FSDPadaptOptimizer(BaseOptimizer): """ - optimizer for Pytorch FSDP if 'use_fsdp' is True in config file + optimizer for Pytorch FSDP if 'parallel.zero1.fsdp' is True in config file reserve some necessary components of hybird-optim: grad_scaler; grad_clip and unscale; @@ -48,7 +48,6 @@ class FSDPadaptOptimizer(BaseOptimizer): # clip gradient self._clip_grad_norm = zero_cfg.clip_grad_norm - self.use_fsdp = gpc.config.parallel.use_fsdp # fp16 and fp32 params # fp16 share mem space with model.FlatParam, fp32 share mem space with optim.param_group diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 53a5711..7af58dd 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -111,7 +111,7 @@ def initialize_model(): def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]): - if gpc.config.parallel.use_fsdp: + if gpc.config.parallel.zero1.fsdp: # set wrap_policy for fsdp wrap transformer_wrap_policy = functools.partial( transformer_auto_wrap_policy, @@ -168,7 +168,7 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]): eps=adam_cfg.adam_eps, ) - if not gpc.config.parallel.use_fsdp: + if not gpc.config.parallel.zero1.fsdp: optimizer = HybridZeroOptimizer( naive_optimizer, grad_scal_cfg=gpc.config.grad_scaler, diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index d84dd62..0fc3718 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -265,7 +265,7 @@ def save_model_checkpoint(folder, model): model: The model to be saved """ - if gpc.config.parallel.use_fsdp: + if gpc.config.parallel.zero1.fsdp: states = get_shard_state_dict(model) else: states = model.state_dict() @@ -285,18 +285,18 @@ def save_model_checkpoint(folder, model): # even if pp is not considered, it will definitely not be written on the same machine. should_save_rank_pair = set() # (tp_rank, dp_rank) for i in range(tp_size): - if gpc.config.parallel.use_fsdp: + if gpc.config.parallel.zero1.fsdp: for j in range(dp_size): should_save_rank_pair.add((i, j)) else: should_save_rank_pair.add((i, i % dp_size)) if (tp_rank, dp_rank) in should_save_rank_pair: - f_dp = f"_dp{dp_rank}" if gpc.config.parallel.use_fsdp else "" + f_dp = f"_dp{dp_rank}" if gpc.config.parallel.zero1.fsdp else "" fn = f"model_tp{tp_rank}_pp{pp_rank}{f_dp}.pt" fp = os.path.join(folder, fn) llm_save(fp, saved_obj=states) - if not gpc.config.parallel.use_fsdp or dp_rank == tp_rank % dp_size: + if not gpc.config.parallel.zero1.fsdp or dp_rank == tp_rank % dp_size: topo_fn = f"topo_tp{tp_rank}_pp{pp_rank}.json" topo_fp = os.path.join(folder, topo_fn) llm_save(topo_fp, saved_obj=topo) @@ -338,15 +338,15 @@ def load_model_checkpoint(folder, model): # avoid ckpt misuse between FSDP and no-FSDP test_fn = list([f for f in fns if f.startswith("model_t") and not f.endswith(".md5")]).pop() - assert ("_dp" in test_fn and gpc.config.parallel.use_fsdp) or ( - "_dp" not in test_fn and not gpc.config.parallel.use_fsdp + assert ("_dp" in test_fn and gpc.config.parallel.zero1.fsdp) or ( + "_dp" not in test_fn and not gpc.config.parallel.zero1.fsdp ), "FSDP model wants to load no-FSDP ckpts or reverse" max_pp, max_tp, max_zo = 0, 0, 0 for fn in fns: if fn.startswith("model_t") and not fn.endswith(".md5"): segements = os.path.splitext(fn)[0].split("_") - if gpc.config.parallel.use_fsdp: + if gpc.config.parallel.zero1.fsdp: max_zo = max(max_zo, int(segements[-1][2:])) max_pp = max(max_pp, int(segements[-2][2:])) max_tp = max(max_tp, int(segements[-3][2:])) @@ -360,12 +360,12 @@ def load_model_checkpoint(folder, model): assert ( tp_size == max_tp + 1 ), f"The weights are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism" - if gpc.config.parallel.use_fsdp: + if gpc.config.parallel.zero1.fsdp: assert ( dp_size == max_zo + 1 ), f"The weights are save for {max_zo+1} FSDP shards , while current has {dp_size} FSDP shards" - if gpc.config.parallel.use_fsdp: + if gpc.config.parallel.zero1.fsdp: should_load_name = f"model_tp{tp_rank}_pp{pp_rank}_dp{dp_rank}.pt" else: should_load_name = f"model_tp{tp_rank}_pp{pp_rank}.pt" @@ -388,7 +388,7 @@ def load_model_checkpoint(folder, model): # try to load expert parameter to separate files if model have moe layer try_load_moe_checkpoint(folder, model, states, tp_rank, pp_rank) - if gpc.config.parallel.use_fsdp: + if gpc.config.parallel.zero1.fsdp: missing_k, unexpected_keys = load_shard_state_dict(model, states, strict=False) else: missing_k, unexpected_keys = model.load_state_dict(states, strict=False)