ColossalAI/colossalai/amp/naive_amp/naive_amp.py

66 lines
1.8 KiB
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
from torch import Tensor
from typing import Union, List, Any, Dict
from torch.optim import Optimizer
import torch.cuda.amp as torch_amp
from colossalai.nn.optimizer import ColossalaiOptimizer
from ._fp16_optimizer import FP16Optimizer
class NaiveAMPOptimizer(ColossalaiOptimizer):
def __init__(self, optim: Optimizer, *args, **kwargs):
optim = FP16Optimizer(optimizer=optim, *args, **kwargs)
super().__init__(optim)
def backward(self, loss: Tensor):
loss = self.optim.scale_loss(loss)
loss.backward()
def step(self):
self.optim.step()
def clip_grad_norm(self, model: nn.Module, max_norm: float):
pass
class NaiveAMPModel(nn.Module):
def __init__(self,
model: nn.Module,
output_to_fp32: bool = True):
super().__init__()
self.model = model.half()
self._output_to_fp32 = output_to_fp32
def _convert_to_fp16(self, input_: Any):
if isinstance(input_, Tensor) and input_.dtype == torch.float32:
input_ = input_.half()
return input_
def _convert_to_fp32(self, input_: Any):
if isinstance(input_, Tensor) and input_.dtype == torch.float16:
input_ = input_.float()
return input_
def forward(self, *args, **kwargs):
if args:
args = [self._convert_to_fp16(arg) for arg in args]
if kwargs:
for k, v in kwargs.items():
kwargs[k] = self._convert_to_fp16(v)
out = self.model(*args, **kwargs)
if self._output_to_fp32:
if isinstance(out, Tensor):
out = self._convert_to_fp32(out)
elif isinstance(out, (tuple, list)):
out = [self._convert_to_fp32(val) for val in out]
return out