You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/amp/naive_amp/__init__.py

39 lines
1.3 KiB

import torch.nn as nn
from torch.optim import Optimizer
from colossalai.utils import is_no_pp_or_last_stage
from .naive_amp import NaiveAMPOptimizer, NaiveAMPModel
def convert_to_naive_amp(model: nn.Module,
optimizer: Optimizer,
amp_config):
"""A helper function to wrap training components with Torch AMP modules
:param model: your model object
:type model: :class:`torch.nn.Module`
:param optimizer: your optimizer object
:type optimizer: :class:`torch.optim.Optimzer`
:param amp_config: configuration for naive mode amp
:type amp_config: :class:`colossalai.context.Config` or dict
:return: (model, optimizer)
:rtype: Tuple
"""
if isinstance(model, nn.ModuleList):
# interleaved pipeline
module_list = []
for chunk, m in enumerate(model):
output_to_fp32 = is_no_pp_or_last_stage() and chunk == len(model) - 1
module_list.append(NaiveAMPModel(m, output_to_fp32=output_to_fp32))
model = nn.ModuleList(module_list)
else:
output_to_fp32 = is_no_pp_or_last_stage()
model = NaiveAMPModel(model, output_to_fp32=output_to_fp32)
optimizer = NaiveAMPOptimizer(optimizer, **amp_config)
return model, optimizer
__all__ = ['convert_to_naive_amp', 'NaiveAMPOptimizer']