mirror of https://github.com/InternLM/InternLM
fix
parent
1d217eb94e
commit
e9208728cb
|
@ -448,7 +448,9 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||
self.parallel_output = parallel_output
|
||||
|
||||
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
|
||||
def forward(
|
||||
self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None, **kwargs
|
||||
):
|
||||
# attention_mask: compute attention on the places where the value is 1
|
||||
# old condition may fail when use shared embedding
|
||||
if gpc.is_pipeline_first_stage():
|
||||
|
@ -470,7 +472,14 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
assert len(indexes) == 1
|
||||
# The indexes are used to indicate the actual position IDs of each token in the packed input.
|
||||
indexes = indexes[0]
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None
|
||||
|
||||
if cu_seqlens is not None:
|
||||
if "max_seqlen" not in kwargs:
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
else:
|
||||
max_seqlen = kwargs.pop("max_seqlen")
|
||||
else:
|
||||
max_seqlen = None
|
||||
|
||||
moe_losses = []
|
||||
for _, block in enumerate(self.blocks):
|
||||
|
|
|
@ -33,7 +33,7 @@ class MlpModel(nn.Module):
|
|||
self.embedding = embedding
|
||||
|
||||
def forward(
|
||||
self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None
|
||||
self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None, **kwargs
|
||||
): # pylint: disable=W0613
|
||||
if self.model_type != "torch" and self.part[0] != 0:
|
||||
input_ids = hidden_states
|
||||
|
|
Loading…
Reference in New Issue