[fix] fix test zerobubble

pull/6083/head
duanjunwen 4 weeks ago
parent 6377aa0fff
commit 5aee4261a6

@ -82,7 +82,7 @@ class LlamaPipelineForwards:
elif input_ids is not None: elif input_ids is not None:
batch_size, seq_length = input_ids.shape[:2] batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None: elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape[:2] batch_size, seq_length = inputs_embeds.shape[:2]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None: if inputs_embeds is None:

Loading…
Cancel
Save