mirror of https://github.com/hpcaitech/ColossalAI
645 lines
28 KiB
Python
645 lines
28 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
from typing import Union
|
|
|
|
import torch.cuda
|
|
import torch.distributed as dist
|
|
from torch import Tensor
|
|
|
|
from colossalai.communication import *
|
|
from colossalai.context.parallel_mode import ParallelMode
|
|
from colossalai.core import global_context as gpc
|
|
from colossalai.amp.naive_amp import NaiveAMPModel
|
|
from colossalai.zero import (ZeroRedundancyOptimizer_Level_2,
|
|
ZeroRedundancyOptimizer_Level_3)
|
|
from colossalai.utils import get_current_device, switch_virtual_pipeline_parallel_rank
|
|
from ._base_schedule import BaseSchedule
|
|
|
|
|
|
def squeeze(x: Union[Tensor, tuple, list]):
|
|
if isinstance(x, (tuple, list)):
|
|
return x[0]
|
|
else:
|
|
return x
|
|
|
|
|
|
class PipelineSchedule(BaseSchedule):
|
|
"""A helper schedule class for pipeline parallelism running environment.
|
|
It uses non-interleaved 1F1B strategy. Other properties are similar as
|
|
:class:`NonPipelineSchedule`.
|
|
|
|
:param num_microbatches: The number of microbatches
|
|
:param amp_type: The type of automatic mixed precision
|
|
:param amp_config: The configuration of automatic mixed procision
|
|
:param sync_data: If set to `True`, will sync data every batch over pipeline stages
|
|
:type num_microbatches: int
|
|
:type amp_type: AMP_TYPE
|
|
:type amp_config: dict
|
|
:type sync_data: bool
|
|
"""
|
|
|
|
def __init__(self,
|
|
num_microbatches,
|
|
sync_data: bool = True):
|
|
super().__init__()
|
|
|
|
self.num_microbatches = num_microbatches
|
|
self.sync_data = sync_data
|
|
self.dtype = torch.float
|
|
|
|
def _move_to_device(self, data):
|
|
if isinstance(data, (
|
|
tuple,
|
|
list,
|
|
)):
|
|
assert len(data) == 1, "Data tuple's length in pipeline should be 1"
|
|
data = data[0]
|
|
assert torch.is_tensor(data), "Data in pipeline should be tensor"
|
|
data = data.to(get_current_device()).detach()
|
|
return data
|
|
|
|
def _sync_data(self):
|
|
reqs = []
|
|
if gpc.is_first_rank(ParallelMode.PIPELINE):
|
|
src_rank = gpc.get_global_rank()
|
|
reqs.append(dist.broadcast(
|
|
tensor=self.batch_data,
|
|
src=src_rank,
|
|
group=gpc.get_group(ParallelMode.PIPELINE_PREV),
|
|
async_op=True
|
|
))
|
|
reqs.append(dist.broadcast(
|
|
tensor=self.batch_label,
|
|
src=src_rank,
|
|
group=gpc.get_group(ParallelMode.PIPELINE_PREV),
|
|
async_op=True
|
|
))
|
|
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
|
src_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
|
reqs.append(dist.broadcast(
|
|
tensor=self.batch_data,
|
|
src=src_rank,
|
|
group=gpc.get_group(ParallelMode.PIPELINE_NEXT),
|
|
async_op=True
|
|
))
|
|
reqs.append(dist.broadcast(
|
|
tensor=self.batch_label,
|
|
src=src_rank,
|
|
group=gpc.get_group(ParallelMode.PIPELINE_NEXT),
|
|
async_op=True
|
|
))
|
|
for req in reqs:
|
|
req.wait()
|
|
|
|
# Pipeline schedule just puts data in memory
|
|
def load_batch(self, data_iter):
|
|
if data_iter is None:
|
|
raise RuntimeError('Dataloader is not defined.')
|
|
self.batch_pos = 0
|
|
data, label = next(data_iter)
|
|
self.batch_data, self.batch_label = \
|
|
self._move_to_device(data), self._move_to_device(label)
|
|
batch_size = self.batch_data.shape[0]
|
|
assert batch_size % self.num_microbatches == 0, \
|
|
"Batch size should divided by the number of microbatches"
|
|
self.microbatch_size = batch_size // self.num_microbatches
|
|
if self.sync_data:
|
|
self._sync_data()
|
|
|
|
def _get_data_slice(self, tensor):
|
|
return tensor[self.batch_pos: self.batch_pos + self.microbatch_size]
|
|
|
|
def load_micro_batch(self):
|
|
data = self._get_data_slice(self.batch_data)
|
|
label = self._get_data_slice(self.batch_label)
|
|
self.batch_pos += self.microbatch_size
|
|
return (data,), (label,)
|
|
|
|
def pre_processing(self, engine):
|
|
if isinstance(engine.optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
|
|
raise TypeError(
|
|
"Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3"
|
|
)
|
|
|
|
if isinstance(engine.model, NaiveAMPModel):
|
|
self.dtype = torch.half
|
|
|
|
def forward_step(self, engine, input_tensor, return_tensors, return_loss=True):
|
|
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
|
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
|
|
Returns output tensor. This is a helper function and can be ignored by users.
|
|
|
|
:param engine: your engine object
|
|
:type engine: colossalai.engine.Engine
|
|
:param input_tensor: input tensor for this pipeline stage
|
|
:type input_tensor: :class:`torch.Tensor`
|
|
:param return_tensors: a list of tensors to return
|
|
:type return_tensors: List[:class:`torch.Tensor`]
|
|
|
|
:return: output or the loss value of the current pipeline stage
|
|
:rtype: :class:`torch.Tensor`
|
|
"""
|
|
|
|
if input_tensor is None:
|
|
input_tensor, label = self.load_micro_batch()
|
|
input_tensor = squeeze(input_tensor)
|
|
output_tensor = engine(input_tensor)
|
|
output_tensor = squeeze(output_tensor)
|
|
|
|
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
|
if return_loss:
|
|
input_tensor, label = self.load_micro_batch()
|
|
loss_reduced = engine.criterion(output_tensor, *label) \
|
|
/ self.num_microbatches
|
|
|
|
return_tensors.append(
|
|
tuple((output_tensor, label[0], loss_reduced)))
|
|
return loss_reduced
|
|
else:
|
|
return_tensors.append(output_tensor)
|
|
return output_tensor
|
|
|
|
else:
|
|
return output_tensor
|
|
|
|
def backward_step(self, engine, input_tensor, output_tensor, output_tensor_grad):
|
|
"""Backward step through the passed-in output tensor. If it is the last stage, the
|
|
output_tensor_grad is None, otherwise it is the gradients with respect to stage's output tensor.
|
|
Returns the gradients with respect to the input tensor (None if first stage).
|
|
This is a helper function and can be ignored by users.
|
|
|
|
:param engine: your engine object
|
|
:type engine: colossalai.engine.Engine
|
|
:param input_tensor: input tensor for this pipeline stage
|
|
:type input_tensor: :class:`torch.Tensor`
|
|
:param output_tensor: output tensor for this pipeline stage
|
|
:type output_tensor: :class:`torch.Tensor`
|
|
:param output_tensor_grad: gradient of output tensor for this pipeline stage
|
|
:type output_tensor_grad: :class:`torch.Tensor`
|
|
|
|
:return: gradient of input tensor
|
|
:rtype: :class:`torch.Tensor`
|
|
"""
|
|
|
|
# Retain the grad on the input_tensor.
|
|
if input_tensor is not None:
|
|
input_tensor.retain_grad()
|
|
|
|
# Backward pass.
|
|
if output_tensor_grad is None:
|
|
engine.backward(output_tensor)
|
|
else:
|
|
engine.backward_by_grad(output_tensor, output_tensor_grad)
|
|
|
|
# Collect the grad of the input_tensor.
|
|
input_tensor_grad = None
|
|
if input_tensor is not None:
|
|
input_tensor_grad = input_tensor.grad
|
|
|
|
return input_tensor_grad
|
|
|
|
def forward_backward_step(self,
|
|
engine,
|
|
data_iter,
|
|
forward_only=False,
|
|
return_loss=True):
|
|
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
|
|
Returns a tuple with losses if the last stage, an empty tuple otherwise.
|
|
|
|
:param engine: your engine object
|
|
:type engine: colossalai.engine.Engine
|
|
:param data_iter: dataloader as the form of an iterator, obtained by calling iter(dataloader)
|
|
:type data_iter: Iterable
|
|
:param forward_only: whether run forward step only. Default is false. If true, no backward will be run.
|
|
:type forward_only: bool
|
|
:param return_loss: whether returns the loss value. Default is true.
|
|
:type return_loss: bool
|
|
|
|
:return: (output, label, loss)
|
|
:rtype: Tuple[:class:`torch.Tensor`]
|
|
"""
|
|
|
|
assert forward_only or return_loss, \
|
|
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
|
|
|
|
self.load_batch(data_iter)
|
|
num_warmup_microbatches = \
|
|
(gpc.get_world_size(ParallelMode.PIPELINE) -
|
|
gpc.get_local_rank(ParallelMode.PIPELINE) - 1)
|
|
num_warmup_microbatches = min(num_warmup_microbatches,
|
|
self.num_microbatches)
|
|
num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches
|
|
|
|
# Input, output tensors only need to be saved when doing backward passes
|
|
input_tensors = None
|
|
output_tensors = None
|
|
if not forward_only:
|
|
input_tensors = []
|
|
output_tensors = []
|
|
return_tensors = []
|
|
|
|
# Used for tensor meta information communication
|
|
ft_shape = None
|
|
bt_shape = None
|
|
fs_checker = True
|
|
|
|
# Run warmup forward passes.
|
|
for i in range(num_warmup_microbatches):
|
|
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
|
ft_shape = recv_tensor_meta(ft_shape)
|
|
input_tensor = recv_forward(ft_shape, dtype=self.dtype)
|
|
output_tensor = self.forward_step(
|
|
engine, input_tensor, return_tensors,
|
|
return_loss=return_loss
|
|
)
|
|
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
|
bt_shape = output_tensor.shape
|
|
fs_checker = send_tensor_meta(output_tensor, fs_checker)
|
|
send_forward(output_tensor)
|
|
|
|
if not forward_only:
|
|
input_tensors.append(input_tensor)
|
|
output_tensors.append(output_tensor)
|
|
|
|
# Before running 1F1B, need to receive first forward tensor.
|
|
# If all microbatches are run in warmup / cooldown phase, then no need to
|
|
# receive this tensor here.
|
|
if num_microbatches_remaining > 0:
|
|
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
|
ft_shape = recv_tensor_meta(ft_shape)
|
|
input_tensor = recv_forward(ft_shape, dtype=self.dtype)
|
|
|
|
# Run 1F1B in steady state.
|
|
for i in range(num_microbatches_remaining):
|
|
last_iteration = (i == (num_microbatches_remaining - 1))
|
|
|
|
output_tensor = self.forward_step(
|
|
engine, input_tensor, return_tensors,
|
|
return_loss=return_loss
|
|
)
|
|
if forward_only:
|
|
send_forward(output_tensor)
|
|
|
|
if not last_iteration:
|
|
input_tensor = recv_forward(ft_shape, dtype=self.dtype)
|
|
|
|
else:
|
|
output_tensor_grad = send_forward_recv_backward(
|
|
output_tensor, bt_shape, dtype=self.dtype)
|
|
|
|
# Add input_tensor and output_tensor to end of list.
|
|
input_tensors.append(input_tensor)
|
|
output_tensors.append(output_tensor)
|
|
|
|
# Pop input_tensor and output_tensor from the start of the list for
|
|
# the backward pass.
|
|
input_tensor = input_tensors.pop(0)
|
|
output_tensor = output_tensors.pop(0)
|
|
|
|
input_tensor_grad = self.backward_step(
|
|
engine,
|
|
input_tensor, output_tensor,
|
|
output_tensor_grad
|
|
)
|
|
|
|
if last_iteration:
|
|
input_tensor = None
|
|
send_backward(input_tensor_grad)
|
|
else:
|
|
input_tensor = send_backward_recv_forward(
|
|
input_tensor_grad, ft_shape, dtype=self.dtype)
|
|
|
|
# Run cooldown backward passes.
|
|
if not forward_only:
|
|
for i in range(num_warmup_microbatches):
|
|
input_tensor = input_tensors.pop(0)
|
|
output_tensor = output_tensors.pop(0)
|
|
|
|
output_tensor_grad = recv_backward(bt_shape, dtype=self.dtype)
|
|
|
|
input_tensor_grad = self.backward_step(
|
|
engine,
|
|
input_tensor, output_tensor,
|
|
output_tensor_grad
|
|
)
|
|
|
|
send_backward(input_tensor_grad)
|
|
|
|
if len(return_tensors) > 0:
|
|
if return_loss:
|
|
output, label, loss = tuple(map(list, zip(*return_tensors)))
|
|
return (torch.cat(output, dim=0),
|
|
torch.cat(label, dim=0),
|
|
sum(loss))
|
|
else:
|
|
return tuple((torch.cat(return_tensors, dim=0), None, None))
|
|
else:
|
|
return tuple((None, None, None))
|
|
|
|
|
|
class InterleavedPipelineSchedule(PipelineSchedule):
|
|
def __init__(self, num_microbatches, num_model_chunks, sync_data: bool = True):
|
|
assert num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0, \
|
|
'num_microbatches must be an integer multiple of pipeline parallel world size'
|
|
super().__init__(num_microbatches, sync_data=sync_data)
|
|
gpc.set_virtual_pipeline_parallel_size(num_model_chunks)
|
|
gpc.set_virtual_pipeline_parallel_rank(0)
|
|
|
|
def pre_processing(self, engine):
|
|
if isinstance(engine.optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
|
|
raise TypeError(
|
|
"Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3"
|
|
)
|
|
|
|
if isinstance(engine.model[0], NaiveAMPModel):
|
|
self.dtype = torch.half
|
|
|
|
def forward_step(self, engine, model, input_tensor, return_tensors, return_loss=True):
|
|
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
|
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
|
|
Returns output tensor. This is a helper function and can be ignored by users.
|
|
"""
|
|
|
|
if input_tensor is None:
|
|
input_tensor, label = self.load_micro_batch()
|
|
input_tensor = squeeze(input_tensor)
|
|
output_tensor = model(input_tensor)
|
|
output_tensor = squeeze(output_tensor)
|
|
|
|
if gpc.is_pipeline_last_stage():
|
|
if return_loss:
|
|
input_tensor, label = self.load_micro_batch()
|
|
loss_reduced = engine.criterion(output_tensor, *label) / self.num_microbatches
|
|
return_tensors.append(
|
|
tuple((output_tensor, label[0], loss_reduced)))
|
|
return loss_reduced
|
|
else:
|
|
return_tensors.append(output_tensor)
|
|
return output_tensor
|
|
else:
|
|
return output_tensor
|
|
|
|
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True):
|
|
"""Run interleaved 1F1B schedule (model split into model chunks), with
|
|
communication between pipeline stages as needed.
|
|
|
|
Returns dictionary with losses if the last stage, empty dict otherwise."""
|
|
assert forward_only or return_loss, \
|
|
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
|
|
self.load_batch(data_iter)
|
|
model = engine.model
|
|
input_tensors = [[] for _ in range(len(model))]
|
|
output_tensors = [[] for _ in range(len(model))]
|
|
return_tensors = []
|
|
if not forward_only:
|
|
output_tensor_grads = [[] for _ in range(len(model))]
|
|
|
|
# Used for tensor meta information communication
|
|
input_tensor_shapes = [None for _ in range(len(model))]
|
|
output_tensor_shapes = [None for _ in range(len(model))]
|
|
send_tensor_shape_flags = [True for _ in range(len(model))]
|
|
|
|
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
|
pipeline_parallel_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
|
|
|
# Compute number of warmup and remaining microbatches.
|
|
num_model_chunks = len(model)
|
|
num_microbatches = self.num_microbatches * num_model_chunks
|
|
all_warmup_microbatches = False
|
|
if forward_only:
|
|
num_warmup_microbatches = num_microbatches
|
|
else:
|
|
# Run all forward passes and then all backward passes if number of
|
|
# microbatches is just the number of pipeline stages.
|
|
# Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
|
|
# all workers, followed by more microbatches after depending on
|
|
# stage ID (more forward passes for earlier stages, later stages can
|
|
# immediately start with 1F1B).
|
|
if self.num_microbatches == pipeline_parallel_size:
|
|
num_warmup_microbatches = num_microbatches
|
|
all_warmup_microbatches = True
|
|
else:
|
|
num_warmup_microbatches = \
|
|
(pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
|
|
num_warmup_microbatches += (
|
|
num_model_chunks - 1) * pipeline_parallel_size
|
|
num_warmup_microbatches = min(num_warmup_microbatches,
|
|
num_microbatches)
|
|
num_microbatches_remaining = \
|
|
num_microbatches - num_warmup_microbatches
|
|
|
|
def get_model_chunk_id(microbatch_id, forward):
|
|
"""Helper method to get the model chunk ID given the iteration number."""
|
|
microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
|
|
model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
|
|
if not forward:
|
|
model_chunk_id = (num_model_chunks - model_chunk_id - 1)
|
|
return model_chunk_id
|
|
|
|
def forward_step_helper(microbatch_id):
|
|
"""Helper method to run forward step with model split into chunks
|
|
(run set_virtual_pipeline_model_parallel_rank() before calling
|
|
forward_step())."""
|
|
model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
|
|
gpc.set_virtual_pipeline_parallel_rank(model_chunk_id)
|
|
|
|
# forward step
|
|
if gpc.is_pipeline_first_stage():
|
|
if len(input_tensors[model_chunk_id]) == \
|
|
len(output_tensors[model_chunk_id]):
|
|
input_tensors[model_chunk_id].append(None)
|
|
input_tensor = input_tensors[model_chunk_id][-1]
|
|
output_tensor = self.forward_step(
|
|
engine, model[model_chunk_id], input_tensor, return_tensors, return_loss=return_loss)
|
|
output_tensors[model_chunk_id].append(output_tensor)
|
|
|
|
# if forward-only, no need to save tensors for a backward pass
|
|
if forward_only:
|
|
input_tensors[model_chunk_id].pop()
|
|
output_tensors[model_chunk_id].pop()
|
|
|
|
return output_tensor
|
|
|
|
def backward_step_helper(microbatch_id):
|
|
"""Helper method to run backward step with model split into chunks
|
|
(run set_virtual_pipeline_model_parallel_rank() before calling
|
|
backward_step())."""
|
|
model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
|
|
gpc.set_virtual_pipeline_parallel_rank(model_chunk_id)
|
|
|
|
if gpc.is_pipeline_last_stage():
|
|
if len(output_tensor_grads[model_chunk_id]) == 0:
|
|
output_tensor_grads[model_chunk_id].append(None)
|
|
input_tensor = input_tensors[model_chunk_id].pop(0)
|
|
output_tensor = output_tensors[model_chunk_id].pop(0)
|
|
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
|
|
input_tensor_grad = self.backward_step(engine, input_tensor, output_tensor, output_tensor_grad)
|
|
|
|
return input_tensor_grad
|
|
|
|
# Run warmup forward passes.
|
|
gpc.set_virtual_pipeline_parallel_rank(0)
|
|
if not gpc.is_pipeline_first_stage():
|
|
input_tensor_shapes[0] = recv_tensor_meta(input_tensor_shapes[0])
|
|
input_tensors[0].append(recv_forward(input_tensor_shapes[0], dtype=self.dtype))
|
|
|
|
for k in range(num_warmup_microbatches):
|
|
model_chunk_id = get_model_chunk_id(k, forward=True)
|
|
output_tensor = forward_step_helper(k)
|
|
if not gpc.is_pipeline_last_stage():
|
|
output_tensor_shapes[model_chunk_id] = output_tensor.shape
|
|
send_tensor_shape_flags[model_chunk_id] = send_tensor_meta(
|
|
output_tensor, send_tensor_shape_flags[model_chunk_id])
|
|
# Determine if tensor should be received from previous stage.
|
|
next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True)
|
|
recv_prev = True
|
|
if gpc.is_pipeline_first_stage(ignore_virtual=True):
|
|
if next_forward_model_chunk_id == 0:
|
|
recv_prev = False
|
|
if k == (num_microbatches - 1):
|
|
recv_prev = False
|
|
|
|
# Don't send tensor downstream if on last stage.
|
|
if gpc.is_pipeline_last_stage():
|
|
output_tensor = None
|
|
|
|
with switch_virtual_pipeline_parallel_rank(next_forward_model_chunk_id):
|
|
if not gpc.is_pipeline_first_stage():
|
|
input_tensor_shapes[next_forward_model_chunk_id] = recv_tensor_meta(
|
|
input_tensor_shapes[next_forward_model_chunk_id])
|
|
# Send and receive tensors as appropriate (send tensors computed
|
|
# in this iteration; receive tensors for next iteration).
|
|
input_shape = input_tensor_shapes[next_forward_model_chunk_id] if recv_prev else None
|
|
if k == (num_warmup_microbatches - 1) and not forward_only and \
|
|
not all_warmup_microbatches:
|
|
input_tensor_grad = None
|
|
recv_next = True
|
|
if gpc.is_pipeline_last_stage(ignore_virtual=True):
|
|
recv_next = False
|
|
output_shape = output_tensor_shapes[num_model_chunks-1] if recv_next else None
|
|
input_tensor, output_tensor_grad = \
|
|
send_forward_backward_recv_forward_backward(
|
|
output_tensor, input_tensor_grad,
|
|
input_shape,
|
|
output_shape,
|
|
recv_prev=recv_prev, recv_next=recv_next,
|
|
dtype=self.dtype)
|
|
output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
|
|
else:
|
|
input_tensor = \
|
|
send_forward_recv_forward(
|
|
output_tensor,
|
|
input_shape,
|
|
recv_prev=recv_prev,
|
|
dtype=self.dtype)
|
|
input_tensors[next_forward_model_chunk_id].append(input_tensor)
|
|
|
|
# Run 1F1B in steady state.
|
|
for k in range(num_microbatches_remaining):
|
|
# Forward pass.
|
|
forward_k = k + num_warmup_microbatches
|
|
output_tensor = forward_step_helper(forward_k)
|
|
|
|
# Backward pass.
|
|
backward_k = k
|
|
input_tensor_grad = backward_step_helper(backward_k)
|
|
|
|
# Send output_tensor and input_tensor_grad, receive input_tensor
|
|
# and output_tensor_grad.
|
|
|
|
# Determine if current stage has anything to send in either direction,
|
|
# otherwise set tensor to None.
|
|
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
|
|
gpc.set_virtual_pipeline_parallel_rank(forward_model_chunk_id)
|
|
if gpc.is_pipeline_last_stage():
|
|
output_tensor = None
|
|
|
|
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
|
|
gpc.set_virtual_pipeline_parallel_rank(backward_model_chunk_id)
|
|
if gpc.is_pipeline_first_stage():
|
|
input_tensor_grad = None
|
|
|
|
# Determine if peers are sending, and where in data structure to put
|
|
# received tensors.
|
|
recv_prev = True
|
|
if gpc.is_pipeline_first_stage(ignore_virtual=True):
|
|
# First stage is ahead of last stage by (pipeline_parallel_size - 1).
|
|
next_forward_model_chunk_id = get_model_chunk_id(
|
|
forward_k - (pipeline_parallel_size - 1), forward=True)
|
|
if next_forward_model_chunk_id == (num_model_chunks - 1):
|
|
recv_prev = False
|
|
next_forward_model_chunk_id += 1
|
|
else:
|
|
next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1,
|
|
forward=True)
|
|
|
|
recv_next = True
|
|
if gpc.is_pipeline_last_stage(ignore_virtual=True):
|
|
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
|
|
next_backward_model_chunk_id = get_model_chunk_id(
|
|
backward_k - (pipeline_parallel_size - 1), forward=False)
|
|
if next_backward_model_chunk_id == 0:
|
|
recv_next = False
|
|
next_backward_model_chunk_id -= 1
|
|
else:
|
|
next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1,
|
|
forward=False)
|
|
|
|
# If last iteration, don't receive; we already received one extra
|
|
# before the start of the for loop.
|
|
if k == (num_microbatches_remaining - 1):
|
|
recv_prev = False
|
|
|
|
input_shape = input_tensor_shapes[next_forward_model_chunk_id] if recv_prev else None
|
|
output_shape = output_tensor_shapes[next_backward_model_chunk_id] if recv_next else None
|
|
# Communicate tensors.
|
|
input_tensor, output_tensor_grad = \
|
|
send_forward_backward_recv_forward_backward(
|
|
output_tensor, input_tensor_grad,
|
|
input_shape,
|
|
output_shape,
|
|
recv_prev=recv_prev, recv_next=recv_next,
|
|
dtype=self.dtype)
|
|
|
|
# Put input_tensor and output_tensor_grad in data structures in the
|
|
# right location.
|
|
if recv_prev:
|
|
input_tensors[next_forward_model_chunk_id].append(input_tensor)
|
|
if recv_next:
|
|
output_tensor_grads[next_backward_model_chunk_id].append(
|
|
output_tensor_grad)
|
|
|
|
# Run cooldown backward passes (flush out pipeline).
|
|
if not forward_only:
|
|
if all_warmup_microbatches:
|
|
output_tensor_grads[num_model_chunks-1].append(
|
|
recv_backward(output_tensor_shapes[num_model_chunks-1]))
|
|
for k in range(num_microbatches_remaining, num_microbatches):
|
|
input_tensor_grad = backward_step_helper(k)
|
|
next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
|
|
recv_next = True
|
|
if gpc.is_pipeline_last_stage(ignore_virtual=True):
|
|
if next_backward_model_chunk_id == (num_model_chunks - 1):
|
|
recv_next = False
|
|
if k == (num_microbatches - 1):
|
|
recv_next = False
|
|
output_shape = output_tensor_shapes[next_backward_model_chunk_id] if recv_next else None
|
|
output_tensor_grads[next_backward_model_chunk_id].append(
|
|
send_backward_recv_backward(
|
|
input_tensor_grad,
|
|
output_shape,
|
|
recv_next=recv_next,
|
|
dtype=self.dtype))
|
|
|
|
if len(return_tensors) > 0:
|
|
if return_loss:
|
|
output, label, loss = tuple(map(list, zip(*return_tensors)))
|
|
return (torch.cat(output, dim=0),
|
|
torch.cat(label, dim=0),
|
|
sum(loss))
|
|
else:
|
|
return tuple((torch.cat(return_tensors, dim=0), None, None))
|
|
else:
|
|
return tuple((None, None, None))
|