Browse Source

fix

pull/6023/head
wangbluo 3 months ago
parent
commit
2e4cbe3a2d
  1. 12
      tests/test_shardformer/test_model/test_shard_llama.py

12
tests/test_shardformer/test_model/test_shard_llama.py

@ -153,8 +153,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize( @parameterize(
"test_config", "test_config",
[ [
{ # Test ring + Flash attention # Double Ring Attention
"tp_size": 2, {
"tp_size": 1,
"pp_size": 1, "pp_size": 1,
"sp_size": 4, "sp_size": 4,
"num_microbatches": 1, "num_microbatches": 1,
@ -166,19 +167,20 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"initial_scale": 1, "initial_scale": 1,
"inner_ring_size": 2, "inner_ring_size": 2,
}, },
{ # Ulysess + Flash attention # Ring Attention + PP
{
"tp_size": 1, "tp_size": 1,
"pp_size": 2, "pp_size": 2,
"sp_size": 2, "sp_size": 2,
"num_microbatches": 2, "num_microbatches": 2,
"enable_sequence_parallelism": True, "enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all", "sequence_parallelism_mode": "ring_attn",
"enable_flash_attention": True,
"use_lazy_init": True, "use_lazy_init": True,
"zero_stage": 1, "zero_stage": 1,
"precision": "fp16", "precision": "fp16",
"initial_scale": 1, "initial_scale": 1,
}, },
# Ring Attention + TP
{ {
"tp_size": 2, "tp_size": 2,
"pp_size": 1, "pp_size": 1,

Loading…
Cancel
Save