[pipeline] refactor 1f1b schedule (#4115)

* [api] update optimizer wrapper to fit pipeline

* [pipeline] add base schedule

* [pipeline] add 1f1b schedule

* [test] add pipeline schedule utils test

* [pipeline] fix import
pull/4445/head
Hongxin Liu 2023-06-29 13:35:39 +08:00
parent 45fdc9b42c
commit f51ce1bc8e
6 changed files with 451 additions and 0 deletions

View File

@ -1,5 +1,6 @@
from typing import Union from typing import Union
import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from torch.optim import Optimizer from torch.optim import Optimizer
@ -53,6 +54,9 @@ class OptimizerWrapper:
""" """
loss.backward(*args, **kwargs) loss.backward(*args, **kwargs)
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
torch.autograd.backward(tensor, grad)
def state_dict(self): def state_dict(self):
""" """
Returns the optimizer state. Returns the optimizer state.

View File

@ -0,0 +1,7 @@
from .base import PipelineSchedule
from .one_f_one_b import OneForwardOneBackwardSchedule
__all__ = [
'PipelineSchedule',
'OneForwardOneBackwardSchedule',
]

View File

@ -0,0 +1,129 @@
from typing import Any, List, Optional
import torch
import torch.cuda
from torch.nn import Module
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
def to_device(x: Any, device: Optional[torch.device] = None) -> Any:
"""Move object to device if it is a tensor.
Args:
x (Any): Object to be moved.
device (Optional[torch.device], optional): Target device. Defaults to None.
Returns:
Any: Moved object.
"""
if isinstance(x, torch.Tensor):
return x.to(device)
return x
def get_batch_size(batch: Any) -> int:
"""Get the batch size (size of dimension-0) of the first tensor in the batch.
Args:
batch (Any): Batch to be inspected.
Raises:
RuntimeError: If no tensor is found in the batch.
Returns:
int: Batch size.
"""
data_list, _ = tree_flatten(batch)
for data in data_list:
if isinstance(data, torch.Tensor):
return data.size(0)
raise RuntimeError('No tensor found in the batch')
def get_micro_batch(batch: Any, start: int, micro_batch_size: int) -> Any:
"""Get a micro batch of the original batch.
Args:
batch (Any): Batch to be sliced.
start (int): Start index of the micro batch.
micro_batch_size (int): Size of the micro batch.
Returns:
Any: Target micro batch.
"""
def _get_tensor_slice(x: Any):
if isinstance(x, torch.Tensor):
return x[start:start + micro_batch_size]
return x
return tree_map(_get_tensor_slice, batch)
def model_forward(model: Module, data: Any, internal_inputs: Optional[dict]) -> Any:
"""Call model forward function with data and internal inputs.
Args:
model (Module): Model to be called.
data (Any): Data loaded from data iterator.
internal_inputs (Optional[dict]): Data from previous stage. It must be a dict or None if it's the first stage.
Returns:
Any: Outputs of the model.
"""
if internal_inputs is None:
internal_inputs = {}
if isinstance(data, (list, tuple)):
return model(*data, **internal_inputs)
elif isinstance(data, dict):
return model(**data, **internal_inputs)
return model(data, **internal_inputs)
def retain_grad(x: Any) -> None:
"""Call retain_grad() on a tensor.
Args:
x (Any): Object to be called.
"""
if isinstance(x, torch.Tensor):
x.retain_grad()
def detach(x: Any) -> Any:
"""Call detach() on a tensor.
Args:
x (Any): Object to be called.
Returns:
Any: The detached object.
"""
if isinstance(x, torch.Tensor):
return x.detach()
return x
def merge_batch(data: List[Any]) -> Any:
"""Merge micro batches into a batch.
Args:
data (List[Any]): A list of micro batches.
Returns:
Any: Merge batch.
"""
if len(data) == 0:
return
flattened_data = []
tree_spec = None
for d in data:
elems, tree_spec = tree_flatten(d)
flattened_data.append(elems)
merged_data = []
for elem_batch in zip(*flattened_data):
if isinstance(elem_batch[0], torch.Tensor):
merged_data.append(torch.cat(elem_batch, dim=0))
else:
merged_data.append(list(elem_batch))
return tree_unflatten(merged_data, tree_spec)

View File

@ -0,0 +1,35 @@
from typing import Any, Callable, Iterable
from torch import Tensor
from torch.nn import Module
from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.stage_manager import PipelineStageManager
class PipelineSchedule:
def __init__(self, stage_manager: PipelineStageManager) -> None:
self.stage_manager = stage_manager
def forward_backward_step(self,
model: Module,
optimizer: OptimizerWrapper,
data_iter: Iterable,
criterion: Callable[[Any, Any], Tensor],
return_loss: bool = False,
return_outputs: bool = False) -> dict:
"""Forward and backward step for pipeline training.
Args:
model (Module): Model to be trained.
optimizer (OptimizerWrapper): Optimizer to be used.
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
Returns:
dict: A dict with keys: 'loss' and 'outputs'.
"""
raise NotImplementedError

View File

@ -0,0 +1,229 @@
from functools import partial
from typing import Any, Callable, Iterable, List, Optional, Union
import torch
import torch.cuda
from torch.nn import Module
from torch.utils._pytree import tree_map
from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.cuda import get_current_device
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
from .base import PipelineSchedule
class OneForwardOneBackwardSchedule(PipelineSchedule):
def __init__(self, num_microbatches: int, stage_manager: PipelineStageManager) -> None:
super().__init__(stage_manager)
self.comm = PipelineP2PCommunication(stage_manager)
self.num_microbatches = num_microbatches
self.batch: Optional[Any] = None
self.batch_size: Optional[int] = None
self.microbatch_offset: Optional[int] = None
self.microbatch_size: Optional[int] = None
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator.
Args:
data_iter (Iterable): Data iterator.
device (Optional[torch.device], optional): Target device. Defaults to None.
"""
batch = next(data_iter)
if device is not None:
batch = tree_map(partial(to_device, device=device), batch)
self.batch = batch
self.batch_size = get_batch_size(batch)
self.microbatch_offset = 0
assert self.batch_size % self.num_microbatches == 0, \
"Batch size should divided by the number of microbatches"
self.microbatch_size = self.batch_size // self.num_microbatches
def load_micro_batch(self) -> Any:
"""Load a micro batch from the current batch.
Returns:
Any: Micro batch.
"""
micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size)
self.microbatch_offset += self.microbatch_size
return tree_map(partial(to_device, device=get_current_device()), micro_batch)
def forward_step(self,
model: Module,
input_obj: Optional[dict],
criterion: Callable,
accum_loss: Optional[torch.Tensor] = None,
outputs: Optional[List[Any]] = None) -> Union[torch.Tensor, dict]:
"""Forward one step of the pipeline
Args:
model (Module): Model to be run
input_obj (Optional[dict]): The output from the previous stage. If it is the first stage, the `input_obj` is None.
criterion (Callable): Criterion to calculate loss.
accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None.
outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None.
Returns:
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
"""
micro_batch = self.load_micro_batch()
# for the first stage, input_obj is None
# for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
output_obj = model_forward(model, micro_batch, input_obj)
if self.stage_manager.is_last_stage():
loss = criterion(output_obj, micro_batch) / self.num_microbatches
if accum_loss is not None:
accum_loss.add_(loss.detach())
if outputs is not None:
outputs.append(tree_map(detach, output_obj))
return loss
else:
return output_obj
def backward_step(self, optimizer: OptimizerWrapper, input_obj: Optional[dict],
output_obj: Union[dict, torch.Tensor], output_obj_grad: Optional[dict]) -> Optional[dict]:
"""Backward one step of the pipeline
Args:
optimizer (OptimizerWrapper): Optimizer to update the model
input_obj (Optional[dict]): Output of the previous stage. If it is the first stage, the `input_obj` is None.
output_obj (Union[dict, torch.Tensor]): Output of the current stage. If it is the last stage, the output is the loss (Tensor).
output_obj_grad (dict): Gradient of the `output_obj`. If it is the last stage, the `output_obj_grad` is None.
Returns:
Optional[dict]: Gradient of the `input_obj`. If it is the first stage, the `input_obj_grad` is None.
"""
# Retain the grad on the input_obj.
tree_map(retain_grad, input_obj)
# Backward pass.
if output_obj_grad is None:
optimizer.backward(output_obj)
else:
for k, grad in output_obj_grad.items():
optimizer.backward_by_grad(output_obj[k], grad)
# Collect the grad of the input_obj.
input_obj_grad = None
if input_obj is not None:
input_obj_grad = {}
for k, v in input_obj.items():
if isinstance(v, torch.Tensor) and v.grad is not None:
input_obj_grad[k] = v.grad
return input_obj_grad
def forward_backward_step(self,
model: Module,
optimizer: OptimizerWrapper,
data_iter: Iterable,
criterion: Callable[..., Any],
return_loss: bool = False,
return_outputs: bool = False) -> dict:
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
Args:
model (Module): Model to be trained.
optimizer (OptimizerWrapper): Optimizer to be used.
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
Returns:
dict: A dict with keys: 'loss' and 'outputs'.
"""
forward_only = not torch.is_grad_enabled()
self.load_batch(data_iter)
# num_warmup_microbatches is the step when not all the processes are working
num_warmup_microbatches = self.stage_manager.num_stages - self.stage_manager.stage - 1
num_warmup_microbatches = min(num_warmup_microbatches, self.num_microbatches)
num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches
# Input, output tensors only need to be saved when doing backward passes
input_objs = None
output_objs = None
if not forward_only:
input_objs = []
output_objs = []
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
if return_loss and self.stage_manager.is_last_stage():
accum_loss = torch.zeros(1, device=get_current_device())
else:
accum_loss = None
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
input_obj = self.comm.recv_forward()
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
self.comm.send_forward(output_obj)
if not forward_only:
input_objs.append(input_obj)
output_objs.append(output_obj)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if num_microbatches_remaining > 0:
input_obj = self.comm.recv_forward()
# Run 1F1B in steady state.
for i in range(num_microbatches_remaining):
last_iteration = (i == (num_microbatches_remaining - 1))
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
if forward_only:
self.comm.send_forward(output_obj)
if not last_iteration:
input_obj = self.comm.recv_forward()
else:
# TODO adjust here
self.comm.send_forward(output_obj)
output_obj_grad = self.comm.recv_backward()
# Add input_obj and output_obj to end of list.
input_objs.append(input_obj)
output_objs.append(output_obj)
# Pop output_obj and output_obj from the start of the list for
# the backward pass.
input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0)
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
if last_iteration:
input_obj = None
else:
input_obj = self.comm.recv_forward()
self.comm.send_backward(input_obj_grad)
# Run cooldown backward passes.
if not forward_only:
for i in range(num_warmup_microbatches):
input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0)
output_obj_grad = self.comm.recv_backward()
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.comm.send_backward(input_obj_grad)
if outputs is not None:
outputs = merge_batch(outputs)
return {'loss': accum_loss, 'outputs': outputs}

