mirror of https://github.com/hpcaitech/ColossalAI
[pipeline] refactor pipeline (#679)
* refactor pipeline---put runtime schedule into engine. * add type hint for schedule Optional[BaseSchedule] * preprocess schedule during engine initializing * infer pipeline schedule params from configpull/692/head
parent
eace69387d
commit
0ed7042f42
|
@ -1,5 +1,5 @@
|
|||
from ._base_schedule import BaseSchedule
|
||||
from ._pipeline_schedule import PipelineSchedule, InterleavedPipelineSchedule
|
||||
from ._pipeline_schedule import PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape
|
||||
from ._non_pipeline_schedule import NonPipelineSchedule
|
||||
|
||||
__all__ = ['BaseSchedule', 'NonPipelineSchedule', 'PipelineSchedule', 'InterleavedPipelineSchedule']
|
||||
__all__ = ['BaseSchedule', 'NonPipelineSchedule', 'PipelineSchedule', 'InterleavedPipelineSchedule', 'get_tensor_shape']
|
||||
|
|
|
@ -16,6 +16,29 @@ from colossalai.zero.sharded_model import ShardedModelV2
|
|||
|
||||
from ._base_schedule import BaseSchedule
|
||||
|
||||
def get_tensor_shape():
|
||||
if hasattr(gpc.config, 'TENSOR_SHAPE'):
|
||||
return gpc.config.TENSOR_SHAPE
|
||||
|
||||
if not gpc.is_initialized(ParallelMode.PIPELINE):
|
||||
return None
|
||||
|
||||
if hasattr(gpc.config, 'SEQ_LENGTH') and hasattr(gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr(gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr(gpc.config, 'HIDDEN_SIZE'):
|
||||
if gpc.is_initialized(ParallelMode.DATA):
|
||||
dp_size = gpc.get_world_size(ParallelMode.DATA)
|
||||
else:
|
||||
dp_size = 1
|
||||
if gpc.is_initialized(ParallelMode.SEQUENCE):
|
||||
seq_size = gpc.get_world_size(ParallelMode.SEQUENCE)
|
||||
else:
|
||||
seq_size = 1
|
||||
|
||||
tensor_shape = (gpc.config.SEQ_LENGTH // seq_size,
|
||||
gpc.config.GLOBAL_BATCH_SIZE // dp_size // gpc.config.NUM_MICRO_BATCHES,
|
||||
gpc.config.HIDDEN_SIZE)
|
||||
return tensor_shape
|
||||
else:
|
||||
return None
|
||||
|
||||
def pack_return_tensors(return_tensors):
|
||||
output, label = tuple(zip(*return_tensors))
|
||||
|
|
|
@ -20,7 +20,7 @@ from colossalai.amp.naive_amp import NaiveAMPModel
|
|||
from colossalai.builder.builder import build_gradient_handler
|
||||
from colossalai.context import Config, ConfigException, ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.engine.schedule import NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule
|
||||
from colossalai.engine.schedule import NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape
|
||||
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
from colossalai.engine import Engine
|
||||
|
@ -391,14 +391,18 @@ def initialize(model: nn.Module,
|
|||
|
||||
# initialize schedule for engine
|
||||
if is_using_pp():
|
||||
tensor_shape = getattr(gpc.config, 'TENSOR_SHAPE', None)
|
||||
tensor_shape = get_tensor_shape()
|
||||
use_interleaved = hasattr(gpc.config, 'model') and hasattr(gpc.config.model, 'num_chunks')
|
||||
if gpc.is_initialized(ParallelMode.PARALLEL_1D):
|
||||
scatter_gather = True
|
||||
else:
|
||||
scatter_gather = False
|
||||
if use_interleaved:
|
||||
schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
|
||||
gpc.config.model.num_chunks, tensor_shape=tensor_shape, scatter_gather_tensors=True)
|
||||
gpc.config.model.num_chunks, tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather)
|
||||
else:
|
||||
schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
|
||||
tensor_shape=tensor_shape, scatter_gather_tensors=True)
|
||||
tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather)
|
||||
else:
|
||||
schedule = NonPipelineSchedule()
|
||||
|
||||
|
|
|
@ -312,7 +312,7 @@ class AccuracyHook(MetricHook):
|
|||
|
||||
def after_test_iter(self, trainer, logits, targets, *args):
|
||||
if self._is_stage_to_compute:
|
||||
batch_size = trainer.schedule.batch_size
|
||||
batch_size = trainer.engine.schedule.batch_size
|
||||
self.metric.update(logits, targets, batch_size)
|
||||
|
||||
|
||||
|
@ -392,7 +392,7 @@ class ThroughputHook(MetricHook):
|
|||
|
||||
def after_train_iter(self, trainer, *args):
|
||||
if self._is_stage_to_compute:
|
||||
self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Train-step').get_elapsed_time())
|
||||
self.metric.update(trainer.engine.schedule.batch_size, trainer._timer.get_timer('Train-step').get_elapsed_time())
|
||||
|
||||
def before_test(self, trainer):
|
||||
if self._is_stage_to_compute:
|
||||
|
@ -400,4 +400,4 @@ class ThroughputHook(MetricHook):
|
|||
|
||||
def after_test_iter(self, trainer, *args):
|
||||
if self._is_stage_to_compute:
|
||||
self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Test-step').get_elapsed_time())
|
||||
self.metric.update(trainer.engine.schedule.batch_size, trainer._timer.get_timer('Test-step').get_elapsed_time())
|
||||
|
|
Loading…
Reference in New Issue