|
|
|
@ -15,26 +15,25 @@ from torch.optim.lr_scheduler import _LRScheduler
|
|
|
|
|
from torch.optim.optimizer import Optimizer |
|
|
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
|
|
|
|
from colossalai.core import global_context as gpc |
|
|
|
|
from colossalai.context.moe_context import MOE_CONTEXT |
|
|
|
|
|
|
|
|
|
from colossalai.logging import get_dist_logger |
|
|
|
|
|
|
|
|
|
from colossalai.engine.schedule import NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape |
|
|
|
|
from colossalai.engine import Engine |
|
|
|
|
from colossalai.gemini.ophooks import BaseOpHook |
|
|
|
|
|
|
|
|
|
from colossalai.utils import (get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param) |
|
|
|
|
from colossalai.utils.moe import sync_moe_model_param |
|
|
|
|
|
|
|
|
|
from colossalai.amp import AMP_TYPE, convert_to_amp |
|
|
|
|
from colossalai.amp.naive_amp import NaiveAMPModel |
|
|
|
|
from colossalai.builder.builder import build_gradient_handler |
|
|
|
|
from colossalai.context import Config, ConfigException, ParallelMode |
|
|
|
|
from colossalai.context.moe_context import MOE_CONTEXT |
|
|
|
|
from colossalai.core import global_context as gpc |
|
|
|
|
from colossalai.engine import Engine |
|
|
|
|
from colossalai.engine.gradient_accumulation import accumulate_gradient |
|
|
|
|
|
|
|
|
|
from colossalai.engine.schedule import ( |
|
|
|
|
InterleavedPipelineSchedule, |
|
|
|
|
NonPipelineSchedule, |
|
|
|
|
PipelineSchedule, |
|
|
|
|
get_tensor_shape, |
|
|
|
|
) |
|
|
|
|
from colossalai.gemini.ophooks import BaseOpHook |
|
|
|
|
from colossalai.logging import get_dist_logger |
|
|
|
|
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer |
|
|
|
|
|
|
|
|
|
from colossalai.utils import get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param |
|
|
|
|
from colossalai.utils.moe import sync_moe_model_param |
|
|
|
|
from colossalai.zero import convert_to_zero_v2 |
|
|
|
|
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2 |
|
|
|
|
|
|
|
|
@ -301,9 +300,9 @@ def initialize(model: nn.Module,
|
|
|
|
|
model = model().to(get_current_device()) |
|
|
|
|
|
|
|
|
|
# optimizer maybe a optimizer_cls |
|
|
|
|
logger.warning("Initializing an non ZeRO model with optimizer class") |
|
|
|
|
if isinstance(optimizer, Callable): |
|
|
|
|
optimizer = optimizer(model.parameters()) |
|
|
|
|
logger.warning("Initializing an non ZeRO model with optimizer class") |
|
|
|
|
|
|
|
|
|
if not use_zero: |
|
|
|
|
if is_using_sequence(): |
|
|
|
|