[refactor] pipeline, put runtime schedule into engine. (#627)

pull/660/head^2
YuliangLiu0306 2022-04-03 20:46:45 +08:00 committed by GitHub
parent e5d615aeee
commit ade05a5d83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 68 additions and 49 deletions

View File

@ -2,7 +2,7 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
from asyncio.log import logger from asyncio.log import logger
from typing import List from typing import List, Iterable
from torch.nn import Module from torch.nn import Module
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer from torch.optim import Optimizer
@ -10,6 +10,7 @@ from torch.optim import Optimizer
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from torch import Tensor from torch import Tensor
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
from colossalai.engine.schedule import BaseSchedule, NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule
from typing import Optional, Type from typing import Optional, Type
from colossalai.engine.gradient_handler import BaseGradientHandler from colossalai.engine.gradient_handler import BaseGradientHandler
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
@ -27,6 +28,7 @@ class Engine:
clip_grad_norm (float, optional): The norm of gradient clipping. clip_grad_norm (float, optional): The norm of gradient clipping.
ophook_list (list): List of ophook. ophook_list (list): List of ophook.
verbose (bool): whether to display log info. verbose (bool): whether to display log info.
schedule (''BaseSchedule''): Runtime schedule.
Examples: Examples:
>>> # define model, criterion, optimizer, lr_scheduler, train_dataloader for your training >>> # define model, criterion, optimizer, lr_scheduler, train_dataloader for your training
@ -59,7 +61,8 @@ class Engine:
gradient_handlers: Optional[List[BaseGradientHandler]] = None, gradient_handlers: Optional[List[BaseGradientHandler]] = None,
clip_grad_norm: float = 0.0, clip_grad_norm: float = 0.0,
ophook_list: Optional[List[BaseOpHook]] = None, ophook_list: Optional[List[BaseOpHook]] = None,
verbose: bool = True): verbose: bool = True,
schedule: Optional[BaseSchedule] = None):
self._model = model self._model = model
self._optimizer = optimizer self._optimizer = optimizer
self._criterion = criterion self._criterion = criterion
@ -80,6 +83,14 @@ class Engine:
self._ophook_list = [] self._ophook_list = []
else: else:
self._ophook_list = ophook_list self._ophook_list = ophook_list
# build schedule
if schedule:
self._schedule = schedule
else:
self._schedule = NonPipelineSchedule()
if self.uses_pipeline:
self._schedule.pre_processing(self)
register_ophooks_recursively(self._model, self._ophook_list) register_ophooks_recursively(self._model, self._ophook_list)
@property @property
@ -102,6 +113,16 @@ class Engine:
"""Criterion attached to the engine""" """Criterion attached to the engine"""
return self._criterion return self._criterion
@property
def schedule(self):
"""Schedule attached to the engine"""
return self._schedule
@property
def uses_pipeline(self):
"""show the pipeline parallel used or not"""
return isinstance(self._schedule, (PipelineSchedule, InterleavedPipelineSchedule))
def add_hook(self, ophook: Type[BaseOpHook]) -> None: def add_hook(self, ophook: Type[BaseOpHook]) -> None:
"""add necessary hook""" """add necessary hook"""
# whether this hook exist # whether this hook exist
@ -165,6 +186,16 @@ class Engine:
""" """
for handler in self._gradient_handlers: for handler in self._gradient_handlers:
handler.handle_gradient() handler.handle_gradient()
def execute_schedule(self, data_iter: Iterable, **kwargs):
"""Run the forward, loss computation, and backward for the model.
Returns a tuple of (output, label, loss).
Returns:
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss).
"""
output, label, loss = self._schedule.forward_backward_step(self, data_iter, **kwargs)
return output, label, loss
def train(self): def train(self):
"""Sets the model to training mode. """Sets the model to training mode.
@ -176,4 +207,4 @@ class Engine:
"""Sets the model to evaluation mode. """Sets the model to evaluation mode.
""" """
self.training = False self.training = False
self._model.eval() self._model.eval()

