Browse Source

[NFC] polish colossalai/engine/schedule/_pipeline_schedule.py code style (#2744)

pull/2829/head
Michelle 2 years ago committed by GitHub
parent
commit
c008d4ad0c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 4
      colossalai/engine/schedule/_pipeline_schedule.py

4
colossalai/engine/schedule/_pipeline_schedule.py

@ -4,8 +4,9 @@
import inspect import inspect
from typing import Callable, List, Tuple, Union from typing import Callable, List, Tuple, Union
import colossalai.communication as comm
import torch.cuda import torch.cuda
import colossalai.communication as comm
from colossalai.amp.naive_amp import NaiveAMPModel from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
@ -157,6 +158,7 @@ class PipelineSchedule(BaseSchedule):
def pre_processing(self, engine): def pre_processing(self, engine):
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
# TODO: remove this after testing new zero with pipeline parallelism # TODO: remove this after testing new zero with pipeline parallelism
model = engine.model model = engine.model
if isinstance(model, NaiveAMPModel): if isinstance(model, NaiveAMPModel):

Loading…
Cancel
Save