feat: avoid calling item() in fwd/bwd

pull/564/head
877825076@qq.com 2023-12-28 19:12:02 +08:00
parent d418eba094
commit 70e84a21f8
4 changed files with 35 additions and 17 deletions

View File

@ -81,6 +81,16 @@ class NonPipelineScheduler(BaseScheduler):
_data.pop("cu_seqlens")
_data.pop("indexes")
if "cu_seqlens" in _data:
if isinstance(_data["cu_seqlens"], list):
cu_seqlens = _data["cu_seqlens"][0]
else:
cu_seqlens = _data["cu_seqlens"]
cu_seqlens = cu_seqlens.squeeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
_data.update({"max_seqlen": max_seqlen})
return _data, _label
def _train_one_batch(

View File

@ -153,12 +153,8 @@ class RotaryEmbedding(torch.nn.Module):
self._cos_k_cached = None
self._sin_k_cached = None
def _update_cos_sin_cache(self, x, indexes):
def _update_cos_sin_cache(self, x, seqlen):
"""x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)"""
if not isinstance(indexes, int):
seqlen = indexes.max().item() + 1
else:
seqlen = indexes + 1 # eval_forward
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if seqlen > self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype:
@ -189,8 +185,14 @@ class RotaryEmbedding(torch.nn.Module):
else:
return self._eval_forward(qkv)
def _forward(self, qkv: torch.Tensor, indexes=0) -> Tuple[torch.Tensor, torch.Tensor]:
self._update_cos_sin_cache(qkv, indexes)
def _forward(self, qkv: torch.Tensor, indexes=0, seqlen=None) -> Tuple[torch.Tensor, torch.Tensor]:
if not isinstance(indexes, int):
if seqlen is None: # We try to avoid trying item calls in fwd and bwd.
seqlen = indexes.max().item() + 1
else:
seqlen = indexes + 1 # eval_forward
self._update_cos_sin_cache(qkv, seqlen)
if self.scale is None:
return apply_rotary_emb_qkv_(qkv, self._cos_cached[indexes], self._sin_cached[indexes])
else:
@ -275,12 +277,8 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
def _update_cos_sin_cache(self, x, indexes):
def _update_cos_sin_cache(self, x, seqlen):
"""x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)"""
if not isinstance(indexes, int):
seqlen = indexes.max().item() + 1
else:
seqlen = indexes + 1 # eval_forward
if seqlen <= self.max_position_embeddings:
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)

View File

@ -3,7 +3,7 @@
import math
from functools import wraps
from typing import Optional
from typing import Optional, Union
import torch
from flash_attn.modules.embedding import ParallelGPT2Embeddings
@ -380,7 +380,15 @@ class PackedFlashInternLm1D(nn.Module):
self.parallel_output = parallel_output
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
def forward(
self,
hidden_states=None,
cu_seqlens=None,
input_ids=None,
indexes=None,
inference_params=None,
max_seqlen: Optional[Union[int, None]] = None,
):
# attention_mask: compute attention on the places where the value is 1
if hasattr(self, "embedding"):
hidden_states = self.embedding(input_ids)
@ -401,7 +409,9 @@ class PackedFlashInternLm1D(nn.Module):
assert len(indexes) == 1
# The indexes are used to indicate the actual position IDs of each token in the packed input.
indexes = indexes[0]
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None
if cu_seqlens is not None and max_seqlen is None:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
for _, block in enumerate(self.blocks):
hidden_states = block(

View File

@ -104,10 +104,10 @@ def get_batch_size(data):
return data.size(0)
elif isinstance(data, (list, tuple)):
if isinstance(data[0], dict):
return data[0][list(data[0].keys())[0]].size(0)
return data[0]["input_ids"].size(0)
return data[0].size(0)
elif isinstance(data, dict):
return data[list(data.keys())[0]].size(0)
return data["input_ids"].size(0)
def check_data_is_packed(data):