mirror of https://github.com/InternLM/InternLM
feat: avoid calling item() in fwd/bwd
parent
d418eba094
commit
70e84a21f8
|
@ -81,6 +81,16 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
_data.pop("cu_seqlens")
|
||||
_data.pop("indexes")
|
||||
|
||||
if "cu_seqlens" in _data:
|
||||
if isinstance(_data["cu_seqlens"], list):
|
||||
cu_seqlens = _data["cu_seqlens"][0]
|
||||
else:
|
||||
cu_seqlens = _data["cu_seqlens"]
|
||||
|
||||
cu_seqlens = cu_seqlens.squeeze(0)
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
_data.update({"max_seqlen": max_seqlen})
|
||||
|
||||
return _data, _label
|
||||
|
||||
def _train_one_batch(
|
||||
|
|
|
@ -153,12 +153,8 @@ class RotaryEmbedding(torch.nn.Module):
|
|||
self._cos_k_cached = None
|
||||
self._sin_k_cached = None
|
||||
|
||||
def _update_cos_sin_cache(self, x, indexes):
|
||||
def _update_cos_sin_cache(self, x, seqlen):
|
||||
"""x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)"""
|
||||
if not isinstance(indexes, int):
|
||||
seqlen = indexes.max().item() + 1
|
||||
else:
|
||||
seqlen = indexes + 1 # eval_forward
|
||||
# Reset the tables if the sequence length has changed,
|
||||
# or if we're on a new device (possibly due to tracing for instance)
|
||||
if seqlen > self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype:
|
||||
|
@ -189,8 +185,14 @@ class RotaryEmbedding(torch.nn.Module):
|
|||
else:
|
||||
return self._eval_forward(qkv)
|
||||
|
||||
def _forward(self, qkv: torch.Tensor, indexes=0) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
self._update_cos_sin_cache(qkv, indexes)
|
||||
def _forward(self, qkv: torch.Tensor, indexes=0, seqlen=None) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if not isinstance(indexes, int):
|
||||
if seqlen is None: # We try to avoid trying item calls in fwd and bwd.
|
||||
seqlen = indexes.max().item() + 1
|
||||
else:
|
||||
seqlen = indexes + 1 # eval_forward
|
||||
|
||||
self._update_cos_sin_cache(qkv, seqlen)
|
||||
if self.scale is None:
|
||||
return apply_rotary_emb_qkv_(qkv, self._cos_cached[indexes], self._sin_cached[indexes])
|
||||
else:
|
||||
|
@ -275,12 +277,8 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
|
|||
self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
|
||||
self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
|
||||
|
||||
def _update_cos_sin_cache(self, x, indexes):
|
||||
def _update_cos_sin_cache(self, x, seqlen):
|
||||
"""x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)"""
|
||||
if not isinstance(indexes, int):
|
||||
seqlen = indexes.max().item() + 1
|
||||
else:
|
||||
seqlen = indexes + 1 # eval_forward
|
||||
if seqlen <= self.max_position_embeddings:
|
||||
# Reset the tables if the sequence length has changed,
|
||||
# or if we're on a new device (possibly due to tracing for instance)
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
import math
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from flash_attn.modules.embedding import ParallelGPT2Embeddings
|
||||
|
@ -380,7 +380,15 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
|
||||
self.parallel_output = parallel_output
|
||||
|
||||
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states=None,
|
||||
cu_seqlens=None,
|
||||
input_ids=None,
|
||||
indexes=None,
|
||||
inference_params=None,
|
||||
max_seqlen: Optional[Union[int, None]] = None,
|
||||
):
|
||||
# attention_mask: compute attention on the places where the value is 1
|
||||
if hasattr(self, "embedding"):
|
||||
hidden_states = self.embedding(input_ids)
|
||||
|
@ -401,7 +409,9 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
assert len(indexes) == 1
|
||||
# The indexes are used to indicate the actual position IDs of each token in the packed input.
|
||||
indexes = indexes[0]
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None
|
||||
|
||||
if cu_seqlens is not None and max_seqlen is None:
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
|
||||
for _, block in enumerate(self.blocks):
|
||||
hidden_states = block(
|
||||
|
|
|
@ -104,10 +104,10 @@ def get_batch_size(data):
|
|||
return data.size(0)
|
||||
elif isinstance(data, (list, tuple)):
|
||||
if isinstance(data[0], dict):
|
||||
return data[0][list(data[0].keys())[0]].size(0)
|
||||
return data[0]["input_ids"].size(0)
|
||||
return data[0].size(0)
|
||||
elif isinstance(data, dict):
|
||||
return data[list(data.keys())[0]].size(0)
|
||||
return data["input_ids"].size(0)
|
||||
|
||||
|
||||
def check_data_is_packed(data):
|
||||
|
|
Loading…
Reference in New Issue