View File

@ -6,7 +6,6 @@ from abc import ABC, abstractmethod
import torch import torch
from typing import Iterable, Callable from typing import Iterable, Callable
from .._base_engine import Engine
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
@ -75,14 +74,14 @@ class BaseSchedule(ABC):
return self._move_to_device(data), self._move_to_device(label) return self._move_to_device(data), self._move_to_device(label)
return data, label return data, label
def pre_processing(self, engine: Engine): def pre_processing(self, engine):
"""To perform actions before running the schedule. """To perform actions before running the schedule.
""" """
pass pass
@abstractmethod @abstractmethod
def forward_backward_step(self, def forward_backward_step(self,
engine: Engine, engine,
data_iter: Iterable, data_iter: Iterable,
forward_only: bool, forward_only: bool,
return_loss: bool = True, return_loss: bool = True,

View File

@ -5,7 +5,6 @@ from typing import Iterable
import torch import torch
from colossalai.engine import Engine
from ._base_schedule import BaseSchedule from ._base_schedule import BaseSchedule
from colossalai.utils import conditional_context from colossalai.utils import conditional_context
@ -22,7 +21,7 @@ class NonPipelineSchedule(BaseSchedule):
""" """
def forward_backward_step(self, def forward_backward_step(self,
engine: Engine, engine,
data_iter: Iterable, data_iter: Iterable,
forward_only: bool = False, forward_only: bool = False,
return_loss: bool = True, return_loss: bool = True,

View File

@ -20,6 +20,7 @@ from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.builder.builder import build_gradient_handler from colossalai.builder.builder import build_gradient_handler
from colossalai.context import Config, ConfigException, ParallelMode from colossalai.context import Config, ConfigException, ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.engine.schedule import NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule
from colossalai.context.moe_context import MOE_CONTEXT from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.engine import Engine from colossalai.engine import Engine
@ -388,6 +389,20 @@ def initialize(model: nn.Module,
if isinstance(model, DDP) and isinstance(model.module, NaiveAMPModel): if isinstance(model, DDP) and isinstance(model.module, NaiveAMPModel):
model.module.sync_buffer = False model.module.sync_buffer = False
# initialize schedule for engine
if is_using_pp():
tensor_shape = getattr(gpc.config, 'TENSOR_SHAPE', None)
use_interleaved = hasattr(gpc.config, 'model') and hasattr(gpc.config.model, 'num_chunks')
if use_interleaved:
schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
gpc.config.model.num_chunks, tensor_shape=tensor_shape, scatter_gather_tensors=True)
else:
schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
tensor_shape=tensor_shape, scatter_gather_tensors=True)
else:
schedule = NonPipelineSchedule()
if gradient_handler_cfg is None: if gradient_handler_cfg is None:
gradient_handlers = None gradient_handlers = None
if verbose and not isinstance(model, DDP): if verbose and not isinstance(model, DDP):
@ -418,6 +433,7 @@ def initialize(model: nn.Module,
criterion=criterion, criterion=criterion,
gradient_handlers=gradient_handlers, gradient_handlers=gradient_handlers,
clip_grad_norm=clip_grad_norm, clip_grad_norm=clip_grad_norm,
ophook_list=ophooks) ophook_list=ophooks,
schedule=schedule)
return engine, train_dataloader, test_dataloader, lr_scheduler return engine, train_dataloader, test_dataloader, lr_scheduler

View File

@ -9,7 +9,6 @@ from tqdm import tqdm
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.engine import Engine from colossalai.engine import Engine
from colossalai.engine.schedule import NonPipelineSchedule, BaseSchedule
from colossalai.logging import DistributedLogger from colossalai.logging import DistributedLogger
from colossalai.utils import MultiTimer from colossalai.utils import MultiTimer
from colossalai.utils import is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage from colossalai.utils import is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage
@ -23,13 +22,9 @@ class Trainer:
Args: Args:
engine (:class:`Engine`): Engine responsible for the process function. engine (:class:`Engine`): Engine responsible for the process function.
schedule (:class:`BaseSchedule`, optional): Schedule responsible for forward and backward steps.
timer (:class:`MultiTimer`, optional): Timer used to monitor the whole training. timer (:class:`MultiTimer`, optional): Timer used to monitor the whole training.
logger (:class:`colossalai.logging.DistributedLogger`, optional): Logger used to record the whole training log. logger (:class:`colossalai.logging.DistributedLogger`, optional): Logger used to record the whole training log.
Note:
when `schedule` is None, the ``NonPipelineSchedule`` would be used. If you would like to use pipeline,
you should choose ``PipelineSchedule`` or ``InterleavedPipelineSchedule`` for the `schedule`
Examples: Examples:
>>> # define model, criterion, optimizer, lr_scheduler, train_dataloader for your training >>> # define model, criterion, optimizer, lr_scheduler, train_dataloader for your training
@ -42,7 +37,7 @@ class Trainer:
>>> # Beginning training progress >>> # Beginning training progress
>>> timier = ... >>> timier = ...
>>> logger = ... >>> logger = ...
>>> trainer = Trainer(engine=engine, logger=logger, schedule=schedule, timer=timier) >>> trainer = Trainer(engine=engine, logger=logger, timer=timier)
>>> # add hooks you would like to use here. >>> # add hooks you would like to use here.
>>> hook_list = [] >>> hook_list = []
>>> trainer.fit( >>> trainer.fit(
@ -61,7 +56,6 @@ class Trainer:
def __init__( def __init__(
self, self,
engine: Engine, engine: Engine,
schedule: BaseSchedule = None,
timer: MultiTimer = None, timer: MultiTimer = None,
logger: DistributedLogger = None, logger: DistributedLogger = None,
): ):
@ -86,17 +80,6 @@ class Trainer:
# multi-timer for time benchmarking # multi-timer for time benchmarking
self._timer = timer self._timer = timer
# set schedule which specifies the training iteration for the engine
if schedule is None:
schedule = NonPipelineSchedule()
if (gpc.is_initialized(ParallelMode.PIPELINE)
and gpc.get_world_size(ParallelMode.PIPELINE) > 1):
assert not isinstance(
schedule, NonPipelineSchedule
), "NonPipelineSchedule cannot be used for pipeline parallel training, please use PipelineSchedule instead."
self._schedule = schedule
self._schedule.pre_processing(engine)
@property @property
def cur_epoch(self): def cur_epoch(self):
"""Returns the index of the current epoch.""" """Returns the index of the current epoch."""
@ -129,10 +112,6 @@ class Trainer:
def engine(self): def engine(self):
return self._engine return self._engine
@property
def schedule(self):
return self._schedule
def _set_current_step(self, epoch: int): def _set_current_step(self, epoch: int):
"""Sets current step number. """Sets current step number.
@ -203,8 +182,7 @@ class Trainer:
# run 1 training step # run 1 training step
self.engine.zero_grad() self.engine.zero_grad()
logits, label, loss = self.schedule.forward_backward_step( logits, label, loss = self.engine.execute_schedule(
self.engine,
data_iter, data_iter,
forward_only=False, forward_only=False,
return_loss=True, return_loss=True,
@ -260,8 +238,7 @@ class Trainer:
for _ in progress: for _ in progress:
self._call_hooks("before_test_iter") self._call_hooks("before_test_iter")
self._call_timer(action="start", item="Test-step") self._call_timer(action="start", item="Test-step")
logits, label, loss = self.schedule.forward_backward_step( logits, label, loss = self.engine.execute_schedule(
self.engine,
data_iter, data_iter,
forward_only=True, forward_only=True,
return_loss=True, return_loss=True,
@ -449,8 +426,7 @@ class Trainer:
# for compatibility with schedule # for compatibility with schedule
simple_dataloader = [(data, None)] simple_dataloader = [(data, None)]
data_iter = iter(simple_dataloader) data_iter = iter(simple_dataloader)
output, _, _ = self.schedule.forward_backward_step(self.engine, output, _, _ = self.engine.execute_schedule(data_iter,
data_iter, forward_only=True,
forward_only=True, return_loss=False)
return_loss=False)
return output return output

View File

@ -23,9 +23,9 @@ from torchvision.datasets import CIFAR10
BATCH_SIZE = 4 BATCH_SIZE = 4
NUM_EPOCHS = 60 NUM_EPOCHS = 60
WARMUP_EPOCHS = 5 WARMUP_EPOCHS = 5
CONFIG = dict(parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')), CONFIG = dict(NUM_MICRO_BATCHES=2, parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')),
fp16=dict(mode=AMP_TYPE.NAIVE), fp16=dict(mode=AMP_TYPE.NAIVE),
gradient_accumulation=2) gradient_accumulation=2)
def run_trainer(rank, world_size, port): def run_trainer(rank, world_size, port):
@ -63,10 +63,9 @@ def run_trainer(rank, world_size, port):
train_dataloader, train_dataloader,
lr_scheduler=lr_scheduler) lr_scheduler=lr_scheduler)
schedule = PipelineSchedule(num_microbatches=2)
logger = get_dist_logger() logger = get_dist_logger()
trainer = Trainer(engine=engine, logger=logger, schedule=schedule) trainer = Trainer(engine=engine, logger=logger)
hook_list = [ hook_list = [
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False), hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False),

View File

@ -7,6 +7,7 @@ IMG_SIZE = 224
DIM = 768 DIM = 768
NUM_CLASSES = 10 NUM_CLASSES = 10
NUM_ATTN_HEADS = 12 NUM_ATTN_HEADS = 12
NUM_MICRO_BATCHES = 2
# resnet 18 # resnet 18
model = dict(type='VanillaResNet', model = dict(type='VanillaResNet',

View File

@ -19,7 +19,6 @@ from torchvision import transforms
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
BATCH_SIZE = 4 BATCH_SIZE = 4
NUM_MICRO = 2
DIR_PATH = osp.dirname(osp.realpath(__file__)) DIR_PATH = osp.dirname(osp.realpath(__file__))
CONFIG_PATH = osp.join(DIR_PATH, './resnet_config.py') CONFIG_PATH = osp.join(DIR_PATH, './resnet_config.py')
@ -57,7 +56,7 @@ def run_schedule(rank, world_size, port):
engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criterion, train_dataloader) engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criterion, train_dataloader)
# build pipeline schedule # build pipeline schedule
schedule = PipelineSchedule(num_microbatches=NUM_MICRO) schedule = engine.schedule
# run schedule # run schedule
data_iter = iter(train_dataloader) data_iter = iter(train_dataloader)

View File

@ -23,7 +23,7 @@ BATCH_SIZE = 4
IMG_SIZE = 32 IMG_SIZE = 32
NUM_EPOCHS = 200 NUM_EPOCHS = 200
CONFIG = dict(parallel=dict(pipeline=2),) CONFIG = dict(NUM_MICRO_BATCHES=2, parallel=dict(pipeline=2),)
def run_trainer_with_pipeline(rank, world_size, port): def run_trainer_with_pipeline(rank, world_size, port):
@ -69,9 +69,8 @@ def run_trainer_with_pipeline(rank, world_size, port):
logger = get_dist_logger() logger = get_dist_logger()
logger.info("engine is built", ranks=[0]) logger.info("engine is built", ranks=[0])
pipe_schedule = PipelineSchedule(num_microbatches=2)
timer = MultiTimer() timer = MultiTimer()
trainer = Trainer(engine=engine, schedule=pipe_schedule, logger=logger, timer=timer) trainer = Trainer(engine=engine, logger=logger, timer=timer)
logger.info("trainer is built", ranks=[0]) logger.info("trainer is built", ranks=[0])
logger.info("start training", ranks=[0]) logger.info("start training", ranks=[0])