pull/564/head
877825076@qq.com 2023-12-29 16:47:19 +08:00
parent 1d217eb94e
commit e9208728cb
2 changed files with 12 additions and 3 deletions

View File

@ -448,7 +448,9 @@ class PackedFlashInternLm1D(nn.Module):
setattr(param, IS_TENSOR_PARALLEL, True)
self.parallel_output = parallel_output
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
def forward(
self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None, **kwargs
):
# attention_mask: compute attention on the places where the value is 1
# old condition may fail when use shared embedding
if gpc.is_pipeline_first_stage():
@ -470,7 +472,14 @@ class PackedFlashInternLm1D(nn.Module):
assert len(indexes) == 1
# The indexes are used to indicate the actual position IDs of each token in the packed input.
indexes = indexes[0]
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None
if cu_seqlens is not None:
if "max_seqlen" not in kwargs:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
else:
max_seqlen = kwargs.pop("max_seqlen")
else:
max_seqlen = None
moe_losses = []
for _, block in enumerate(self.blocks):

View File

@ -33,7 +33,7 @@ class MlpModel(nn.Module):
self.embedding = embedding
def forward(
self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None
self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None, **kwargs
): # pylint: disable=W0613
if self.model_type != "torch" and self.part[0] != 0:
input_ids = hidden_states