mirror of https://github.com/InternLM/InternLM
fix
parent
1d217eb94e
commit
e9208728cb
|
@ -448,7 +448,9 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
setattr(param, IS_TENSOR_PARALLEL, True)
|
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||||
self.parallel_output = parallel_output
|
self.parallel_output = parallel_output
|
||||||
|
|
||||||
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
|
def forward(
|
||||||
|
self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None, **kwargs
|
||||||
|
):
|
||||||
# attention_mask: compute attention on the places where the value is 1
|
# attention_mask: compute attention on the places where the value is 1
|
||||||
# old condition may fail when use shared embedding
|
# old condition may fail when use shared embedding
|
||||||
if gpc.is_pipeline_first_stage():
|
if gpc.is_pipeline_first_stage():
|
||||||
|
@ -470,7 +472,14 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
assert len(indexes) == 1
|
assert len(indexes) == 1
|
||||||
# The indexes are used to indicate the actual position IDs of each token in the packed input.
|
# The indexes are used to indicate the actual position IDs of each token in the packed input.
|
||||||
indexes = indexes[0]
|
indexes = indexes[0]
|
||||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None
|
|
||||||
|
if cu_seqlens is not None:
|
||||||
|
if "max_seqlen" not in kwargs:
|
||||||
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||||
|
else:
|
||||||
|
max_seqlen = kwargs.pop("max_seqlen")
|
||||||
|
else:
|
||||||
|
max_seqlen = None
|
||||||
|
|
||||||
moe_losses = []
|
moe_losses = []
|
||||||
for _, block in enumerate(self.blocks):
|
for _, block in enumerate(self.blocks):
|
||||||
|
|
|
@ -33,7 +33,7 @@ class MlpModel(nn.Module):
|
||||||
self.embedding = embedding
|
self.embedding = embedding
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None
|
self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None, **kwargs
|
||||||
): # pylint: disable=W0613
|
): # pylint: disable=W0613
|
||||||
if self.model_type != "torch" and self.part[0] != 0:
|
if self.model_type != "torch" and self.part[0] != 0:
|
||||||
input_ids = hidden_states
|
input_ids = hidden_states
|
||||||
|
|
Loading…
Reference in New Issue