mirror of https://github.com/InternLM/InternLM
fix
parent
83989b57ae
commit
456a6953f8
|
@ -82,6 +82,8 @@ class NonPipelineScheduler(BaseScheduler):
|
||||||
_data.pop("indexes")
|
_data.pop("indexes")
|
||||||
|
|
||||||
if "cu_seqlens" in _data:
|
if "cu_seqlens" in _data:
|
||||||
|
# Without BC modeling interface, we try to calculate 'max_seqlen' in advance
|
||||||
|
# to avoid overlap being interrupted by .item() operations.
|
||||||
if isinstance(_data["cu_seqlens"], list):
|
if isinstance(_data["cu_seqlens"], list):
|
||||||
cu_seqlens = _data["cu_seqlens"][0]
|
cu_seqlens = _data["cu_seqlens"][0]
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -222,6 +222,18 @@ class PipelineScheduler(BaseScheduler):
|
||||||
micro_batch_data.pop("cu_seqlens")
|
micro_batch_data.pop("cu_seqlens")
|
||||||
micro_batch_data.pop("indexes")
|
micro_batch_data.pop("indexes")
|
||||||
|
|
||||||
|
if "cu_seqlens" in micro_batch_data:
|
||||||
|
# Without BC modeling interface, we try to calculate 'max_seqlen' in advance
|
||||||
|
# to avoid overlap being interrupted by .item() operations.
|
||||||
|
if isinstance(micro_batch_data["cu_seqlens"], list):
|
||||||
|
cu_seqlens = micro_batch_data["cu_seqlens"][0]
|
||||||
|
else:
|
||||||
|
cu_seqlens = micro_batch_data["cu_seqlens"]
|
||||||
|
|
||||||
|
cu_seqlens = cu_seqlens.squeeze(0)
|
||||||
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||||
|
micro_batch_data.update({"max_seqlen": max_seqlen})
|
||||||
|
|
||||||
micro_batch_data["label"] = micro_batch_label
|
micro_batch_data["label"] = micro_batch_label
|
||||||
self.microbatch_offset += self.bsz_stride
|
self.microbatch_offset += self.bsz_stride
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue