[shardformer] vit test finish and support

pull/4445/head
klhhhhh 2023-07-06 10:59:42 +08:00 committed by Hongxin Liu
parent f60162b265
commit c49286985d
1 changed files with 0 additions and 1 deletions

View File

@ -56,7 +56,6 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
@parameterize('enable_tensor_parallelism', [True, False])
def run_vit_test(enable_fused_normalization, enable_tensor_parallelism):
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
print(sub_model_zoo)
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)