mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish colossalai/engine/schedule/_pipeline_schedule.py code style (#2744)
parent
58abde2857
commit
c008d4ad0c
|
@ -4,8 +4,9 @@
|
|||
import inspect
|
||||
from typing import Callable, List, Tuple, Union
|
||||
|
||||
import colossalai.communication as comm
|
||||
import torch.cuda
|
||||
|
||||
import colossalai.communication as comm
|
||||
from colossalai.amp.naive_amp import NaiveAMPModel
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
|
@ -157,6 +158,7 @@ class PipelineSchedule(BaseSchedule):
|
|||
|
||||
def pre_processing(self, engine):
|
||||
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
|
||||
|
||||
# TODO: remove this after testing new zero with pipeline parallelism
|
||||
model = engine.model
|
||||
if isinstance(model, NaiveAMPModel):
|
||||
|
|
Loading…
Reference in New Issue