[fix] fix test_shard_llama ci;

pull/6083/head
duanjunwen 2024-10-28 02:42:33 +00:00
parent 03fa79a55c
commit 6377aa0fff
2 changed files with 1 additions and 2 deletions

View File

@ -82,7 +82,7 @@ class LlamaPipelineForwards:
elif input_ids is not None:
batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None:
batch_size, seq_length = inputs_embeds.shape[:2]
batch_size, seq_length, _ = inputs_embeds.shape[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:

View File

@ -325,7 +325,6 @@ def run_llama_test(test_config):
).get_v_schedule()
test_config["scheduler_nodes"] = scheduler_nodes
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
print(f"name {name}")
if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and "causal" not in name:
continue
try: