[shardformer] added tests

pull/4445/head
klhhhhh 2023-07-04 14:35:55 +08:00 committed by Hongxin Liu
parent ed34bb1310
commit f60162b265
1 changed files with 1 additions and 0 deletions

View File

@ -56,6 +56,7 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
@parameterize('enable_tensor_parallelism', [True, False])
def run_vit_test(enable_fused_normalization, enable_tensor_parallelism):
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
print(sub_model_zoo)
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)