You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/amp/apex_amp/__init__.py

43 lines
1.6 KiB

import torch.nn as nn
from torch.optim import Optimizer
from .apex_amp import ApexAMPOptimizer
def convert_to_apex_amp(model: nn.Module, optimizer: Optimizer, amp_config):
r"""A helper function to wrap training components with Apex AMP modules
Args:
model (:class:`torch.nn.Module`): your model object.
optimizer (:class:`torch.optim.Optimizer`): your optimizer object.
amp_config (Union[:class:`colossalai.context.Config`, dict]): configuration for initializing apex_amp.
Returns:
Tuple: A tuple (model, optimizer).
The ``amp_config`` should include parameters below:
::
enabled (bool, optional, default=True)
opt_level (str, optional, default="O1")
cast_model_type (``torch.dtype``, optional, default=None)
patch_torch_functions (bool, optional, default=None)
keep_batchnorm_fp32 (bool or str, optional, default=None
master_weights (bool, optional, default=None)
loss_scale (float or str, optional, default=None)
cast_model_outputs (torch.dtype, optional, default=None)
num_losses (int, optional, default=1)
verbosity (int, default=1)
min_loss_scale (float, default=None)
max_loss_scale (float, default=2.**24)
More details about ``amp_config`` refer to `amp_config <https://nvidia.github.io/apex/amp.html?highlight=apex%20amp>`_.
"""
import apex.amp as apex_amp
model, optimizer = apex_amp.initialize(model, optimizer, **amp_config)
optimizer = ApexAMPOptimizer(optimizer)
return model, optimizer
__all__ = ['convert_to_apex_amp', 'ApexAMPOptimizer']