View File

@ -0,0 +1,47 @@
import torch
from colossalai.pipeline.schedule._utils import get_batch_size, get_micro_batch, merge_batch
def test_get_batch_size():
tensor = torch.rand(2, 3)
assert get_batch_size(tensor) == 2
assert get_batch_size([tensor]) == 2
assert get_batch_size((1, tensor)) == 2
assert get_batch_size({'tensor': tensor}) == 2
assert get_batch_size({'dummy': [1], 'tensor': tensor}) == 2
assert get_batch_size({'tensor': [tensor]}) == 2
def test_get_micro_batch():
x = torch.rand(2, 1)
y = torch.rand(2, 3)
micro_batch = get_micro_batch(x, 0, 1)
assert torch.equal(micro_batch, x[0:1])
micro_batch = get_micro_batch(x, 1, 1)
assert torch.equal(micro_batch, x[1:2])
micro_batch = get_micro_batch([x, y], 0, 1)
assert torch.equal(micro_batch[0], x[0:1])
assert torch.equal(micro_batch[1], y[0:1])
micro_batch = get_micro_batch([x, y], 1, 1)
assert torch.equal(micro_batch[0], x[1:2])
assert torch.equal(micro_batch[1], y[1:2])
micro_batch = get_micro_batch({'x': x, 'y': y}, 0, 1)
assert torch.equal(micro_batch['x'], x[0:1])
assert torch.equal(micro_batch['y'], y[0:1])
micro_batch = get_micro_batch({'x': x, 'y': y}, 1, 1)
assert torch.equal(micro_batch['x'], x[1:2])
assert torch.equal(micro_batch['y'], y[1:2])
def test_merge_batch():
x = torch.rand(2, 1)
y = torch.rand(2, 3)
merged = merge_batch([x[0:1], x[1:2]])
assert torch.equal(merged, x)
merged = merge_batch([[x[0:1], y[0:1]], [x[1:2], y[1:2]]])
assert torch.equal(merged[0], x)
assert torch.equal(merged[1], y)
merged = merge_batch([{'x': x[0:1], 'y': y[0:1]}, {'x': x[1:2], 'y': y[1:2]}])
assert torch.equal(merged['x'], x)
assert torch.equal(merged['y'], y)