mirror of https://github.com/InternLM/InternLM
fix
parent
456a6953f8
commit
c437ffbfc9
|
@ -122,6 +122,19 @@ class BaseScheduler(ABC):
|
|||
'but got {type(outputs)} (model outputs) and {type(labels)} (labels)"
|
||||
)
|
||||
|
||||
def cal_max_seqlen(self, data: dict):
|
||||
if isinstance(data, dict) and "cu_seqlens" in data:
|
||||
# Without BC modeling interface, we try to calculate 'max_seqlen' in advance
|
||||
# to avoid overlap being interrupted by .item() operations.
|
||||
if isinstance(data["cu_seqlens"], list):
|
||||
cu_seqlens = data["cu_seqlens"][0]
|
||||
else:
|
||||
cu_seqlens = data["cu_seqlens"]
|
||||
|
||||
cu_seqlens = cu_seqlens.squeeze(0)
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
data.update({"max_seqlen": max_seqlen})
|
||||
|
||||
|
||||
class SchedulerHook(ABC):
|
||||
"""
|
||||
|
|
|
@ -81,17 +81,7 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
_data.pop("cu_seqlens")
|
||||
_data.pop("indexes")
|
||||
|
||||
if "cu_seqlens" in _data:
|
||||
# Without BC modeling interface, we try to calculate 'max_seqlen' in advance
|
||||
# to avoid overlap being interrupted by .item() operations.
|
||||
if isinstance(_data["cu_seqlens"], list):
|
||||
cu_seqlens = _data["cu_seqlens"][0]
|
||||
else:
|
||||
cu_seqlens = _data["cu_seqlens"]
|
||||
|
||||
cu_seqlens = cu_seqlens.squeeze(0)
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
_data.update({"max_seqlen": max_seqlen})
|
||||
self.cal_max_seqlen(_data)
|
||||
|
||||
return _data, _label
|
||||
|
||||
|
|
|
@ -222,17 +222,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
micro_batch_data.pop("cu_seqlens")
|
||||
micro_batch_data.pop("indexes")
|
||||
|
||||
if "cu_seqlens" in micro_batch_data:
|
||||
# Without BC modeling interface, we try to calculate 'max_seqlen' in advance
|
||||
# to avoid overlap being interrupted by .item() operations.
|
||||
if isinstance(micro_batch_data["cu_seqlens"], list):
|
||||
cu_seqlens = micro_batch_data["cu_seqlens"][0]
|
||||
else:
|
||||
cu_seqlens = micro_batch_data["cu_seqlens"]
|
||||
|
||||
cu_seqlens = cu_seqlens.squeeze(0)
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
micro_batch_data.update({"max_seqlen": max_seqlen})
|
||||
self.cal_max_seqlen(micro_batch_data)
|
||||
|
||||
micro_batch_data["label"] = micro_batch_label
|
||||
self.microbatch_offset += self.bsz_stride
|
||||
|
|
Loading…
Reference in New Issue