mirror of https://github.com/hpcaitech/ColossalAI
[booster] removed models that don't support fsdp (#3744)
Co-authored-by: 纪少敏 <jishaomin@jishaomindeMBP.lan>pull/3748/head
parent
afb239bbf8
commit
6050f37776
|
@ -46,7 +46,10 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn):
|
||||||
|
|
||||||
def check_torch_fsdp_plugin():
|
def check_torch_fsdp_plugin():
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
|
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
|
||||||
if 'diffusers' in name:
|
if any(element in name for element in [
|
||||||
|
'diffusers', 'deepfm_sparsearch', 'dlrm_interactionarch', 'torchvision_googlenet',
|
||||||
|
'torchvision_inception_v3'
|
||||||
|
]):
|
||||||
continue
|
continue
|
||||||
run_fn(model_fn, data_gen_fn, output_transform_fn)
|
run_fn(model_fn, data_gen_fn, output_transform_fn)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
@ -58,12 +61,6 @@ def run_dist(rank, world_size, port):
|
||||||
check_torch_fsdp_plugin()
|
check_torch_fsdp_plugin()
|
||||||
|
|
||||||
|
|
||||||
# FIXME: this test is not working
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(
|
|
||||||
"ValueError: expected to be in states [<TrainingState_.BACKWARD_PRE: 3>, <TrainingState_.BACKWARD_POST: 4>] but current state is TrainingState_.IDLE"
|
|
||||||
)
|
|
||||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason="requires torch1.12 or higher")
|
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason="requires torch1.12 or higher")
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_torch_fsdp_plugin():
|
def test_torch_fsdp_plugin():
|
||||||
|
|
Loading…
Reference in New Issue