mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] vit test finish and support
parent
f60162b265
commit
c49286985d
|
@ -56,7 +56,6 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||||
@parameterize('enable_tensor_parallelism', [True, False])
|
@parameterize('enable_tensor_parallelism', [True, False])
|
||||||
def run_vit_test(enable_fused_normalization, enable_tensor_parallelism):
|
def run_vit_test(enable_fused_normalization, enable_tensor_parallelism):
|
||||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
|
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
|
||||||
print(sub_model_zoo)
|
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
|
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
|
||||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||||
|
|
Loading…
Reference in New Issue