diff --git a/colossalai/engine/_base_engine.py b/colossalai/engine/_base_engine.py index ff8979d82..db27ad0e8 100644 --- a/colossalai/engine/_base_engine.py +++ b/colossalai/engine/_base_engine.py @@ -12,7 +12,7 @@ from colossalai.engine.gradient_handler import BaseGradientHandler from colossalai.engine.schedule import BaseSchedule, InterleavedPipelineSchedule, NonPipelineSchedule, PipelineSchedule from colossalai.logging import get_dist_logger from colossalai.zero.legacy.gemini import BaseOpHook, register_ophooks_recursively - +from colossalai.nn.optimizer import ColossalaiOptimizer class Engine: """Basic engine class for training and evaluation. It runs a specific process method