mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
38 lines
926 B
38 lines
926 B
#!/usr/bin/env python |
|
|
|
|
|
class Registry: |
|
def __init__(self): |
|
self._registry = dict() |
|
|
|
def register(self, name): |
|
assert name not in self._registry |
|
|
|
def _register(callable_): |
|
self._registry[name] = callable_ |
|
|
|
return _register |
|
|
|
def get_callable(self, name: str): |
|
return self._registry[name] |
|
|
|
def __iter__(self): |
|
self._idx = 0 |
|
self._len = len(self._registry) |
|
self._names = list(self._registry.keys()) |
|
return self |
|
|
|
def __next__(self): |
|
if self._idx < self._len: |
|
key = self._names[self._idx] |
|
callable_ = self._registry[key] |
|
self._idx += 1 |
|
return callable_ |
|
else: |
|
raise StopIteration |
|
|
|
|
|
non_distributed_component_funcs = Registry() |
|
model_parallel_component_funcs = Registry() |
|
|
|
__all__ = ["non_distributed_component_funcs", "model_parallel_component_funcs"]
|
|
|