mirror of https://github.com/hpcaitech/ColossalAI
[refactor] pipeline, put runtime schedule into engine. (#627)
parent
e5d615aeee
commit
ade05a5d83
|
@ -2,7 +2,7 @@
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
from asyncio.log import logger
|
from asyncio.log import logger
|
||||||
from typing import List
|
from typing import List, Iterable
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
from torch.nn.modules.loss import _Loss
|
from torch.nn.modules.loss import _Loss
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
@ -10,6 +10,7 @@ from torch.optim import Optimizer
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
|
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
|
||||||
|
from colossalai.engine.schedule import BaseSchedule, NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule
|
||||||
from typing import Optional, Type
|
from typing import Optional, Type
|
||||||
from colossalai.engine.gradient_handler import BaseGradientHandler
|
from colossalai.engine.gradient_handler import BaseGradientHandler
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
|
@ -27,6 +28,7 @@ class Engine:
|
||||||
clip_grad_norm (float, optional): The norm of gradient clipping.
|
clip_grad_norm (float, optional): The norm of gradient clipping.
|
||||||
ophook_list (list): List of ophook.
|
ophook_list (list): List of ophook.
|
||||||
verbose (bool): whether to display log info.
|
verbose (bool): whether to display log info.
|
||||||
|
schedule (''BaseSchedule''): Runtime schedule.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> # define model, criterion, optimizer, lr_scheduler, train_dataloader for your training
|
>>> # define model, criterion, optimizer, lr_scheduler, train_dataloader for your training
|
||||||
|
@ -59,7 +61,8 @@ class Engine:
|
||||||
gradient_handlers: Optional[List[BaseGradientHandler]] = None,
|
gradient_handlers: Optional[List[BaseGradientHandler]] = None,
|
||||||
clip_grad_norm: float = 0.0,
|
clip_grad_norm: float = 0.0,
|
||||||
ophook_list: Optional[List[BaseOpHook]] = None,
|
ophook_list: Optional[List[BaseOpHook]] = None,
|
||||||
verbose: bool = True):
|
verbose: bool = True,
|
||||||
|
schedule: Optional[BaseSchedule] = None):
|
||||||
self._model = model
|
self._model = model
|
||||||
self._optimizer = optimizer
|
self._optimizer = optimizer
|
||||||
self._criterion = criterion
|
self._criterion = criterion
|
||||||
|
@ -80,6 +83,14 @@ class Engine:
|
||||||
self._ophook_list = []
|
self._ophook_list = []
|
||||||
else:
|
else:
|
||||||
self._ophook_list = ophook_list
|
self._ophook_list = ophook_list
|
||||||
|
|
||||||
|
# build schedule
|
||||||
|
if schedule:
|
||||||
|
self._schedule = schedule
|
||||||
|
else:
|
||||||
|
self._schedule = NonPipelineSchedule()
|
||||||
|
if self.uses_pipeline:
|
||||||
|
self._schedule.pre_processing(self)
|
||||||
register_ophooks_recursively(self._model, self._ophook_list)
|
register_ophooks_recursively(self._model, self._ophook_list)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -102,6 +113,16 @@ class Engine:
|
||||||
"""Criterion attached to the engine"""
|
"""Criterion attached to the engine"""
|
||||||
return self._criterion
|
return self._criterion
|
||||||
|
|
||||||
|
@property
|
||||||
|
def schedule(self):
|
||||||
|
"""Schedule attached to the engine"""
|
||||||
|
return self._schedule
|
||||||
|
|
||||||
|
@property
|
||||||
|
def uses_pipeline(self):
|
||||||
|
"""show the pipeline parallel used or not"""
|
||||||
|
return isinstance(self._schedule, (PipelineSchedule, InterleavedPipelineSchedule))
|
||||||
|
|
||||||
def add_hook(self, ophook: Type[BaseOpHook]) -> None:
|
def add_hook(self, ophook: Type[BaseOpHook]) -> None:
|
||||||
"""add necessary hook"""
|
"""add necessary hook"""
|
||||||
# whether this hook exist
|
# whether this hook exist
|
||||||
|
@ -165,6 +186,16 @@ class Engine:
|
||||||
"""
|
"""
|
||||||
for handler in self._gradient_handlers:
|
for handler in self._gradient_handlers:
|
||||||
handler.handle_gradient()
|
handler.handle_gradient()
|
||||||
|
|
||||||
|
def execute_schedule(self, data_iter: Iterable, **kwargs):
|
||||||
|
"""Run the forward, loss computation, and backward for the model.
|
||||||
|
Returns a tuple of (output, label, loss).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss).
|
||||||
|
"""
|
||||||
|
output, label, loss = self._schedule.forward_backward_step(self, data_iter, **kwargs)
|
||||||
|
return output, label, loss
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
"""Sets the model to training mode.
|
"""Sets the model to training mode.
|
||||||
|
@ -176,4 +207,4 @@ class Engine:
|
||||||
"""Sets the model to evaluation mode.
|
"""Sets the model to evaluation mode.
|
||||||
"""
|
"""
|
||||||
self.training = False
|
self.training = False
|
||||||
self._model.eval()
|
self._model.eval()
|
||||||
|
|
|
@ -6,7 +6,6 @@ from abc import ABC, abstractmethod
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from typing import Iterable, Callable
|
from typing import Iterable, Callable
|
||||||
from .._base_engine import Engine
|
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
@ -75,14 +74,14 @@ class BaseSchedule(ABC):
|
||||||
return self._move_to_device(data), self._move_to_device(label)
|
return self._move_to_device(data), self._move_to_device(label)
|
||||||
return data, label
|
return data, label
|
||||||
|
|
||||||
def pre_processing(self, engine: Engine):
|
def pre_processing(self, engine):
|
||||||
"""To perform actions before running the schedule.
|
"""To perform actions before running the schedule.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def forward_backward_step(self,
|
def forward_backward_step(self,
|
||||||
engine: Engine,
|
engine,
|
||||||
data_iter: Iterable,
|
data_iter: Iterable,
|
||||||
forward_only: bool,
|
forward_only: bool,
|
||||||
return_loss: bool = True,
|
return_loss: bool = True,
|
||||||
|
|
|
@ -5,7 +5,6 @@ from typing import Iterable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from colossalai.engine import Engine
|
|
||||||
from ._base_schedule import BaseSchedule
|
from ._base_schedule import BaseSchedule
|
||||||
from colossalai.utils import conditional_context
|
from colossalai.utils import conditional_context
|
||||||
|
|
||||||
|
@ -22,7 +21,7 @@ class NonPipelineSchedule(BaseSchedule):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def forward_backward_step(self,
|
def forward_backward_step(self,
|
||||||
engine: Engine,
|
engine,
|
||||||
data_iter: Iterable,
|
data_iter: Iterable,
|
||||||
forward_only: bool = False,
|
forward_only: bool = False,
|
||||||
return_loss: bool = True,
|
return_loss: bool = True,
|
||||||
|
|
|
@ -20,6 +20,7 @@ from colossalai.amp.naive_amp import NaiveAMPModel
|
||||||
from colossalai.builder.builder import build_gradient_handler
|
from colossalai.builder.builder import build_gradient_handler
|
||||||
from colossalai.context import Config, ConfigException, ParallelMode
|
from colossalai.context import Config, ConfigException, ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.engine.schedule import NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule
|
||||||
|
|
||||||
from colossalai.context.moe_context import MOE_CONTEXT
|
from colossalai.context.moe_context import MOE_CONTEXT
|
||||||
from colossalai.engine import Engine
|
from colossalai.engine import Engine
|
||||||
|
@ -388,6 +389,20 @@ def initialize(model: nn.Module,
|
||||||
if isinstance(model, DDP) and isinstance(model.module, NaiveAMPModel):
|
if isinstance(model, DDP) and isinstance(model.module, NaiveAMPModel):
|
||||||
model.module.sync_buffer = False
|
model.module.sync_buffer = False
|
||||||
|
|
||||||
|
# initialize schedule for engine
|
||||||
|
if is_using_pp():
|
||||||
|
tensor_shape = getattr(gpc.config, 'TENSOR_SHAPE', None)
|
||||||
|
use_interleaved = hasattr(gpc.config, 'model') and hasattr(gpc.config.model, 'num_chunks')
|
||||||
|
if use_interleaved:
|
||||||
|
schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
|
||||||
|
gpc.config.model.num_chunks, tensor_shape=tensor_shape, scatter_gather_tensors=True)
|
||||||
|
else:
|
||||||
|
schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
|
||||||
|
tensor_shape=tensor_shape, scatter_gather_tensors=True)
|
||||||
|
else:
|
||||||
|
schedule = NonPipelineSchedule()
|
||||||
|
|
||||||
|
|
||||||
if gradient_handler_cfg is None:
|
if gradient_handler_cfg is None:
|
||||||
gradient_handlers = None
|
gradient_handlers = None
|
||||||
if verbose and not isinstance(model, DDP):
|
if verbose and not isinstance(model, DDP):
|
||||||
|
@ -418,6 +433,7 @@ def initialize(model: nn.Module,
|
||||||
criterion=criterion,
|
criterion=criterion,
|
||||||
gradient_handlers=gradient_handlers,
|
gradient_handlers=gradient_handlers,
|
||||||
clip_grad_norm=clip_grad_norm,
|
clip_grad_norm=clip_grad_norm,
|
||||||
ophook_list=ophooks)
|
ophook_list=ophooks,
|
||||||
|
schedule=schedule)
|
||||||
|
|
||||||
return engine, train_dataloader, test_dataloader, lr_scheduler
|
return engine, train_dataloader, test_dataloader, lr_scheduler
|
||||||
|
|
|
@ -9,7 +9,6 @@ from tqdm import tqdm
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
|
|
||||||
from colossalai.engine import Engine
|
from colossalai.engine import Engine
|
||||||
from colossalai.engine.schedule import NonPipelineSchedule, BaseSchedule
|
|
||||||
from colossalai.logging import DistributedLogger
|
from colossalai.logging import DistributedLogger
|
||||||
from colossalai.utils import MultiTimer
|
from colossalai.utils import MultiTimer
|
||||||
from colossalai.utils import is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage
|
from colossalai.utils import is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage
|
||||||
|
@ -23,13 +22,9 @@ class Trainer:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
engine (:class:`Engine`): Engine responsible for the process function.
|
engine (:class:`Engine`): Engine responsible for the process function.
|
||||||
schedule (:class:`BaseSchedule`, optional): Schedule responsible for forward and backward steps.
|
|
||||||
timer (:class:`MultiTimer`, optional): Timer used to monitor the whole training.
|
timer (:class:`MultiTimer`, optional): Timer used to monitor the whole training.
|
||||||
logger (:class:`colossalai.logging.DistributedLogger`, optional): Logger used to record the whole training log.
|
logger (:class:`colossalai.logging.DistributedLogger`, optional): Logger used to record the whole training log.
|
||||||
|
|
||||||
Note:
|
|
||||||
when `schedule` is None, the ``NonPipelineSchedule`` would be used. If you would like to use pipeline,
|
|
||||||
you should choose ``PipelineSchedule`` or ``InterleavedPipelineSchedule`` for the `schedule`
|
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> # define model, criterion, optimizer, lr_scheduler, train_dataloader for your training
|
>>> # define model, criterion, optimizer, lr_scheduler, train_dataloader for your training
|
||||||
|
@ -42,7 +37,7 @@ class Trainer:
|
||||||
>>> # Beginning training progress
|
>>> # Beginning training progress
|
||||||
>>> timier = ...
|
>>> timier = ...
|
||||||
>>> logger = ...
|
>>> logger = ...
|
||||||
>>> trainer = Trainer(engine=engine, logger=logger, schedule=schedule, timer=timier)
|
>>> trainer = Trainer(engine=engine, logger=logger, timer=timier)
|
||||||
>>> # add hooks you would like to use here.
|
>>> # add hooks you would like to use here.
|
||||||
>>> hook_list = []
|
>>> hook_list = []
|
||||||
>>> trainer.fit(
|
>>> trainer.fit(
|
||||||
|
@ -61,7 +56,6 @@ class Trainer:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
engine: Engine,
|
engine: Engine,
|
||||||
schedule: BaseSchedule = None,
|
|
||||||
timer: MultiTimer = None,
|
timer: MultiTimer = None,
|
||||||
logger: DistributedLogger = None,
|
logger: DistributedLogger = None,
|
||||||
):
|
):
|
||||||
|
@ -86,17 +80,6 @@ class Trainer:
|
||||||
# multi-timer for time benchmarking
|
# multi-timer for time benchmarking
|
||||||
self._timer = timer
|
self._timer = timer
|
||||||
|
|
||||||
# set schedule which specifies the training iteration for the engine
|
|
||||||
if schedule is None:
|
|
||||||
schedule = NonPipelineSchedule()
|
|
||||||
if (gpc.is_initialized(ParallelMode.PIPELINE)
|
|
||||||
and gpc.get_world_size(ParallelMode.PIPELINE) > 1):
|
|
||||||
assert not isinstance(
|
|
||||||
schedule, NonPipelineSchedule
|
|
||||||
), "NonPipelineSchedule cannot be used for pipeline parallel training, please use PipelineSchedule instead."
|
|
||||||
self._schedule = schedule
|
|
||||||
self._schedule.pre_processing(engine)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cur_epoch(self):
|
def cur_epoch(self):
|
||||||
"""Returns the index of the current epoch."""
|
"""Returns the index of the current epoch."""
|
||||||
|
@ -129,10 +112,6 @@ class Trainer:
|
||||||
def engine(self):
|
def engine(self):
|
||||||
return self._engine
|
return self._engine
|
||||||
|
|
||||||
@property
|
|
||||||
def schedule(self):
|
|
||||||
return self._schedule
|
|
||||||
|
|
||||||
def _set_current_step(self, epoch: int):
|
def _set_current_step(self, epoch: int):
|
||||||
"""Sets current step number.
|
"""Sets current step number.
|
||||||
|
|
||||||
|
@ -203,8 +182,7 @@ class Trainer:
|
||||||
|
|
||||||
# run 1 training step
|
# run 1 training step
|
||||||
self.engine.zero_grad()
|
self.engine.zero_grad()
|
||||||
logits, label, loss = self.schedule.forward_backward_step(
|
logits, label, loss = self.engine.execute_schedule(
|
||||||
self.engine,
|
|
||||||
data_iter,
|
data_iter,
|
||||||
forward_only=False,
|
forward_only=False,
|
||||||
return_loss=True,
|
return_loss=True,
|
||||||
|
@ -260,8 +238,7 @@ class Trainer:
|
||||||
for _ in progress:
|
for _ in progress:
|
||||||
self._call_hooks("before_test_iter")
|
self._call_hooks("before_test_iter")
|
||||||
self._call_timer(action="start", item="Test-step")
|
self._call_timer(action="start", item="Test-step")
|
||||||
logits, label, loss = self.schedule.forward_backward_step(
|
logits, label, loss = self.engine.execute_schedule(
|
||||||
self.engine,
|
|
||||||
data_iter,
|
data_iter,
|
||||||
forward_only=True,
|
forward_only=True,
|
||||||
return_loss=True,
|
return_loss=True,
|
||||||
|
@ -449,8 +426,7 @@ class Trainer:
|
||||||
# for compatibility with schedule
|
# for compatibility with schedule
|
||||||
simple_dataloader = [(data, None)]
|
simple_dataloader = [(data, None)]
|
||||||
data_iter = iter(simple_dataloader)
|
data_iter = iter(simple_dataloader)
|
||||||
output, _, _ = self.schedule.forward_backward_step(self.engine,
|
output, _, _ = self.engine.execute_schedule(data_iter,
|
||||||
data_iter,
|
forward_only=True,
|
||||||
forward_only=True,
|
return_loss=False)
|
||||||
return_loss=False)
|
|
||||||
return output
|
return output
|
||||||
|
|
|
@ -23,9 +23,9 @@ from torchvision.datasets import CIFAR10
|
||||||
BATCH_SIZE = 4
|
BATCH_SIZE = 4
|
||||||
NUM_EPOCHS = 60
|
NUM_EPOCHS = 60
|
||||||
WARMUP_EPOCHS = 5
|
WARMUP_EPOCHS = 5
|
||||||
CONFIG = dict(parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')),
|
CONFIG = dict(NUM_MICRO_BATCHES=2, parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')),
|
||||||
fp16=dict(mode=AMP_TYPE.NAIVE),
|
fp16=dict(mode=AMP_TYPE.NAIVE),
|
||||||
gradient_accumulation=2)
|
gradient_accumulation=2)
|
||||||
|
|
||||||
|
|
||||||
def run_trainer(rank, world_size, port):
|
def run_trainer(rank, world_size, port):
|
||||||
|
@ -63,10 +63,9 @@ def run_trainer(rank, world_size, port):
|
||||||
train_dataloader,
|
train_dataloader,
|
||||||
lr_scheduler=lr_scheduler)
|
lr_scheduler=lr_scheduler)
|
||||||
|
|
||||||
schedule = PipelineSchedule(num_microbatches=2)
|
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
|
|
||||||
trainer = Trainer(engine=engine, logger=logger, schedule=schedule)
|
trainer = Trainer(engine=engine, logger=logger)
|
||||||
|
|
||||||
hook_list = [
|
hook_list = [
|
||||||
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False),
|
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False),
|
||||||
|
|
|
@ -7,6 +7,7 @@ IMG_SIZE = 224
|
||||||
DIM = 768
|
DIM = 768
|
||||||
NUM_CLASSES = 10
|
NUM_CLASSES = 10
|
||||||
NUM_ATTN_HEADS = 12
|
NUM_ATTN_HEADS = 12
|
||||||
|
NUM_MICRO_BATCHES = 2
|
||||||
|
|
||||||
# resnet 18
|
# resnet 18
|
||||||
model = dict(type='VanillaResNet',
|
model = dict(type='VanillaResNet',
|
||||||
|
|
|
@ -19,7 +19,6 @@ from torchvision import transforms
|
||||||
from torchvision.datasets import CIFAR10
|
from torchvision.datasets import CIFAR10
|
||||||
|
|
||||||
BATCH_SIZE = 4
|
BATCH_SIZE = 4
|
||||||
NUM_MICRO = 2
|
|
||||||
|
|
||||||
DIR_PATH = osp.dirname(osp.realpath(__file__))
|
DIR_PATH = osp.dirname(osp.realpath(__file__))
|
||||||
CONFIG_PATH = osp.join(DIR_PATH, './resnet_config.py')
|
CONFIG_PATH = osp.join(DIR_PATH, './resnet_config.py')
|
||||||
|
@ -57,7 +56,7 @@ def run_schedule(rank, world_size, port):
|
||||||
engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criterion, train_dataloader)
|
engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criterion, train_dataloader)
|
||||||
|
|
||||||
# build pipeline schedule
|
# build pipeline schedule
|
||||||
schedule = PipelineSchedule(num_microbatches=NUM_MICRO)
|
schedule = engine.schedule
|
||||||
|
|
||||||
# run schedule
|
# run schedule
|
||||||
data_iter = iter(train_dataloader)
|
data_iter = iter(train_dataloader)
|
||||||
|
|
|
@ -23,7 +23,7 @@ BATCH_SIZE = 4
|
||||||
IMG_SIZE = 32
|
IMG_SIZE = 32
|
||||||
NUM_EPOCHS = 200
|
NUM_EPOCHS = 200
|
||||||
|
|
||||||
CONFIG = dict(parallel=dict(pipeline=2),)
|
CONFIG = dict(NUM_MICRO_BATCHES=2, parallel=dict(pipeline=2),)
|
||||||
|
|
||||||
|
|
||||||
def run_trainer_with_pipeline(rank, world_size, port):
|
def run_trainer_with_pipeline(rank, world_size, port):
|
||||||
|
@ -69,9 +69,8 @@ def run_trainer_with_pipeline(rank, world_size, port):
|
||||||
|
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
logger.info("engine is built", ranks=[0])
|
logger.info("engine is built", ranks=[0])
|
||||||
pipe_schedule = PipelineSchedule(num_microbatches=2)
|
|
||||||
timer = MultiTimer()
|
timer = MultiTimer()
|
||||||
trainer = Trainer(engine=engine, schedule=pipe_schedule, logger=logger, timer=timer)
|
trainer = Trainer(engine=engine, logger=logger, timer=timer)
|
||||||
logger.info("trainer is built", ranks=[0])
|
logger.info("trainer is built", ranks=[0])
|
||||||
|
|
||||||
logger.info("start training", ranks=[0])
|
logger.info("start training", ranks=[0])
|
||||||
|
|
Loading…
Reference in New Issue