[pipeline] refactor pipeline (#679)

* refactor pipeline---put runtime schedule into engine.

* add type hint for schedule Optional[BaseSchedule]

* preprocess schedule during engine initializing

* infer pipeline schedule params from config
pull/692/head
YuliangLiu0306 2022-04-07 15:54:14 +08:00 committed by GitHub
parent eace69387d
commit 0ed7042f42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 36 additions and 9 deletions

View File

@ -1,5 +1,5 @@
from ._base_schedule import BaseSchedule
from ._pipeline_schedule import PipelineSchedule, InterleavedPipelineSchedule
from ._pipeline_schedule import PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape
from ._non_pipeline_schedule import NonPipelineSchedule
__all__ = ['BaseSchedule', 'NonPipelineSchedule', 'PipelineSchedule', 'InterleavedPipelineSchedule']
__all__ = ['BaseSchedule', 'NonPipelineSchedule', 'PipelineSchedule', 'InterleavedPipelineSchedule', 'get_tensor_shape']

View File

@ -16,6 +16,29 @@ from colossalai.zero.sharded_model import ShardedModelV2
from ._base_schedule import BaseSchedule
def get_tensor_shape():
if hasattr(gpc.config, 'TENSOR_SHAPE'):
return gpc.config.TENSOR_SHAPE
if not gpc.is_initialized(ParallelMode.PIPELINE):
return None
if hasattr(gpc.config, 'SEQ_LENGTH') and hasattr(gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr(gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr(gpc.config, 'HIDDEN_SIZE'):
if gpc.is_initialized(ParallelMode.DATA):
dp_size = gpc.get_world_size(ParallelMode.DATA)
else:
dp_size = 1
if gpc.is_initialized(ParallelMode.SEQUENCE):
seq_size = gpc.get_world_size(ParallelMode.SEQUENCE)
else:
seq_size = 1
tensor_shape = (gpc.config.SEQ_LENGTH // seq_size,
gpc.config.GLOBAL_BATCH_SIZE // dp_size // gpc.config.NUM_MICRO_BATCHES,
gpc.config.HIDDEN_SIZE)
return tensor_shape
else:
return None
def pack_return_tensors(return_tensors):
output, label = tuple(zip(*return_tensors))

View File

@ -20,7 +20,7 @@ from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.builder.builder import build_gradient_handler
from colossalai.context import Config, ConfigException, ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine.schedule import NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule
from colossalai.engine.schedule import NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.engine import Engine
@ -391,14 +391,18 @@ def initialize(model: nn.Module,
# initialize schedule for engine
if is_using_pp():
tensor_shape = getattr(gpc.config, 'TENSOR_SHAPE', None)
tensor_shape = get_tensor_shape()
use_interleaved = hasattr(gpc.config, 'model') and hasattr(gpc.config.model, 'num_chunks')
if gpc.is_initialized(ParallelMode.PARALLEL_1D):
scatter_gather = True
else:
scatter_gather = False
if use_interleaved:
schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
gpc.config.model.num_chunks, tensor_shape=tensor_shape, scatter_gather_tensors=True)
gpc.config.model.num_chunks, tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather)
else:
schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
tensor_shape=tensor_shape, scatter_gather_tensors=True)
tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather)
else:
schedule = NonPipelineSchedule()

View File

@ -312,7 +312,7 @@ class AccuracyHook(MetricHook):
def after_test_iter(self, trainer, logits, targets, *args):
if self._is_stage_to_compute:
batch_size = trainer.schedule.batch_size
batch_size = trainer.engine.schedule.batch_size
self.metric.update(logits, targets, batch_size)
@ -392,7 +392,7 @@ class ThroughputHook(MetricHook):
def after_train_iter(self, trainer, *args):
if self._is_stage_to_compute:
self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Train-step').get_elapsed_time())
self.metric.update(trainer.engine.schedule.batch_size, trainer._timer.get_timer('Train-step').get_elapsed_time())
def before_test(self, trainer):
if self._is_stage_to_compute:
@ -400,4 +400,4 @@ class ThroughputHook(MetricHook):
def after_test_iter(self, trainer, *args):
if self._is_stage_to_compute:
self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Test-step').get_elapsed_time())
self.metric.update(trainer.engine.schedule.batch_size, trainer._timer.get_timer('Test-step').get_elapsed_time())