ColossalAI/colossalai/booster/mixed_precision/mixed_precision_base.py

24 lines
565 B
Python
Raw Permalink Normal View History

from abc import ABC, abstractmethod
from typing import Callable, Optional, Tuple
import torch.nn as nn
from torch.optim import Optimizer
from colossalai.interface import OptimizerWrapper
class MixedPrecision(ABC):
"""
An abstract class for mixed precision training.
"""
@abstractmethod
def configure(
self,
model: nn.Module,
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
# TODO: implement this method
pass