Browse Source

fix

pull/6023/head
wangbluo 3 months ago
parent
commit
2e4cbe3a2d
  1. 12
      tests/test_shardformer/test_model/test_shard_llama.py

12
tests/test_shardformer/test_model/test_shard_llama.py

@ -153,8 +153,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize(
"test_config",
[
{ # Test ring + Flash attention
"tp_size": 2,
# Double Ring Attention
{
"tp_size": 1,
"pp_size": 1,
"sp_size": 4,
"num_microbatches": 1,
@ -166,19 +167,20 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"initial_scale": 1,
"inner_ring_size": 2,
},
{ # Ulysess + Flash attention
# Ring Attention + PP
{
"tp_size": 1,
"pp_size": 2,
"sp_size": 2,
"num_microbatches": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"enable_flash_attention": True,
"sequence_parallelism_mode": "ring_attn",
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
# Ring Attention + TP
{
"tp_size": 2,
"pp_size": 1,

Loading…
Cancel
Save