From 456a6953f8be1a29304c1a4a197df773f80040b9 Mon Sep 17 00:00:00 2001 From: "877825076@qq.com" <877825076@qq.com> Date: Thu, 28 Dec 2023 19:59:44 +0800 Subject: [PATCH] fix --- internlm/core/scheduler/no_pipeline_scheduler.py | 2 ++ internlm/core/scheduler/pipeline_scheduler.py | 12 ++++++++++++ 2 files changed, 14 insertions(+) diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index 8af8180..aa92788 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -82,6 +82,8 @@ class NonPipelineScheduler(BaseScheduler): _data.pop("indexes") if "cu_seqlens" in _data: + # Without BC modeling interface, we try to calculate 'max_seqlen' in advance + # to avoid overlap being interrupted by .item() operations. if isinstance(_data["cu_seqlens"], list): cu_seqlens = _data["cu_seqlens"][0] else: diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler.py index 5b864ff..9740975 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler.py @@ -222,6 +222,18 @@ class PipelineScheduler(BaseScheduler): micro_batch_data.pop("cu_seqlens") micro_batch_data.pop("indexes") + if "cu_seqlens" in micro_batch_data: + # Without BC modeling interface, we try to calculate 'max_seqlen' in advance + # to avoid overlap being interrupted by .item() operations. + if isinstance(micro_batch_data["cu_seqlens"], list): + cu_seqlens = micro_batch_data["cu_seqlens"][0] + else: + cu_seqlens = micro_batch_data["cu_seqlens"] + + cu_seqlens = cu_seqlens.squeeze(0) + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + micro_batch_data.update({"max_seqlen": max_seqlen}) + micro_batch_data["label"] = micro_batch_label self.microbatch_offset += self.bsz_stride