[NFC] polish colossalai/engine/schedule/_pipeline_schedule.py code style (#2744)

pull/2829/head
Michelle 2 years ago committed by GitHub
parent 58abde2857
commit c008d4ad0c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,8 +4,9 @@
import inspect
from typing import Callable, List, Tuple, Union
import colossalai.communication as comm
import torch.cuda
import colossalai.communication as comm
from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
@ -157,6 +158,7 @@ class PipelineSchedule(BaseSchedule):
def pre_processing(self, engine):
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
# TODO: remove this after testing new zero with pipeline parallelism
model = engine.model
if isinstance(model, NaiveAMPModel):

Loading…
Cancel
Save