mirror of https://github.com/hpcaitech/ColossalAI
[booster] removed models that don't support fsdp (#3744)
Co-authored-by: 纪少敏 <jishaomin@jishaomindeMBP.lan>pull/3748/head
parent
afb239bbf8
commit
6050f37776
|
@ -46,7 +46,10 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn):
|
|||
|
||||
def check_torch_fsdp_plugin():
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
|
||||
if 'diffusers' in name:
|
||||
if any(element in name for element in [
|
||||
'diffusers', 'deepfm_sparsearch', 'dlrm_interactionarch', 'torchvision_googlenet',
|
||||
'torchvision_inception_v3'
|
||||
]):
|
||||
continue
|
||||
run_fn(model_fn, data_gen_fn, output_transform_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
@ -58,12 +61,6 @@ def run_dist(rank, world_size, port):
|
|||
check_torch_fsdp_plugin()
|
||||
|
||||
|
||||
# FIXME: this test is not working
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
"ValueError: expected to be in states [<TrainingState_.BACKWARD_PRE: 3>, <TrainingState_.BACKWARD_POST: 4>] but current state is TrainingState_.IDLE"
|
||||
)
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason="requires torch1.12 or higher")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_torch_fsdp_plugin():
|
||||
|
|
Loading…
Reference in New Issue