feat: avoid calling item() in fwd/bwd

pull/564/head
877825076@qq.com 2023-12-28 19:12:02 +08:00
parent d418eba094
commit 70e84a21f8
4 changed files with 35 additions and 17 deletions

View File

@ -81,6 +81,16 @@ class NonPipelineScheduler(BaseScheduler):
_data.pop("cu_seqlens") _data.pop("cu_seqlens")
_data.pop("indexes") _data.pop("indexes")
if "cu_seqlens" in _data:
if isinstance(_data["cu_seqlens"], list):
cu_seqlens = _data["cu_seqlens"][0]
else:
cu_seqlens = _data["cu_seqlens"]
cu_seqlens = cu_seqlens.squeeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
_data.update({"max_seqlen": max_seqlen})
return _data, _label return _data, _label
def _train_one_batch( def _train_one_batch(

View File

@ -153,12 +153,8 @@ class RotaryEmbedding(torch.nn.Module):
self._cos_k_cached = None self._cos_k_cached = None
self._sin_k_cached = None self._sin_k_cached = None
def _update_cos_sin_cache(self, x, indexes): def _update_cos_sin_cache(self, x, seqlen):
"""x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)""" """x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)"""
if not isinstance(indexes, int):
seqlen = indexes.max().item() + 1
else:
seqlen = indexes + 1 # eval_forward
# Reset the tables if the sequence length has changed, # Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance) # or if we're on a new device (possibly due to tracing for instance)
if seqlen > self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype: if seqlen > self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype:
@ -189,8 +185,14 @@ class RotaryEmbedding(torch.nn.Module):
else: else:
return self._eval_forward(qkv) return self._eval_forward(qkv)
def _forward(self, qkv: torch.Tensor, indexes=0) -> Tuple[torch.Tensor, torch.Tensor]: def _forward(self, qkv: torch.Tensor, indexes=0, seqlen=None) -> Tuple[torch.Tensor, torch.Tensor]:
self._update_cos_sin_cache(qkv, indexes) if not isinstance(indexes, int):
if seqlen is None: # We try to avoid trying item calls in fwd and bwd.
seqlen = indexes.max().item() + 1
else:
seqlen = indexes + 1 # eval_forward
self._update_cos_sin_cache(qkv, seqlen)
if self.scale is None: if self.scale is None:
return apply_rotary_emb_qkv_(qkv, self._cos_cached[indexes], self._sin_cached[indexes]) return apply_rotary_emb_qkv_(qkv, self._cos_cached[indexes], self._sin_cached[indexes])
else: else:
@ -275,12 +277,8 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype) self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype) self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
def _update_cos_sin_cache(self, x, indexes): def _update_cos_sin_cache(self, x, seqlen):
"""x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)""" """x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)"""
if not isinstance(indexes, int):
seqlen = indexes.max().item() + 1
else:
seqlen = indexes + 1 # eval_forward
if seqlen <= self.max_position_embeddings: if seqlen <= self.max_position_embeddings:
# Reset the tables if the sequence length has changed, # Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance) # or if we're on a new device (possibly due to tracing for instance)

View File

@ -3,7 +3,7 @@
import math import math
from functools import wraps from functools import wraps
from typing import Optional from typing import Optional, Union
import torch import torch
from flash_attn.modules.embedding import ParallelGPT2Embeddings from flash_attn.modules.embedding import ParallelGPT2Embeddings
@ -380,7 +380,15 @@ class PackedFlashInternLm1D(nn.Module):
self.parallel_output = parallel_output self.parallel_output = parallel_output
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None): def forward(
self,
hidden_states=None,
cu_seqlens=None,
input_ids=None,
indexes=None,
inference_params=None,
max_seqlen: Optional[Union[int, None]] = None,
):
# attention_mask: compute attention on the places where the value is 1 # attention_mask: compute attention on the places where the value is 1
if hasattr(self, "embedding"): if hasattr(self, "embedding"):
hidden_states = self.embedding(input_ids) hidden_states = self.embedding(input_ids)
@ -401,7 +409,9 @@ class PackedFlashInternLm1D(nn.Module):
assert len(indexes) == 1 assert len(indexes) == 1
# The indexes are used to indicate the actual position IDs of each token in the packed input. # The indexes are used to indicate the actual position IDs of each token in the packed input.
indexes = indexes[0] indexes = indexes[0]
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None
if cu_seqlens is not None and max_seqlen is None:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
for _, block in enumerate(self.blocks): for _, block in enumerate(self.blocks):
hidden_states = block( hidden_states = block(

View File

@ -104,10 +104,10 @@ def get_batch_size(data):
return data.size(0) return data.size(0)
elif isinstance(data, (list, tuple)): elif isinstance(data, (list, tuple)):
if isinstance(data[0], dict): if isinstance(data[0], dict):
return data[0][list(data[0].keys())[0]].size(0) return data[0]["input_ids"].size(0)
return data[0].size(0) return data[0].size(0)
elif isinstance(data, dict): elif isinstance(data, dict):
return data[list(data.keys())[0]].size(0) return data["input_ids"].size(0)
def check_data_is_packed(data): def check_data_is_packed(data):