diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index 16c84a8..3a65c2f 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -427,10 +427,10 @@ class ParallelContext(metaclass=SingletonMeta): # if zo_size < dp_size, ckpts saving will introduce redundent storage for model weights # because pytorch "ShardTensor" need to ensure current global rank equals to saved shard's global rank # pytorch vision: 1.13.1+cu117 - if self.data_parallel_size > self.zero1_parallel_size: + if self.data_parallel_size > self.zero1_parallel_size and self.config.parallel.get("use_fsdp", False): logger.warning( - f"zo size: {self.zero1_parallel_size} < dp size: {self.data_parallel_size}, \ - will introduce redundancy when saving ckpts, recommend setting them to same value" + f"zo size: {self.zero1_parallel_size} < dp size: {self.data_parallel_size}, " + "will introduce redundancy when saving ckpts, recommend setting them to same value" ) def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str):