mirror of https://github.com/hpcaitech/ColossalAI
parent
d5eeeb1416
commit
2b83418719
|
@ -61,7 +61,7 @@ class ModelZooRegistry(dict):
|
||||||
"""
|
"""
|
||||||
self[name] = (model_fn, data_gen_fn, output_transform_fn, loss_fn, model_attribute)
|
self[name] = (model_fn, data_gen_fn, output_transform_fn, loss_fn, model_attribute)
|
||||||
|
|
||||||
def get_sub_registry(self, keyword: Union[str, List[str]]):
|
def get_sub_registry(self, keyword: Union[str, List[str]], exclude: Union[str, List[str]] = None):
|
||||||
"""
|
"""
|
||||||
Get a sub registry with models that contain the keyword.
|
Get a sub registry with models that contain the keyword.
|
||||||
|
|
||||||
|
@ -76,10 +76,24 @@ class ModelZooRegistry(dict):
|
||||||
keyword_list = keyword
|
keyword_list = keyword
|
||||||
assert isinstance(keyword_list, (list, tuple))
|
assert isinstance(keyword_list, (list, tuple))
|
||||||
|
|
||||||
|
if exclude is None:
|
||||||
|
exclude_keywords = []
|
||||||
|
elif isinstance(exclude, str):
|
||||||
|
exclude_keywords = [exclude]
|
||||||
|
else:
|
||||||
|
exclude_keywords = exclude
|
||||||
|
assert isinstance(exclude_keywords, (list, tuple))
|
||||||
|
|
||||||
for k, v in self.items():
|
for k, v in self.items():
|
||||||
for kw in keyword_list:
|
for kw in keyword_list:
|
||||||
if kw in k:
|
if kw in k:
|
||||||
new_dict[k] = v
|
should_exclude = False
|
||||||
|
for ex_kw in exclude_keywords:
|
||||||
|
if ex_kw in k:
|
||||||
|
should_exclude = True
|
||||||
|
|
||||||
|
if not should_exclude:
|
||||||
|
new_dict[k] = v
|
||||||
|
|
||||||
assert len(new_dict) > 0, f"No model found with keyword {keyword}"
|
assert len(new_dict) > 0, f"No model found with keyword {keyword}"
|
||||||
return new_dict
|
return new_dict
|
||||||
|
|
|
@ -16,7 +16,7 @@ from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
@parameterize("lazy_init", [True, False])
|
@parameterize("lazy_init", [True, False])
|
||||||
def check_shardformer_with_ddp(lazy_init: bool):
|
def check_shardformer_with_ddp(lazy_init: bool):
|
||||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt")
|
sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt", exclude="transformers_gptj")
|
||||||
|
|
||||||
# create shardformer
|
# create shardformer
|
||||||
# ranks: [0, 1, 2, 3]
|
# ranks: [0, 1, 2, 3]
|
||||||
|
|
Loading…
Reference in New Issue