[NFC] polish colossalai/engine/schedule/_pipeline_schedule.py code style (#2744)

pull/2829/head
Michelle 2023-02-20 10:38:40 +08:00 committed by GitHub
parent 58abde2857
commit c008d4ad0c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 8 additions and 6 deletions

View File

@ -4,8 +4,9 @@
import inspect
from typing import Callable, List, Tuple, Union
import colossalai.communication as comm
import torch.cuda
import colossalai.communication as comm
from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
@ -157,6 +158,7 @@ class PipelineSchedule(BaseSchedule):
def pre_processing(self, engine):
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
# TODO: remove this after testing new zero with pipeline parallelism
model = engine.model
if isinstance(model, NaiveAMPModel):