[fix] fix test zerobubble

pull/6083/head
duanjunwen 4 weeks ago
parent 6377aa0fff
commit 5aee4261a6

@ -82,7 +82,7 @@ class LlamaPipelineForwards:
elif input_ids is not None:
batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape[:2]
batch_size, seq_length = inputs_embeds.shape[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:

Loading…
Cancel
Save