[ci] fixed ddp test (#5254)

* [ci] fixed ddp test

* polish
pull/5256/head
Frank Lee 11 months ago committed by GitHub
parent d5eeeb1416
commit 2b83418719
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -61,7 +61,7 @@ class ModelZooRegistry(dict):
"""
self[name] = (model_fn, data_gen_fn, output_transform_fn, loss_fn, model_attribute)
def get_sub_registry(self, keyword: Union[str, List[str]]):
def get_sub_registry(self, keyword: Union[str, List[str]], exclude: Union[str, List[str]] = None):
"""
Get a sub registry with models that contain the keyword.
@ -76,10 +76,24 @@ class ModelZooRegistry(dict):
keyword_list = keyword
assert isinstance(keyword_list, (list, tuple))
if exclude is None:
exclude_keywords = []
elif isinstance(exclude, str):
exclude_keywords = [exclude]
else:
exclude_keywords = exclude
assert isinstance(exclude_keywords, (list, tuple))
for k, v in self.items():
for kw in keyword_list:
if kw in k:
new_dict[k] = v
should_exclude = False
for ex_kw in exclude_keywords:
if ex_kw in k:
should_exclude = True
if not should_exclude:
new_dict[k] = v
assert len(new_dict) > 0, f"No model found with keyword {keyword}"
return new_dict

@ -16,7 +16,7 @@ from tests.kit.model_zoo import model_zoo
@parameterize("lazy_init", [True, False])
def check_shardformer_with_ddp(lazy_init: bool):
sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt")
sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt", exclude="transformers_gptj")
# create shardformer
# ranks: [0, 1, 2, 3]

Loading…
Cancel
Save