diff --git a/colossalai/initialize.py b/colossalai/initialize.py index e907efdde..f3719dcb4 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -15,26 +15,25 @@ from torch.optim.lr_scheduler import _LRScheduler from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader -from colossalai.core import global_context as gpc -from colossalai.context.moe_context import MOE_CONTEXT - -from colossalai.logging import get_dist_logger - -from colossalai.engine.schedule import NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape -from colossalai.engine import Engine -from colossalai.gemini.ophooks import BaseOpHook - -from colossalai.utils import (get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param) -from colossalai.utils.moe import sync_moe_model_param - from colossalai.amp import AMP_TYPE, convert_to_amp from colossalai.amp.naive_amp import NaiveAMPModel from colossalai.builder.builder import build_gradient_handler from colossalai.context import Config, ConfigException, ParallelMode +from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.core import global_context as gpc +from colossalai.engine import Engine from colossalai.engine.gradient_accumulation import accumulate_gradient - +from colossalai.engine.schedule import ( + InterleavedPipelineSchedule, + NonPipelineSchedule, + PipelineSchedule, + get_tensor_shape, +) +from colossalai.gemini.ophooks import BaseOpHook +from colossalai.logging import get_dist_logger from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer - +from colossalai.utils import get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param +from colossalai.utils.moe import sync_moe_model_param from colossalai.zero import convert_to_zero_v2 from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2 @@ -301,9 +300,9 @@ def initialize(model: nn.Module, model = model().to(get_current_device()) # optimizer maybe a optimizer_cls - logger.warning("Initializing an non ZeRO model with optimizer class") if isinstance(optimizer, Callable): optimizer = optimizer(model.parameters()) + logger.warning("Initializing an non ZeRO model with optimizer class") if not use_zero: if is_using_sequence():