import torch.nn as nn from torch.optim import Optimizer from .apex_amp import ApexAMPOptimizer def convert_to_apex_amp(model: nn.Module, optimizer: Optimizer, amp_config): r"""A helper function to wrap training components with Apex AMP modules Args: model (:class:`torch.nn.Module`): your model object. optimizer (:class:`torch.optim.Optimizer`): your optimizer object. amp_config (Union[:class:`colossalai.context.Config`, dict]): configuration for initializing apex_amp. Returns: Tuple: A tuple (model, optimizer). The ``amp_config`` should include parameters below: :: enabled (bool, optional, default=True) opt_level (str, optional, default="O1") cast_model_type (``torch.dtype``, optional, default=None) patch_torch_functions (bool, optional, default=None) keep_batchnorm_fp32 (bool or str, optional, default=None master_weights (bool, optional, default=None) loss_scale (float or str, optional, default=None) cast_model_outputs (torch.dtype, optional, default=None) num_losses (int, optional, default=1) verbosity (int, default=1) min_loss_scale (float, default=None) max_loss_scale (float, default=2.**24) More details about ``amp_config`` refer to `amp_config `_. """ import apex.amp as apex_amp model, optimizer = apex_amp.initialize(model, optimizer, **amp_config) optimizer = ApexAMPOptimizer(optimizer) return model, optimizer __all__ = ["convert_to_apex_amp", "ApexAMPOptimizer"]