[fix] fix llama modeling policy;

pull/6083/head
duanjunwen 2024-10-25 10:17:06 +00:00
parent cc0dfddcbc
commit 03fa79a55c
2 changed files with 3 additions and 1 deletions

View File

@ -96,7 +96,8 @@ class LlamaPolicy(Policy):
target_key=attn_cls,
)
if self.pipeline_stage_manager is not None:
# if self.pipeline_stage_manager is not None:
if self.pipeline_stage_manager is None:
self.append_or_create_method_replacement(
description={
"forward": partial(

View File

@ -325,6 +325,7 @@ def run_llama_test(test_config):
).get_v_schedule()
test_config["scheduler_nodes"] = scheduler_nodes
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
print(f"name {name}")
if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and "causal" not in name:
continue
try: