mirror of https://github.com/hpcaitech/ColossalAI
[refactor] pipeline, put runtime schedule into engine. (#627)
parent
e5d615aeee
commit
ade05a5d83
|
@ -2,7 +2,7 @@
|
|||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from asyncio.log import logger
|
||||
from typing import List
|
||||
from typing import List, Iterable
|
||||
from torch.nn import Module
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from torch.optim import Optimizer
|
||||
|
@ -10,6 +10,7 @@ from torch.optim import Optimizer
|
|||
from colossalai.logging import get_dist_logger
|
||||
from torch import Tensor
|
||||
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
|
||||
from colossalai.engine.schedule import BaseSchedule, NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule
|
||||
from typing import Optional, Type
|
||||
from colossalai.engine.gradient_handler import BaseGradientHandler
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
@ -27,6 +28,7 @@ class Engine:
|
|||
clip_grad_norm (float, optional): The norm of gradient clipping.
|
||||
ophook_list (list): List of ophook.
|
||||
verbose (bool): whether to display log info.
|
||||
schedule (''BaseSchedule''): Runtime schedule.
|
||||
|
||||
Examples:
|
||||
>>> # define model, criterion, optimizer, lr_scheduler, train_dataloader for your training
|
||||
|
@ -59,7 +61,8 @@ class Engine:
|
|||
gradient_handlers: Optional[List[BaseGradientHandler]] = None,
|
||||
clip_grad_norm: float = 0.0,
|
||||
ophook_list: Optional[List[BaseOpHook]] = None,
|
||||
verbose: bool = True):
|
||||
verbose: bool = True,
|
||||
schedule: Optional[BaseSchedule] = None):
|
||||
self._model = model
|
||||
self._optimizer = optimizer
|
||||
self._criterion = criterion
|
||||
|
@ -80,6 +83,14 @@ class Engine:
|
|||
self._ophook_list = []
|
||||
else:
|
||||
self._ophook_list = ophook_list
|
||||
|
||||
# build schedule
|
||||
if schedule:
|
||||
self._schedule = schedule
|
||||
else:
|
||||
self._schedule = NonPipelineSchedule()
|
||||
if self.uses_pipeline:
|
||||
self._schedule.pre_processing(self)
|
||||
register_ophooks_recursively(self._model, self._ophook_list)
|
||||
|
||||
@property
|
||||
|
@ -102,6 +113,16 @@ class Engine:
|
|||
"""Criterion attached to the engine"""
|
||||
return self._criterion
|
||||
|
||||
@property
|
||||
def schedule(self):
|
||||
"""Schedule attached to the engine"""
|
||||
return self._schedule
|
||||
|
||||
@property
|
||||
def uses_pipeline(self):
|
||||
"""show the pipeline parallel used or not"""
|
||||
return isinstance(self._schedule, (PipelineSchedule, InterleavedPipelineSchedule))
|
||||
|
||||
def add_hook(self, ophook: Type[BaseOpHook]) -> None:
|
||||
"""add necessary hook"""
|
||||
# whether this hook exist
|
||||
|
@ -166,6 +187,16 @@ class Engine:
|
|||
for handler in self._gradient_handlers:
|
||||
handler.handle_gradient()
|
||||
|
||||
def execute_schedule(self, data_iter: Iterable, **kwargs):
|
||||
"""Run the forward, loss computation, and backward for the model.
|
||||
Returns a tuple of (output, label, loss).
|
||||
|
||||
Returns:
|
||||
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss).
|
||||
"""
|
||||
output, label, loss = self._schedule.forward_backward_step(self, data_iter, **kwargs)
|
||||
return output, label, loss
|
||||
|
||||
def train(self):
|
||||
"""Sets the model to training mode.
|
||||
"""
|
||||
|
|
|
@ -6,7 +6,6 @@ from abc import ABC, abstractmethod
|
|||
import torch
|
||||
|
||||
from typing import Iterable, Callable
|
||||
from .._base_engine import Engine
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
@ -75,14 +74,14 @@ class BaseSchedule(ABC):
|
|||
return self._move_to_device(data), self._move_to_device(label)
|
||||
return data, label
|
||||
|
||||
def pre_processing(self, engine: Engine):
|
||||
def pre_processing(self, engine):
|
||||
"""To perform actions before running the schedule.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def forward_backward_step(self,
|
||||
engine: Engine,
|
||||
engine,
|
||||
data_iter: Iterable,
|
||||
forward_only: bool,
|
||||
return_loss: bool = True,
|
||||
|
|
|
@ -5,7 +5,6 @@ from typing import Iterable
|
|||
|
||||
import torch
|
||||
|
||||
from colossalai.engine import Engine
|
||||
from ._base_schedule import BaseSchedule
|
||||
from colossalai.utils import conditional_context
|
||||
|
||||
|
@ -22,7 +21,7 @@ class NonPipelineSchedule(BaseSchedule):
|
|||
"""
|
||||
|
||||
def forward_backward_step(self,
|
||||
engine: Engine,
|
||||
engine,
|
||||
data_iter: Iterable,
|
||||
forward_only: bool = False,
|
||||
return_loss: bool = True,
|
||||
|
|
|
@ -20,6 +20,7 @@ from colossalai.amp.naive_amp import NaiveAMPModel
|
|||
from colossalai.builder.builder import build_gradient_handler
|
||||
from colossalai.context import Config, ConfigException, ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.engine.schedule import NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule
|
||||
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
from colossalai.engine import Engine
|
||||
|
@ -388,6 +389,20 @@ def initialize(model: nn.Module,
|
|||
if isinstance(model, DDP) and isinstance(model.module, NaiveAMPModel):
|
||||
model.module.sync_buffer = False
|
||||
|
||||
# initialize schedule for engine
|
||||
if is_using_pp():
|
||||
tensor_shape = getattr(gpc.config, 'TENSOR_SHAPE', None)
|
||||
use_interleaved = hasattr(gpc.config, 'model') and hasattr(gpc.config.model, 'num_chunks')
|
||||
if use_interleaved:
|
||||
schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
|
||||
gpc.config.model.num_chunks, tensor_shape=tensor_shape, scatter_gather_tensors=True)
|
||||
else:
|
||||
schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
|
||||
tensor_shape=tensor_shape, scatter_gather_tensors=True)
|
||||
else:
|
||||
schedule = NonPipelineSchedule()
|
||||
|
||||
|
||||
if gradient_handler_cfg is None:
|
||||
gradient_handlers = None
|
||||
if verbose and not isinstance(model, DDP):
|
||||
|
@ -418,6 +433,7 @@ def initialize(model: nn.Module,
|
|||
criterion=criterion,
|
||||
gradient_handlers=gradient_handlers,
|
||||
clip_grad_norm=clip_grad_norm,
|
||||
ophook_list=ophooks)
|
||||
ophook_list=ophooks,
|
||||
schedule=schedule)
|
||||
|
||||
return engine, train_dataloader, test_dataloader, lr_scheduler
|
||||
|
|
|
@ -9,7 +9,6 @@ from tqdm import tqdm
|
|||
from colossalai.core import global_context as gpc
|
||||
|
||||
from colossalai.engine import Engine
|
||||
from colossalai.engine.schedule import NonPipelineSchedule, BaseSchedule
|
||||
from colossalai.logging import DistributedLogger
|
||||
from colossalai.utils import MultiTimer
|
||||
from colossalai.utils import is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage
|
||||
|
@ -23,13 +22,9 @@ class Trainer:
|
|||
|
||||
Args:
|
||||
engine (:class:`Engine`): Engine responsible for the process function.
|
||||
schedule (:class:`BaseSchedule`, optional): Schedule responsible for forward and backward steps.
|
||||
timer (:class:`MultiTimer`, optional): Timer used to monitor the whole training.
|
||||
logger (:class:`colossalai.logging.DistributedLogger`, optional): Logger used to record the whole training log.
|
||||
|
||||
Note:
|
||||
when `schedule` is None, the ``NonPipelineSchedule`` would be used. If you would like to use pipeline,
|
||||
you should choose ``PipelineSchedule`` or ``InterleavedPipelineSchedule`` for the `schedule`
|
||||
|
||||
Examples:
|
||||
>>> # define model, criterion, optimizer, lr_scheduler, train_dataloader for your training
|
||||
|
@ -42,7 +37,7 @@ class Trainer:
|
|||
>>> # Beginning training progress
|
||||
>>> timier = ...
|
||||
>>> logger = ...
|
||||
>>> trainer = Trainer(engine=engine, logger=logger, schedule=schedule, timer=timier)
|
||||
>>> trainer = Trainer(engine=engine, logger=logger, timer=timier)
|
||||
>>> # add hooks you would like to use here.
|
||||
>>> hook_list = []
|
||||
>>> trainer.fit(
|
||||
|
@ -61,7 +56,6 @@ class Trainer:
|
|||
def __init__(
|
||||
self,
|
||||
engine: Engine,
|
||||
schedule: BaseSchedule = None,
|
||||
timer: MultiTimer = None,
|
||||
logger: DistributedLogger = None,
|
||||
):
|
||||
|
@ -86,17 +80,6 @@ class Trainer:
|
|||
# multi-timer for time benchmarking
|
||||
self._timer = timer
|
||||
|
||||
# set schedule which specifies the training iteration for the engine
|
||||
if schedule is None:
|
||||
schedule = NonPipelineSchedule()
|
||||
if (gpc.is_initialized(ParallelMode.PIPELINE)
|
||||
and gpc.get_world_size(ParallelMode.PIPELINE) > 1):
|
||||
assert not isinstance(
|
||||
schedule, NonPipelineSchedule
|
||||
), "NonPipelineSchedule cannot be used for pipeline parallel training, please use PipelineSchedule instead."
|
||||
self._schedule = schedule
|
||||
self._schedule.pre_processing(engine)
|
||||
|
||||
@property
|
||||
def cur_epoch(self):
|
||||
"""Returns the index of the current epoch."""
|
||||
|
@ -129,10 +112,6 @@ class Trainer:
|
|||
def engine(self):
|
||||
return self._engine
|
||||
|
||||
@property
|
||||
def schedule(self):
|
||||
return self._schedule
|
||||
|
||||
def _set_current_step(self, epoch: int):
|
||||
"""Sets current step number.
|
||||
|
||||
|
@ -203,8 +182,7 @@ class Trainer:
|
|||
|
||||
# run 1 training step
|
||||
self.engine.zero_grad()
|
||||
logits, label, loss = self.schedule.forward_backward_step(
|
||||
self.engine,
|
||||
logits, label, loss = self.engine.execute_schedule(
|
||||
data_iter,
|
||||
forward_only=False,
|
||||
return_loss=True,
|
||||
|
@ -260,8 +238,7 @@ class Trainer:
|
|||
for _ in progress:
|
||||
self._call_hooks("before_test_iter")
|
||||
self._call_timer(action="start", item="Test-step")
|
||||
logits, label, loss = self.schedule.forward_backward_step(
|
||||
self.engine,
|
||||
logits, label, loss = self.engine.execute_schedule(
|
||||
data_iter,
|
||||
forward_only=True,
|
||||
return_loss=True,
|
||||
|
@ -449,8 +426,7 @@ class Trainer:
|
|||
# for compatibility with schedule
|
||||
simple_dataloader = [(data, None)]
|
||||
data_iter = iter(simple_dataloader)
|
||||
output, _, _ = self.schedule.forward_backward_step(self.engine,
|
||||
data_iter,
|
||||
output, _, _ = self.engine.execute_schedule(data_iter,
|
||||
forward_only=True,
|
||||
return_loss=False)
|
||||
return output
|
||||
|
|
|
@ -23,7 +23,7 @@ from torchvision.datasets import CIFAR10
|
|||
BATCH_SIZE = 4
|
||||
NUM_EPOCHS = 60
|
||||
WARMUP_EPOCHS = 5
|
||||
CONFIG = dict(parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')),
|
||||
CONFIG = dict(NUM_MICRO_BATCHES=2, parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')),
|
||||
fp16=dict(mode=AMP_TYPE.NAIVE),
|
||||
gradient_accumulation=2)
|
||||
|
||||
|
@ -63,10 +63,9 @@ def run_trainer(rank, world_size, port):
|
|||
train_dataloader,
|
||||
lr_scheduler=lr_scheduler)
|
||||
|
||||
schedule = PipelineSchedule(num_microbatches=2)
|
||||
logger = get_dist_logger()
|
||||
|
||||
trainer = Trainer(engine=engine, logger=logger, schedule=schedule)
|
||||
trainer = Trainer(engine=engine, logger=logger)
|
||||
|
||||
hook_list = [
|
||||
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False),
|
||||
|
|
|
@ -7,6 +7,7 @@ IMG_SIZE = 224
|
|||
DIM = 768
|
||||
NUM_CLASSES = 10
|
||||
NUM_ATTN_HEADS = 12
|
||||
NUM_MICRO_BATCHES = 2
|
||||
|
||||
# resnet 18
|
||||
model = dict(type='VanillaResNet',
|
||||
|
|
|
@ -19,7 +19,6 @@ from torchvision import transforms
|
|||
from torchvision.datasets import CIFAR10
|
||||
|
||||
BATCH_SIZE = 4
|
||||
NUM_MICRO = 2
|
||||
|
||||
DIR_PATH = osp.dirname(osp.realpath(__file__))
|
||||
CONFIG_PATH = osp.join(DIR_PATH, './resnet_config.py')
|
||||
|
@ -57,7 +56,7 @@ def run_schedule(rank, world_size, port):
|
|||
engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criterion, train_dataloader)
|
||||
|
||||
# build pipeline schedule
|
||||
schedule = PipelineSchedule(num_microbatches=NUM_MICRO)
|
||||
schedule = engine.schedule
|
||||
|
||||
# run schedule
|
||||
data_iter = iter(train_dataloader)
|
||||
|
|
|
@ -23,7 +23,7 @@ BATCH_SIZE = 4
|
|||
IMG_SIZE = 32
|
||||
NUM_EPOCHS = 200
|
||||
|
||||
CONFIG = dict(parallel=dict(pipeline=2),)
|
||||
CONFIG = dict(NUM_MICRO_BATCHES=2, parallel=dict(pipeline=2),)
|
||||
|
||||
|
||||
def run_trainer_with_pipeline(rank, world_size, port):
|
||||
|
@ -69,9 +69,8 @@ def run_trainer_with_pipeline(rank, world_size, port):
|
|||
|
||||
logger = get_dist_logger()
|
||||
logger.info("engine is built", ranks=[0])
|
||||
pipe_schedule = PipelineSchedule(num_microbatches=2)
|
||||
timer = MultiTimer()
|
||||
trainer = Trainer(engine=engine, schedule=pipe_schedule, logger=logger, timer=timer)
|
||||
trainer = Trainer(engine=engine, logger=logger, timer=timer)
|
||||
logger.info("trainer is built", ranks=[0])
|
||||
|
||||
logger.info("start training", ranks=[0])
|
||||
|
|
Loading…
Reference in New Issue