pull/564/head
877825076@qq.com 2023-12-28 19:59:44 +08:00
parent 83989b57ae
commit 456a6953f8
2 changed files with 14 additions and 0 deletions

View File

@ -82,6 +82,8 @@ class NonPipelineScheduler(BaseScheduler):
_data.pop("indexes") _data.pop("indexes")
if "cu_seqlens" in _data: if "cu_seqlens" in _data:
# Without BC modeling interface, we try to calculate 'max_seqlen' in advance
# to avoid overlap being interrupted by .item() operations.
if isinstance(_data["cu_seqlens"], list): if isinstance(_data["cu_seqlens"], list):
cu_seqlens = _data["cu_seqlens"][0] cu_seqlens = _data["cu_seqlens"][0]
else: else:

View File

@ -222,6 +222,18 @@ class PipelineScheduler(BaseScheduler):
micro_batch_data.pop("cu_seqlens") micro_batch_data.pop("cu_seqlens")
micro_batch_data.pop("indexes") micro_batch_data.pop("indexes")
if "cu_seqlens" in micro_batch_data:
# Without BC modeling interface, we try to calculate 'max_seqlen' in advance
# to avoid overlap being interrupted by .item() operations.
if isinstance(micro_batch_data["cu_seqlens"], list):
cu_seqlens = micro_batch_data["cu_seqlens"][0]
else:
cu_seqlens = micro_batch_data["cu_seqlens"]
cu_seqlens = cu_seqlens.squeeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
micro_batch_data.update({"max_seqlen": max_seqlen})
micro_batch_data["label"] = micro_batch_label micro_batch_data["label"] = micro_batch_label
self.microbatch_offset += self.bsz_stride self.microbatch_offset += self.bsz_stride