[booster] removed models that don't support fsdp (#3744)

Co-authored-by: 纪少敏 <jishaomin@jishaomindeMBP.lan>
pull/3748/head
wukong1992 2023-05-15 19:35:21 +08:00 committed by GitHub
parent afb239bbf8
commit 6050f37776
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 7 deletions

View File

@ -46,7 +46,10 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn):
def check_torch_fsdp_plugin(): def check_torch_fsdp_plugin():
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
if 'diffusers' in name: if any(element in name for element in [
'diffusers', 'deepfm_sparsearch', 'dlrm_interactionarch', 'torchvision_googlenet',
'torchvision_inception_v3'
]):
continue continue
run_fn(model_fn, data_gen_fn, output_transform_fn) run_fn(model_fn, data_gen_fn, output_transform_fn)
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -58,12 +61,6 @@ def run_dist(rank, world_size, port):
check_torch_fsdp_plugin() check_torch_fsdp_plugin()
# FIXME: this test is not working
@pytest.mark.skip(
"ValueError: expected to be in states [<TrainingState_.BACKWARD_PRE: 3>, <TrainingState_.BACKWARD_POST: 4>] but current state is TrainingState_.IDLE"
)
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason="requires torch1.12 or higher") @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason="requires torch1.12 or higher")
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_torch_fsdp_plugin(): def test_torch_fsdp_plugin():