pull/564/head
877825076@qq.com 2023-12-28 19:34:19 +08:00
parent 70e84a21f8
commit 06ececeb00
1 changed files with 9 additions and 10 deletions

View File

@ -3,7 +3,7 @@
import math
from functools import wraps
from typing import Optional, Union
from typing import Optional
import torch
from flash_attn.modules.embedding import ParallelGPT2Embeddings
@ -381,13 +381,7 @@ class PackedFlashInternLm1D(nn.Module):
self.parallel_output = parallel_output
def forward(
self,
hidden_states=None,
cu_seqlens=None,
input_ids=None,
indexes=None,
inference_params=None,
max_seqlen: Optional[Union[int, None]] = None,
self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None, **kwargs
):
# attention_mask: compute attention on the places where the value is 1
if hasattr(self, "embedding"):
@ -410,8 +404,13 @@ class PackedFlashInternLm1D(nn.Module):
# The indexes are used to indicate the actual position IDs of each token in the packed input.
indexes = indexes[0]
if cu_seqlens is not None and max_seqlen is None:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
if cu_seqlens is not None:
if "max_seqlen" not in kwargs:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
else:
max_seqlen = kwargs.pop("max_seqlen")
else:
max_seqlen = None
for _, block in enumerate(self.blocks):
hidden_states = block(