From 2e4cbe3a2d1dc60cfe44b99ce5a5b68433fe6b97 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Tue, 20 Aug 2024 09:11:02 +0000 Subject: [PATCH] fix --- .../test_shardformer/test_model/test_shard_llama.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index a91ffd00d..3c66f6097 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -153,8 +153,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ - { # Test ring + Flash attention - "tp_size": 2, + # Double Ring Attention + { + "tp_size": 1, "pp_size": 1, "sp_size": 4, "num_microbatches": 1, @@ -166,19 +167,20 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "initial_scale": 1, "inner_ring_size": 2, }, - { # Ulysess + Flash attention + # Ring Attention + PP + { "tp_size": 1, "pp_size": 2, "sp_size": 2, "num_microbatches": 2, "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "all_to_all", - "enable_flash_attention": True, + "sequence_parallelism_mode": "ring_attn", "use_lazy_init": True, "zero_stage": 1, "precision": "fp16", "initial_scale": 1, }, + # Ring Attention + TP { "tp_size": 2, "pp_size": 1,