From f60162b2657a18af1468bd172835828787d23c17 Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Tue, 4 Jul 2023 14:35:55 +0800 Subject: [PATCH] [shardformer] added tests --- tests/test_shardformer/test_model/test_shard_vit.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index 2b02c83e0..c1126cb2c 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -56,6 +56,7 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_tensor_parallelism', [True, False]) def run_vit_test(enable_fused_normalization, enable_tensor_parallelism): sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') + print(sub_model_zoo) for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)