|
|
@ -4,8 +4,9 @@ |
|
|
|
import inspect |
|
|
|
import inspect |
|
|
|
from typing import Callable, List, Tuple, Union |
|
|
|
from typing import Callable, List, Tuple, Union |
|
|
|
|
|
|
|
|
|
|
|
import colossalai.communication as comm |
|
|
|
|
|
|
|
import torch.cuda |
|
|
|
import torch.cuda |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import colossalai.communication as comm |
|
|
|
from colossalai.amp.naive_amp import NaiveAMPModel |
|
|
|
from colossalai.amp.naive_amp import NaiveAMPModel |
|
|
|
from colossalai.context.parallel_mode import ParallelMode |
|
|
|
from colossalai.context.parallel_mode import ParallelMode |
|
|
|
from colossalai.core import global_context as gpc |
|
|
|
from colossalai.core import global_context as gpc |
|
|
@ -157,6 +158,7 @@ class PipelineSchedule(BaseSchedule): |
|
|
|
|
|
|
|
|
|
|
|
def pre_processing(self, engine): |
|
|
|
def pre_processing(self, engine): |
|
|
|
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 |
|
|
|
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 |
|
|
|
|
|
|
|
|
|
|
|
# TODO: remove this after testing new zero with pipeline parallelism |
|
|
|
# TODO: remove this after testing new zero with pipeline parallelism |
|
|
|
model = engine.model |
|
|
|
model = engine.model |
|
|
|
if isinstance(model, NaiveAMPModel): |
|
|
|
if isinstance(model, NaiveAMPModel): |
|
|
|