diff --git a/tests/kit/model_zoo/registry.py b/tests/kit/model_zoo/registry.py index 44a0adc6a..5e8e0b382 100644 --- a/tests/kit/model_zoo/registry.py +++ b/tests/kit/model_zoo/registry.py @@ -61,7 +61,7 @@ class ModelZooRegistry(dict): """ self[name] = (model_fn, data_gen_fn, output_transform_fn, loss_fn, model_attribute) - def get_sub_registry(self, keyword: Union[str, List[str]]): + def get_sub_registry(self, keyword: Union[str, List[str]], exclude: Union[str, List[str]] = None): """ Get a sub registry with models that contain the keyword. @@ -76,10 +76,24 @@ class ModelZooRegistry(dict): keyword_list = keyword assert isinstance(keyword_list, (list, tuple)) + if exclude is None: + exclude_keywords = [] + elif isinstance(exclude, str): + exclude_keywords = [exclude] + else: + exclude_keywords = exclude + assert isinstance(exclude_keywords, (list, tuple)) + for k, v in self.items(): for kw in keyword_list: if kw in k: - new_dict[k] = v + should_exclude = False + for ex_kw in exclude_keywords: + if ex_kw in k: + should_exclude = True + + if not should_exclude: + new_dict[k] = v assert len(new_dict) > 0, f"No model found with keyword {keyword}" return new_dict diff --git a/tests/test_shardformer/test_with_torch_ddp.py b/tests/test_shardformer/test_with_torch_ddp.py index f642a9dca..4b741c21b 100644 --- a/tests/test_shardformer/test_with_torch_ddp.py +++ b/tests/test_shardformer/test_with_torch_ddp.py @@ -16,7 +16,7 @@ from tests.kit.model_zoo import model_zoo @parameterize("lazy_init", [True, False]) def check_shardformer_with_ddp(lazy_init: bool): - sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt") + sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt", exclude="transformers_gptj") # create shardformer # ranks: [0, 1, 2, 3]