mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
43 lines
1.6 KiB
43 lines
1.6 KiB
import torch.nn as nn
|
|
from torch.optim import Optimizer
|
|
|
|
from .apex_amp import ApexAMPOptimizer
|
|
|
|
|
|
def convert_to_apex_amp(model: nn.Module, optimizer: Optimizer, amp_config):
|
|
r"""A helper function to wrap training components with Apex AMP modules
|
|
|
|
Args:
|
|
model (:class:`torch.nn.Module`): your model object.
|
|
optimizer (:class:`torch.optim.Optimizer`): your optimizer object.
|
|
amp_config (Union[:class:`colossalai.context.Config`, dict]): configuration for initializing apex_amp.
|
|
|
|
Returns:
|
|
Tuple: A tuple (model, optimizer).
|
|
|
|
The ``amp_config`` should include parameters below:
|
|
::
|
|
|
|
enabled (bool, optional, default=True)
|
|
opt_level (str, optional, default="O1")
|
|
cast_model_type (``torch.dtype``, optional, default=None)
|
|
patch_torch_functions (bool, optional, default=None)
|
|
keep_batchnorm_fp32 (bool or str, optional, default=None
|
|
master_weights (bool, optional, default=None)
|
|
loss_scale (float or str, optional, default=None)
|
|
cast_model_outputs (torch.dtype, optional, default=None)
|
|
num_losses (int, optional, default=1)
|
|
verbosity (int, default=1)
|
|
min_loss_scale (float, default=None)
|
|
max_loss_scale (float, default=2.**24)
|
|
|
|
More details about ``amp_config`` refer to `amp_config <https://nvidia.github.io/apex/amp.html?highlight=apex%20amp>`_.
|
|
"""
|
|
import apex.amp as apex_amp
|
|
model, optimizer = apex_amp.initialize(model, optimizer, **amp_config)
|
|
optimizer = ApexAMPOptimizer(optimizer)
|
|
return model, optimizer
|
|
|
|
|
|
__all__ = ['convert_to_apex_amp', 'ApexAMPOptimizer']
|