|
|
@ -153,8 +153,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, |
|
|
|
@parameterize( |
|
|
|
@parameterize( |
|
|
|
"test_config", |
|
|
|
"test_config", |
|
|
|
[ |
|
|
|
[ |
|
|
|
{ # Test ring + Flash attention |
|
|
|
# Double Ring Attention |
|
|
|
"tp_size": 2, |
|
|
|
{ |
|
|
|
|
|
|
|
"tp_size": 1, |
|
|
|
"pp_size": 1, |
|
|
|
"pp_size": 1, |
|
|
|
"sp_size": 4, |
|
|
|
"sp_size": 4, |
|
|
|
"num_microbatches": 1, |
|
|
|
"num_microbatches": 1, |
|
|
@ -166,19 +167,20 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, |
|
|
|
"initial_scale": 1, |
|
|
|
"initial_scale": 1, |
|
|
|
"inner_ring_size": 2, |
|
|
|
"inner_ring_size": 2, |
|
|
|
}, |
|
|
|
}, |
|
|
|
{ # Ulysess + Flash attention |
|
|
|
# Ring Attention + PP |
|
|
|
|
|
|
|
{ |
|
|
|
"tp_size": 1, |
|
|
|
"tp_size": 1, |
|
|
|
"pp_size": 2, |
|
|
|
"pp_size": 2, |
|
|
|
"sp_size": 2, |
|
|
|
"sp_size": 2, |
|
|
|
"num_microbatches": 2, |
|
|
|
"num_microbatches": 2, |
|
|
|
"enable_sequence_parallelism": True, |
|
|
|
"enable_sequence_parallelism": True, |
|
|
|
"sequence_parallelism_mode": "all_to_all", |
|
|
|
"sequence_parallelism_mode": "ring_attn", |
|
|
|
"enable_flash_attention": True, |
|
|
|
|
|
|
|
"use_lazy_init": True, |
|
|
|
"use_lazy_init": True, |
|
|
|
"zero_stage": 1, |
|
|
|
"zero_stage": 1, |
|
|
|
"precision": "fp16", |
|
|
|
"precision": "fp16", |
|
|
|
"initial_scale": 1, |
|
|
|
"initial_scale": 1, |
|
|
|
}, |
|
|
|
}, |
|
|
|
|
|
|
|
# Ring Attention + TP |
|
|
|
{ |
|
|
|
{ |
|
|
|
"tp_size": 2, |
|
|
|
"tp_size": 2, |
|
|
|
"pp_size": 1, |
|
|
|
"pp_size": 1, |
|
|
|