diff --git a/colossalai/amp/naive_amp/__init__.py b/colossalai/amp/naive_amp/__init__.py index bb2b8eb26..5b2f71d3c 100644 --- a/colossalai/amp/naive_amp/__init__.py +++ b/colossalai/amp/naive_amp/__init__.py @@ -1,10 +1,13 @@ import inspect + import torch.nn as nn from torch.optim import Optimizer + from colossalai.utils import is_no_pp_or_last_stage -from .naive_amp import NaiveAMPOptimizer, NaiveAMPModel -from .grad_scaler import DynamicGradScaler, ConstantGradScaler + from ._fp16_optimizer import FP16Optimizer +from .grad_scaler import ConstantGradScaler, DynamicGradScaler +from .naive_amp import NaiveAMPModel, NaiveAMPOptimizer def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config):