mirror of https://github.com/InternLM/InternLM
fix
parent
70e84a21f8
commit
06ececeb00
|
@ -3,7 +3,7 @@
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Optional, Union
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from flash_attn.modules.embedding import ParallelGPT2Embeddings
|
from flash_attn.modules.embedding import ParallelGPT2Embeddings
|
||||||
|
@ -381,13 +381,7 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
self.parallel_output = parallel_output
|
self.parallel_output = parallel_output
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None, **kwargs
|
||||||
hidden_states=None,
|
|
||||||
cu_seqlens=None,
|
|
||||||
input_ids=None,
|
|
||||||
indexes=None,
|
|
||||||
inference_params=None,
|
|
||||||
max_seqlen: Optional[Union[int, None]] = None,
|
|
||||||
):
|
):
|
||||||
# attention_mask: compute attention on the places where the value is 1
|
# attention_mask: compute attention on the places where the value is 1
|
||||||
if hasattr(self, "embedding"):
|
if hasattr(self, "embedding"):
|
||||||
|
@ -410,8 +404,13 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
# The indexes are used to indicate the actual position IDs of each token in the packed input.
|
# The indexes are used to indicate the actual position IDs of each token in the packed input.
|
||||||
indexes = indexes[0]
|
indexes = indexes[0]
|
||||||
|
|
||||||
if cu_seqlens is not None and max_seqlen is None:
|
if cu_seqlens is not None:
|
||||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
if "max_seqlen" not in kwargs:
|
||||||
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||||
|
else:
|
||||||
|
max_seqlen = kwargs.pop("max_seqlen")
|
||||||
|
else:
|
||||||
|
max_seqlen = None
|
||||||
|
|
||||||
for _, block in enumerate(self.blocks):
|
for _, block in enumerate(self.blocks):
|
||||||
hidden_states = block(
|
hidden_states = block(
|
||||||
|
|
Loading…
Reference in New Issue