diff --git a/internlm/core/scheduler/base_scheduler.py b/internlm/core/scheduler/base_scheduler.py index 6e19425..dade907 100644 --- a/internlm/core/scheduler/base_scheduler.py +++ b/internlm/core/scheduler/base_scheduler.py @@ -122,6 +122,19 @@ class BaseScheduler(ABC): 'but got {type(outputs)} (model outputs) and {type(labels)} (labels)" ) + def cal_max_seqlen(self, data: dict): + if isinstance(data, dict) and "cu_seqlens" in data: + # Without BC modeling interface, we try to calculate 'max_seqlen' in advance + # to avoid overlap being interrupted by .item() operations. + if isinstance(data["cu_seqlens"], list): + cu_seqlens = data["cu_seqlens"][0] + else: + cu_seqlens = data["cu_seqlens"] + + cu_seqlens = cu_seqlens.squeeze(0) + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + data.update({"max_seqlen": max_seqlen}) + class SchedulerHook(ABC): """ diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index aa92788..cd816e0 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -81,17 +81,7 @@ class NonPipelineScheduler(BaseScheduler): _data.pop("cu_seqlens") _data.pop("indexes") - if "cu_seqlens" in _data: - # Without BC modeling interface, we try to calculate 'max_seqlen' in advance - # to avoid overlap being interrupted by .item() operations. - if isinstance(_data["cu_seqlens"], list): - cu_seqlens = _data["cu_seqlens"][0] - else: - cu_seqlens = _data["cu_seqlens"] - - cu_seqlens = cu_seqlens.squeeze(0) - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - _data.update({"max_seqlen": max_seqlen}) + self.cal_max_seqlen(_data) return _data, _label diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler.py index 9740975..7ea6a69 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler.py @@ -222,17 +222,7 @@ class PipelineScheduler(BaseScheduler): micro_batch_data.pop("cu_seqlens") micro_batch_data.pop("indexes") - if "cu_seqlens" in micro_batch_data: - # Without BC modeling interface, we try to calculate 'max_seqlen' in advance - # to avoid overlap being interrupted by .item() operations. - if isinstance(micro_batch_data["cu_seqlens"], list): - cu_seqlens = micro_batch_data["cu_seqlens"][0] - else: - cu_seqlens = micro_batch_data["cu_seqlens"] - - cu_seqlens = cu_seqlens.squeeze(0) - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - micro_batch_data.update({"max_seqlen": max_seqlen}) + self.cal_max_seqlen(micro_batch_data) micro_batch_data["label"] = micro_batch_label self.microbatch_offset += self.bsz_stride