diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index 997bd46..915905a 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -329,7 +329,8 @@ class ParallelContext(metaclass=SingletonMeta): return self.is_last_rank(ParallelMode.PIPELINE) def is_no_pp_or_last_stage(self): - return not self.is_initialized(ParallelMode.PIPELINE) or self.is_pipeline_last_stage() + # NOTICE!!!, this will ignore virutal stage + return not self.is_initialized(ParallelMode.PIPELINE) or self.is_last_rank(ParallelMode.PIPELINE) def get_world_size(self, parallel_mode: ParallelMode): """Returns the world size for `parallel_mode`.