2022-03-08 02:19:18 +00:00
|
|
|
#!/usr/bin/env python
|
|
|
|
|
|
|
|
|
|
|
|
class Registry:
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
self._registry = dict()
|
|
|
|
|
|
|
|
def register(self, name):
|
|
|
|
assert name not in self._registry
|
|
|
|
|
2023-05-10 09:12:03 +00:00
|
|
|
def _register(callable_):
|
2022-03-08 02:19:18 +00:00
|
|
|
self._registry[name] = callable_
|
|
|
|
|
2023-05-10 09:12:03 +00:00
|
|
|
return _register
|
2022-03-08 02:19:18 +00:00
|
|
|
|
|
|
|
def get_callable(self, name: str):
|
|
|
|
return self._registry[name]
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
self._idx = 0
|
|
|
|
self._len = len(self._registry)
|
|
|
|
self._names = list(self._registry.keys())
|
|
|
|
return self
|
|
|
|
|
|
|
|
def __next__(self):
|
|
|
|
if self._idx < self._len:
|
|
|
|
key = self._names[self._idx]
|
|
|
|
callable_ = self._registry[key]
|
|
|
|
self._idx += 1
|
|
|
|
return callable_
|
|
|
|
else:
|
|
|
|
raise StopIteration
|
|
|
|
|
|
|
|
|
|
|
|
non_distributed_component_funcs = Registry()
|
2023-05-10 09:12:03 +00:00
|
|
|
model_parallel_component_funcs = Registry()
|
2022-03-08 02:19:18 +00:00
|
|
|
|
2023-05-10 09:12:03 +00:00
|
|
|
__all__ = ['non_distributed_component_funcs', 'model_parallel_component_funcs']
|