diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 85be9a242..f5fc21b4c 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -181,6 +181,7 @@ def run_forward_backward_with_hybrid_plugin( data_iter = iter([data]) sharded_output = booster.execute_pipeline( + data_iter, sharded_model, _criterion, sharded_optimizer, return_loss=True, return_outputs=True data_iter, sharded_model, _criterion,