mirror of https://github.com/InternLM/InternLM
fix
parent
70e84a21f8
commit
06ececeb00
|
@ -3,7 +3,7 @@
|
|||
|
||||
import math
|
||||
from functools import wraps
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from flash_attn.modules.embedding import ParallelGPT2Embeddings
|
||||
|
@ -381,13 +381,7 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
self.parallel_output = parallel_output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states=None,
|
||||
cu_seqlens=None,
|
||||
input_ids=None,
|
||||
indexes=None,
|
||||
inference_params=None,
|
||||
max_seqlen: Optional[Union[int, None]] = None,
|
||||
self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None, **kwargs
|
||||
):
|
||||
# attention_mask: compute attention on the places where the value is 1
|
||||
if hasattr(self, "embedding"):
|
||||
|
@ -410,8 +404,13 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
# The indexes are used to indicate the actual position IDs of each token in the packed input.
|
||||
indexes = indexes[0]
|
||||
|
||||
if cu_seqlens is not None and max_seqlen is None:
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
if cu_seqlens is not None:
|
||||
if "max_seqlen" not in kwargs:
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
else:
|
||||
max_seqlen = kwargs.pop("max_seqlen")
|
||||
else:
|
||||
max_seqlen = None
|
||||
|
||||
for _, block in enumerate(self.blocks):
|
||||
hidden_states = block(
|
||||
|
|
Loading…
Reference in New Issue