mirror of https://github.com/hpcaitech/ColossalAI
Fix False warning in initialize.py (#2456)
* Update initialize.py * pre-commit run checkpull/2465/head
parent
32c46e146e
commit
9358262992
|
@ -15,26 +15,25 @@ from torch.optim.lr_scheduler import _LRScheduler
|
|||
from torch.optim.optimizer import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from colossalai.engine.schedule import NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape
|
||||
from colossalai.engine import Engine
|
||||
from colossalai.gemini.ophooks import BaseOpHook
|
||||
|
||||
from colossalai.utils import (get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param)
|
||||
from colossalai.utils.moe import sync_moe_model_param
|
||||
|
||||
from colossalai.amp import AMP_TYPE, convert_to_amp
|
||||
from colossalai.amp.naive_amp import NaiveAMPModel
|
||||
from colossalai.builder.builder import build_gradient_handler
|
||||
from colossalai.context import Config, ConfigException, ParallelMode
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.engine import Engine
|
||||
from colossalai.engine.gradient_accumulation import accumulate_gradient
|
||||
|
||||
from colossalai.engine.schedule import (
|
||||
InterleavedPipelineSchedule,
|
||||
NonPipelineSchedule,
|
||||
PipelineSchedule,
|
||||
get_tensor_shape,
|
||||
)
|
||||
from colossalai.gemini.ophooks import BaseOpHook
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
|
||||
|
||||
from colossalai.utils import get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param
|
||||
from colossalai.utils.moe import sync_moe_model_param
|
||||
from colossalai.zero import convert_to_zero_v2
|
||||
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
|
||||
|
||||
|
@ -301,9 +300,9 @@ def initialize(model: nn.Module,
|
|||
model = model().to(get_current_device())
|
||||
|
||||
# optimizer maybe a optimizer_cls
|
||||
logger.warning("Initializing an non ZeRO model with optimizer class")
|
||||
if isinstance(optimizer, Callable):
|
||||
optimizer = optimizer(model.parameters())
|
||||
logger.warning("Initializing an non ZeRO model with optimizer class")
|
||||
|
||||
if not use_zero:
|
||||
if is_using_sequence():
|
||||
|
|
Loading…
Reference in New Issue