mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
234 lines
9.5 KiB
234 lines
9.5 KiB
#!/usr/bin/env python |
|
# -*- encoding: utf-8 -*- |
|
|
|
import inspect |
|
from collections.abc import Iterable |
|
|
|
from colossalai.registry import * |
|
|
|
|
|
def build_from_config(module, config: dict): |
|
"""Returns an object of :class:`module` constructed from `config`. |
|
|
|
:param module: A python or user-defined class |
|
:type module: class |
|
:param config: A python dict containing information used in the construction |
|
of the return object |
|
:type config: dict |
|
:raises AssertionError: Raises an AssertionError if `module` is not a class |
|
:return: An object of interest |
|
:rtype: Object |
|
""" |
|
assert inspect.isclass(module), 'module must be a class' |
|
return module(**config) |
|
|
|
|
|
def build_from_registry(config, registry: Registry): |
|
"""Returns an object constructed from `config`, the type of the object |
|
is specified by `registry`. |
|
|
|
:param config: A python dict or a :class:`colossalai.context.Config` object |
|
containing information used in the construction of the return object |
|
:type config: dict or :class:`colossalai.context.colossalai.context.Config` |
|
:param registry: A registry specifying the type of the return object |
|
:type registry: :class:`Registry` |
|
:raises AssertionError: Raises an AssertionError if `registry` is not an object |
|
of :class:`Registry` or `mod_type` in `config` is not found in `registry` |
|
:raises Exception: Raises an Exception if an error occurred when building |
|
from registry |
|
:return: An object specified by `registry` |
|
:rtype: Python object specified by `registry` |
|
""" |
|
config_ = config.copy() # keep the original config untouched |
|
assert isinstance( |
|
registry, Registry), f'Expected type Registry but got {type(registry)}' |
|
|
|
mod_type = config_.pop('type') |
|
assert registry.has( |
|
mod_type), f'{mod_type} is not found in registry {registry.name}' |
|
try: |
|
obj = registry.get_module(mod_type)(**config_) |
|
except Exception as e: |
|
print( |
|
f'An error occurred when building {mod_type} from registry {registry.name}', |
|
flush=True) |
|
raise e |
|
|
|
return obj |
|
|
|
|
|
def build_layer(config): |
|
"""Returns a layer object of :class:`nn.Module` constructed from `config`. |
|
|
|
:param config: A python dict or a :class:`colossalai.context.Config` object |
|
containing information used in the construction of the return object |
|
:type config: dict or :class:`colossalai.context.Config` |
|
:return: An object of :class:`torch.nn.Module` |
|
:rtype: :class:`torch.nn.Module` |
|
""" |
|
return build_from_registry(config, LAYERS) |
|
|
|
|
|
def build_loss(config): |
|
"""Returns a loss function object of :class:`torch.autograd.Function` constructed |
|
from `config`. |
|
|
|
:param config: A python dict or a :class:`colossalai.context.Config` object |
|
containing information used in the construction of the return object |
|
:type config: dict or :class:`colossalai.context.Config` |
|
:return: An object of :class:`torch.nn.modules.loss._Loss` |
|
:rtype: :class:`torch.nn.modules.loss._Loss` |
|
""" |
|
return build_from_registry(config, LOSSES) |
|
|
|
|
|
def build_model(config): |
|
"""Returns a model object of :class:`nn.Module` constructed from `config`. |
|
|
|
:param config: A python dict or a :class:`colossalai.context.Config` object |
|
containing information used in the construction of the return object |
|
:type config: dict or :class:`colossalai.context.Config` |
|
:return: An object of :class:`torch.nn.Module` |
|
:rtype: :class:`torch.nn.Module` |
|
""" |
|
return build_from_registry(config, MODELS) |
|
|
|
|
|
def build_dataset(config): |
|
"""Returns a dataset object of :class:`torch.utils.data.Dataset` constructed |
|
from `config`. |
|
|
|
:param config: A python dict or a :class:`colossalai.context.Config` object |
|
containing information used in the construction of the return object |
|
:type config: dict or :class:`colossalai.context.Config` |
|
:return: An object of :class:`torch.utils.data.Dataset` |
|
:rtype: :class:`torch.utils.data.Dataset` |
|
""" |
|
return build_from_registry(config, DATASETS) |
|
|
|
|
|
def build_optimizer(config, model): |
|
"""Returns an optimizer object of :class:`torch.optim.Optimizer` constructed from `config`, |
|
'model' and 'params'. |
|
|
|
:param config: A python dict or a :class:`colossalai.context.Config` object |
|
containing information used in the construction of the return object |
|
:type config: dict or :class:`colossalai.context.Config` |
|
:param model: A model containing parameters for the optimizer |
|
:type model: :class:`nn.Module` |
|
:return: An object of :class:`torch.optim.Optimizer` |
|
:rtype: :class:`torch.optim.Optimizer` |
|
""" |
|
config_ = config.copy() |
|
config_['params'] = model.parameters() |
|
return build_from_registry(config_, OPTIMIZERS) |
|
|
|
|
|
def build_gradient_handler(config, model, optimizer): |
|
"""Returns a gradient handler object of :class:`BaseGradientHandler` constructed from `config`, |
|
`model` and `optimizer`. |
|
|
|
:param config: A python dict or a :class:`colossalai.context.Config` object |
|
containing information used in the construction of the return object |
|
:type config: dict or :class:`colossalai.context.Config` |
|
:param model: A model containing parameters for the gradient handler |
|
:type model: :class:`nn.Module` |
|
:param optimizer: An optimizer object containing parameters for the gradient handler |
|
:type optimizer: :class:`torch.optim.Optimizer` |
|
:return: An object of :class:`colossalai.engine.BaseGradientHandler` |
|
:rtype: :class:`colossalai.engine.BaseGradientHandler` |
|
""" |
|
config_ = config.copy() |
|
config_['model'] = model |
|
config_['optimizer'] = optimizer |
|
return build_from_registry(config_, GRADIENT_HANDLER) |
|
|
|
|
|
def build_hooks(config, trainer): |
|
"""Returns a hook object of :class:`BaseHook` constructed from `config` and `trainer`. |
|
|
|
:param config: A python dict or a :class:`colossalai.context.Config` object |
|
containing information used in the construction of the return object |
|
:type config: dict or :class:`colossalai.context.Config` |
|
:param trainer: A :class:`Trainer` object containing parameters for the hook |
|
:type trainer: :class:`Trainer` |
|
:return: An object of :class:`colossalai.trainer.hooks.BaseHook` |
|
:rtype: :class:`colossalai.trainer.hooks.BaseHook` |
|
""" |
|
config_ = config.copy() |
|
config_['trainer'] = trainer |
|
return build_from_registry(config_, HOOKS) |
|
|
|
|
|
def build_ophooks(config): |
|
"""Returns a hook object of :class:`BaseOpHook` constructed from `config`. |
|
|
|
:param config: A python dict or a :class:`colossalai.context.Config` object |
|
containing information used in the construction of the return object |
|
:type config: dict or :class:`colossalai.context.Config` |
|
:return: An object of :class:`colossalai.trainer.hooks.BaseOpHook` |
|
:rtype: :class:`colossalai.trainer.hooks.BaseOpHook` |
|
""" |
|
config_ = config.copy() |
|
return build_from_registry(config_, OPHOOKS) |
|
|
|
|
|
def build_transform(config): |
|
"""Returns a transformation object of :class:`torchvision.transforms` constructed |
|
from `config`. |
|
|
|
:param config: A python dict or a :class:`colossalai.context.Config` object |
|
containing information used in the construction of the return object |
|
:type config: dict or :class:`colossalai.context.Config` |
|
:return: An object of :class:`torchvision.transforms` |
|
:rtype: :class:`torchvision.transforms` |
|
""" |
|
return build_from_registry(config, TRANSFORMS) |
|
|
|
|
|
def build_data_sampler(config, dataset): |
|
"""Returns a data sampler object of :class:`colossalai.nn.data.sampler.BaseSampler` |
|
constructed from `config`. |
|
|
|
:param config: A python dict or a :class:`colossalai.context.Config` object |
|
containing information used in the construction of the return object |
|
:type config: dict or :class:`colossalai.context.Config` |
|
:param dataset: An object of :class:`torch.utils.data.Dataset` containing information |
|
used in the construction of the return object |
|
:type dataset: :class:`torch.utils.data.Dataset` |
|
:return: An object of :class:`colossalai.utils.data_sampler.BaseSampler` |
|
:rtype: :class:`colossalai.utils.data_sampler.BaseSampler` |
|
""" |
|
config_ = config.copy() |
|
config_['dataset'] = dataset |
|
return build_from_registry(config_, DATA_SAMPLERS) |
|
|
|
|
|
def build_lr_scheduler(config, optimizer): |
|
"""Returns a learning rate scheduler object of :class:`torch.optim.lr_scheduler` |
|
constructed from `config`, `optimizer`, `total_steps` and `num_steps_per_epoch`. |
|
|
|
:param config: A python dict or a :class:`colossalai.context.Config` object |
|
containing information used in the construction of the return object |
|
:type config: dict or :class:`colossalai.context.Config` |
|
:param optimizer: An optimizer object containing parameters for the learning rate |
|
scheduler |
|
:type optimizer: :class:`torch.optim.Optimizer` |
|
:return: An object of :class:`torch.optim.lr_scheduler` |
|
:rtype: :class:`torch.optim.lr_scheduler` |
|
""" |
|
config_ = config.copy() |
|
config_['optimizer'] = optimizer |
|
return build_from_registry(config_, LR_SCHEDULERS) |
|
|
|
|
|
def build_schedule(config): |
|
"""Returns a schedule of :class:`colossalai.engine.schedule.BaseSchedule`. |
|
|
|
:param config: A python dict or a :class:`colossalai.context.Config` object |
|
containing information used in the construction of the return object |
|
:type config: dict or :class:`colossalai.context.Config` |
|
:return: An object of :class:`colossalai.engine.schedule.BaseSchedule` |
|
:rtype: :class:`colossalai.engine.schedule.BaseSchedule` |
|
""" |
|
return build_from_registry(config, SCHEDULE)
|
|
|