diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index 79a6f62..8af8180 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -81,6 +81,16 @@ class NonPipelineScheduler(BaseScheduler): _data.pop("cu_seqlens") _data.pop("indexes") + if "cu_seqlens" in _data: + if isinstance(_data["cu_seqlens"], list): + cu_seqlens = _data["cu_seqlens"][0] + else: + cu_seqlens = _data["cu_seqlens"] + + cu_seqlens = cu_seqlens.squeeze(0) + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + _data.update({"max_seqlen": max_seqlen}) + return _data, _label def _train_one_batch( diff --git a/internlm/model/embedding.py b/internlm/model/embedding.py index d177053..751bce5 100644 --- a/internlm/model/embedding.py +++ b/internlm/model/embedding.py @@ -153,12 +153,8 @@ class RotaryEmbedding(torch.nn.Module): self._cos_k_cached = None self._sin_k_cached = None - def _update_cos_sin_cache(self, x, indexes): + def _update_cos_sin_cache(self, x, seqlen): """x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)""" - if not isinstance(indexes, int): - seqlen = indexes.max().item() + 1 - else: - seqlen = indexes + 1 # eval_forward # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) if seqlen > self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype: @@ -189,8 +185,14 @@ class RotaryEmbedding(torch.nn.Module): else: return self._eval_forward(qkv) - def _forward(self, qkv: torch.Tensor, indexes=0) -> Tuple[torch.Tensor, torch.Tensor]: - self._update_cos_sin_cache(qkv, indexes) + def _forward(self, qkv: torch.Tensor, indexes=0, seqlen=None) -> Tuple[torch.Tensor, torch.Tensor]: + if not isinstance(indexes, int): + if seqlen is None: # We try to avoid trying item calls in fwd and bwd. + seqlen = indexes.max().item() + 1 + else: + seqlen = indexes + 1 # eval_forward + + self._update_cos_sin_cache(qkv, seqlen) if self.scale is None: return apply_rotary_emb_qkv_(qkv, self._cos_cached[indexes], self._sin_cached[indexes]) else: @@ -275,12 +277,8 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype) self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype) - def _update_cos_sin_cache(self, x, indexes): + def _update_cos_sin_cache(self, x, seqlen): """x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)""" - if not isinstance(indexes, int): - seqlen = indexes.max().item() + 1 - else: - seqlen = indexes + 1 # eval_forward if seqlen <= self.max_position_embeddings: # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index a47a5cd..b1af14f 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -3,7 +3,7 @@ import math from functools import wraps -from typing import Optional +from typing import Optional, Union import torch from flash_attn.modules.embedding import ParallelGPT2Embeddings @@ -380,7 +380,15 @@ class PackedFlashInternLm1D(nn.Module): self.parallel_output = parallel_output - def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None): + def forward( + self, + hidden_states=None, + cu_seqlens=None, + input_ids=None, + indexes=None, + inference_params=None, + max_seqlen: Optional[Union[int, None]] = None, + ): # attention_mask: compute attention on the places where the value is 1 if hasattr(self, "embedding"): hidden_states = self.embedding(input_ids) @@ -401,7 +409,9 @@ class PackedFlashInternLm1D(nn.Module): assert len(indexes) == 1 # The indexes are used to indicate the actual position IDs of each token in the packed input. indexes = indexes[0] - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None + + if cu_seqlens is not None and max_seqlen is None: + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() for _, block in enumerate(self.blocks): hidden_states = block( diff --git a/internlm/utils/common.py b/internlm/utils/common.py index a20b61d..d162cc8 100644 --- a/internlm/utils/common.py +++ b/internlm/utils/common.py @@ -104,10 +104,10 @@ def get_batch_size(data): return data.size(0) elif isinstance(data, (list, tuple)): if isinstance(data[0], dict): - return data[0][list(data[0].keys())[0]].size(0) + return data[0]["input_ids"].size(0) return data[0].size(0) elif isinstance(data, dict): - return data[list(data.keys())[0]].size(0) + return data["input_ids"].size(0) def check_data_is_packed(data):