[hotfix] Fix ShardFormer test execution path when using sequence parallelism (#5230)

pull/5278/head^2
Zhongkai Zhao 2024-01-17 17:42:29 +08:00 committed by GitHub
parent 46e091651b
commit 5d9a0ae75b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 1 additions and 1 deletions

View File

@ -154,7 +154,7 @@ def run_forward_backward_with_hybrid_plugin(
data = data_gen_fn()
if booster.plugin.enable_sequence_parallelism and booster.plugin.tp_size != 0:
if booster.plugin.shard_config.enable_sequence_parallelism and booster.plugin.tp_size != 0:
seq_len = data["input_ids"].shape[-1]
lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len)
times = lcm // seq_len