add scatter/gather optim for pipeline (#123)

pull/126/head^2
ver217 2022-01-07 13:22:22 +08:00 committed by GitHub
parent 404e6f88ed
commit 293fb40c42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 166 additions and 56 deletions

View File

@ -13,5 +13,5 @@ __all__ = [
'send_forward_backward_recv_forward_backward', 'send_backward',
'send_backward_recv_backward', 'send_backward_recv_forward',
'send_forward_recv_backward', 'recv_backward', 'recv_forward',
'ring_forward', 'send_tensor_meta', 'recv_tensor_meta'
]
'ring_forward', 'send_tensor_meta', 'recv_tensor_meta',
]

View File

@ -1,12 +1,42 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import List, Tuple, Union
import torch
import torch.distributed as dist
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device
from functools import reduce
import operator
from .utils import split_tensor_into_1d_equal_chunks, gather_split_1d_tensor
TensorShape = Union[torch.Size, List[int], Tuple[int]]
def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) -> Tuple[TensorShape, bool]:
"""get the exact tensor shape when communicating and return whether the tensor is a chunk
:param tensor_shape: shape of tensor
:type tensor_shape: TensorShape
:param chunk_tensor: whether to chunk tensor, defaults to False
:type chunk_tensor: bool, optional
:return: exact tensor shape, whether to chunk tensor
:rtype: Tuple[Union[torch.Size, List[int], Tuple[int]], bool]
"""
if chunk_tensor:
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1)
tensor_parallel_world_size = gpc.get_world_size(ParallelMode.TENSOR)
if tensor_chunk_shape % tensor_parallel_world_size == 0:
tensor_chunk_shape = tensor_chunk_shape // tensor_parallel_world_size
else:
tensor_chunk_shape = tensor_shape
chunk_tensor = False
else:
tensor_chunk_shape = tensor_shape
return tensor_chunk_shape, chunk_tensor
def _communicate(tensor_send_next=None,
@ -17,7 +47,8 @@ def _communicate(tensor_send_next=None,
recv_next_shape=None,
prev_rank=None,
next_rank=None,
dtype=None):
dtype=None,
scatter_gather_tensors=False):
"""
Adapted from megatron.p2p_communication.
Communicate tensors between stages. Used as helper method in other
@ -42,13 +73,15 @@ def _communicate(tensor_send_next=None,
if recv_prev:
assert recv_prev_shape is not None
tensor_recv_prev = torch.empty(recv_prev_shape,
recv_prev_chunk_shape, recv_prev_split = _get_tensor_shape(recv_prev_shape, scatter_gather_tensors)
tensor_recv_prev = torch.empty(recv_prev_chunk_shape,
requires_grad=True,
device=get_current_device(),
dtype=dtype)
if recv_next:
assert recv_next_shape is not None
tensor_recv_next = torch.empty(recv_next_shape,
recv_next_chunk_shape, recv_next_split = _get_tensor_shape(recv_next_shape, scatter_gather_tensors)
tensor_recv_next = torch.empty(recv_next_chunk_shape,
requires_grad=True,
device=get_current_device(),
dtype=dtype)
@ -63,6 +96,16 @@ def _communicate(tensor_send_next=None,
next_rank = gpc.get_next_global_rank(
ParallelMode.PIPELINE)
if tensor_send_prev is not None:
send_prev_split = _get_tensor_shape(tensor_send_prev.shape, scatter_gather_tensors)[1]
if send_prev_split:
tensor_send_prev = split_tensor_into_1d_equal_chunks(tensor_send_prev)
if tensor_send_next is not None:
send_next_split = _get_tensor_shape(tensor_send_next.shape, scatter_gather_tensors)[1]
if send_next_split:
tensor_send_next = split_tensor_into_1d_equal_chunks(tensor_send_next)
ops = []
if tensor_send_prev is not None:
send_prev_op = dist.P2POp(dist.isend, tensor_send_prev, prev_rank)
@ -82,10 +125,15 @@ def _communicate(tensor_send_next=None,
req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
if recv_prev and recv_prev_split:
tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_()
if recv_next and recv_next_split:
tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_()
return tensor_recv_prev, tensor_recv_next
def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float):
def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_gather_tensors=False):
"""Receives the input tensor from the previous member in pipeline.
:param input_tensor_shape: The shape of the tensor to be recieved
@ -101,11 +149,12 @@ def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float):
input_tensor, _ = _communicate(recv_prev=True,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank,
dtype=dtype)
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors)
return input_tensor
def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float):
def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False):
"""Receives the grad tensor from the next member in pipeline.
:param output_grad_shape: The shape of the tensor to be recieved
@ -121,11 +170,12 @@ def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float):
_, output_tensor_grad = _communicate(recv_next=True,
recv_next_shape=output_grad_shape,
next_rank=next_rank,
dtype=dtype)
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors)
return output_tensor_grad
def send_forward(output_tensor, next_rank=None):
def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False):
"""Sends the input tensor to the next member in pipeline.
:param output_tensor: Tensor to be sent
@ -135,10 +185,11 @@ def send_forward(output_tensor, next_rank=None):
"""
if not gpc.is_pipeline_last_stage():
_communicate(tensor_send_next=output_tensor,
next_rank=next_rank)
next_rank=next_rank,
scatter_gather_tensors=scatter_gather_tensors)
def send_backward(input_tensor_grad, prev_rank=None):
def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False):
"""Sends the grad tensor to the previous member in pipeline.
:param input_tensor_grad: Tensor to be sent
@ -148,14 +199,16 @@ def send_backward(input_tensor_grad, prev_rank=None):
"""
if not gpc.is_pipeline_first_stage():
_communicate(tensor_send_prev=input_tensor_grad,
prev_rank=prev_rank)
prev_rank=prev_rank,
scatter_gather_tensors=scatter_gather_tensors)
def send_forward_recv_backward(output_tensor,
output_grad_shape,
recv_next=True,
next_rank=None,
dtype=torch.float):
dtype=torch.float,
scatter_gather_tensors=False):
"""Batched communication operation. Sends the input tensor to the
next member in pipeline, while recieves the grad tensor from the
next member in pipeline.
@ -174,7 +227,8 @@ def send_forward_recv_backward(output_tensor,
recv_next=recv_next,
recv_next_shape=output_grad_shape,
next_rank=next_rank,
dtype=dtype)
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors)
return output_tensor_grad
@ -182,7 +236,8 @@ def send_backward_recv_forward(input_tensor_grad,
input_tensor_shape,
recv_prev=True,
prev_rank=None,
dtype=torch.float):
dtype=torch.float,
scatter_gather_tensors=False):
"""Batched communication operation. Sends the grad tensor to the
previous member in pipeline, while recieves the input tensor from the
previous member in pipeline.
@ -201,7 +256,8 @@ def send_backward_recv_forward(input_tensor_grad,
recv_prev=recv_prev,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank,
dtype=dtype)
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors)
return input_tensor
@ -210,7 +266,8 @@ def send_forward_recv_forward(output_tensor,
recv_prev=True,
prev_rank=None,
next_rank=None,
dtype=torch.float):
dtype=torch.float,
scatter_gather_tensors=False):
"""Batched communication operation. Sends the input tensor to the
next member in pipeline, while recieves the input tensor from the
previous member in pipeline.
@ -227,7 +284,8 @@ def send_forward_recv_forward(output_tensor,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank,
next_rank=next_rank,
dtype=dtype)
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors)
return input_tensor
@ -236,7 +294,8 @@ def send_backward_recv_backward(input_tensor_grad,
recv_next=True,
prev_rank=None,
next_rank=None,
dtype=torch.float):
dtype=torch.float,
scatter_gather_tensors=False):
"""Batched communication operation. Sends the grad tensor to the
previous member in pipeline, while recieves the grad tensor from the
next member in pipeline.
@ -253,7 +312,8 @@ def send_backward_recv_backward(input_tensor_grad,
recv_next_shape=output_grad_shape,
prev_rank=prev_rank,
next_rank=next_rank,
dtype=dtype)
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors)
return output_tensor_grad
@ -265,7 +325,8 @@ def send_forward_backward_recv_forward_backward(output_tensor,
recv_next=True,
prev_rank=None,
next_rank=None,
dtype=torch.float):
dtype=torch.float,
scatter_gather_tensors=False):
"""Batched communication operation. Sends the input tensor to the next and
the grad tensor to the previous, while recieves the grad tensor from the
next and the input tensor from the previous.
@ -290,5 +351,6 @@ def send_forward_backward_recv_forward_backward(output_tensor,
recv_next_shape=output_grad_shape,
prev_rank=prev_rank,
next_rank=next_rank,
dtype=dtype)
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors)
return input_tensor, output_tensor_grad

View File

@ -62,3 +62,31 @@ def recv_tensor_meta(tensor_shape, prev_rank=None):
tensor_shape = torch.Size(recv_shape)
return tensor_shape
def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
"""Break a tensor into equal 1D chunks."""
partition_size = torch.numel(tensor) // gpc.get_world_size(ParallelMode.PARALLEL_1D)
start_index = partition_size * gpc.get_local_rank(ParallelMode.PARALLEL_1D)
end_index = start_index + partition_size
if new_buffer:
data = torch.empty(partition_size, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
data.copy_(tensor.view(-1)[start_index:end_index])
else:
data = tensor.view(-1)[start_index:end_index]
return data
def gather_split_1d_tensor(tensor):
"""Opposite of above function, gather values from model parallel ranks."""
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
numel = torch.numel(tensor)
numel_gathered = world_size * numel
gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
chunks = [gathered[i*numel:(i+1)*numel] for i in range(world_size)]
dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.PARALLEL_1D))
return gathered

View File

@ -6,7 +6,7 @@ import inspect
import torch.cuda
from torch import Tensor
from colossalai.communication import *
import colossalai.communication as comm
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.amp.naive_amp import NaiveAMPModel
@ -33,16 +33,22 @@ class PipelineSchedule(BaseSchedule):
:type num_microbatches: int
:param batch_data_process_func: The preprocessing function which receives a batch of data, and it will be executed in `load_batch`
:type batch_data_process_func: Callable
:param scatter_gather_tensors: If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization
:type scatter_gather_tensors: bool
"""
def __init__(self,
num_microbatches,
batch_data_process_func: Callable = None,
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None):
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
scatter_gather_tensors: bool = False):
super().__init__(batch_data_process_func=batch_data_process_func)
self.num_microbatches = num_microbatches
self.dtype = torch.float
self.tensor_shape = tensor_shape
self.scatter_gather_tensors = False
if gpc.is_initialized(ParallelMode.PARALLEL_1D) and gpc.get_world_size(ParallelMode.PARALLEL_1D) > 1:
self.scatter_gather_tensors = scatter_gather_tensors
def load_batch(self, data_iter):
# Pipeline schedule just puts data in memory
@ -227,8 +233,9 @@ class PipelineSchedule(BaseSchedule):
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
if not gpc.is_first_rank(ParallelMode.PIPELINE):
ft_shape = recv_tensor_meta(ft_shape)
input_tensor = recv_forward(ft_shape, dtype=self.dtype)
ft_shape = comm.recv_tensor_meta(ft_shape)
input_tensor = comm.recv_forward(ft_shape, dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
output_tensor = self.forward_step(
engine, input_tensor, return_tensors,
return_output_label=return_output_label,
@ -236,8 +243,8 @@ class PipelineSchedule(BaseSchedule):
)
if not gpc.is_last_rank(ParallelMode.PIPELINE):
bt_shape = output_tensor.shape
fs_checker = send_tensor_meta(output_tensor, fs_checker)
send_forward(output_tensor)
fs_checker = comm.send_tensor_meta(output_tensor, fs_checker)
comm.send_forward(output_tensor, scatter_gather_tensors=self.scatter_gather_tensors)
if not forward_only:
input_tensors.append(input_tensor)
@ -248,8 +255,9 @@ class PipelineSchedule(BaseSchedule):
# receive this tensor here.
if num_microbatches_remaining > 0:
if not gpc.is_first_rank(ParallelMode.PIPELINE):
ft_shape = recv_tensor_meta(ft_shape)
input_tensor = recv_forward(ft_shape, dtype=self.dtype)
ft_shape = comm.recv_tensor_meta(ft_shape)
input_tensor = comm.recv_forward(ft_shape, dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
# Run 1F1B in steady state.
for i in range(num_microbatches_remaining):
@ -261,14 +269,15 @@ class PipelineSchedule(BaseSchedule):
accum_loss=accum_loss
)
if forward_only:
send_forward(output_tensor)
comm.send_forward(output_tensor, scatter_gather_tensors=self.scatter_gather_tensors)
if not last_iteration:
input_tensor = recv_forward(ft_shape, dtype=self.dtype)
input_tensor = comm.recv_forward(ft_shape, dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
else:
output_tensor_grad = send_forward_recv_backward(
output_tensor, bt_shape, dtype=self.dtype)
output_tensor_grad = comm.send_forward_recv_backward(
output_tensor, bt_shape, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors)
# Add input_tensor and output_tensor to end of list.
input_tensors.append(input_tensor)
@ -287,10 +296,10 @@ class PipelineSchedule(BaseSchedule):
if last_iteration:
input_tensor = None
send_backward(input_tensor_grad)
comm.send_backward(input_tensor_grad, scatter_gather_tensors=self.scatter_gather_tensors)
else:
input_tensor = send_backward_recv_forward(
input_tensor_grad, ft_shape, dtype=self.dtype)
input_tensor = comm.send_backward_recv_forward(
input_tensor_grad, ft_shape, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors)
# Run cooldown backward passes.
if not forward_only:
@ -298,7 +307,8 @@ class PipelineSchedule(BaseSchedule):
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
output_tensor_grad = recv_backward(bt_shape, dtype=self.dtype)
output_tensor_grad = comm.recv_backward(bt_shape, dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
input_tensor_grad = self.backward_step(
engine,
@ -306,7 +316,7 @@ class PipelineSchedule(BaseSchedule):
output_tensor_grad
)
send_backward(input_tensor_grad)
comm.send_backward(input_tensor_grad, scatter_gather_tensors=self.scatter_gather_tensors)
if len(return_tensors) > 0:
output, label = tuple(map(list, zip(*return_tensors)))
@ -322,7 +332,8 @@ class InterleavedPipelineSchedule(PipelineSchedule):
num_microbatches,
num_model_chunks,
batch_data_process_func: Callable = None,
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None):
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
scatter_gather_tensors: bool = False):
"""A helper schedule class for pipeline parallelism running environment.
It uses interleaved 1F1B strategy. Other properties are similar as
:class:`NonPipelineSchedule`.
@ -333,10 +344,13 @@ class InterleavedPipelineSchedule(PipelineSchedule):
:type num_model_chunks: int
:param batch_data_process_func: The preprocessing function which receives a batch of data, and it will be executed in `load_batch`
:type batch_data_process_func: Callable
:param scatter_gather_tensors: If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization
:type scatter_gather_tensors: bool
"""
assert num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0, \
'num_microbatches must be an integer multiple of pipeline parallel world size'
super().__init__(num_microbatches, batch_data_process_func=batch_data_process_func, tensor_shape=tensor_shape)
super().__init__(num_microbatches, batch_data_process_func=batch_data_process_func,
tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather_tensors)
gpc.set_virtual_pipeline_parallel_size(num_model_chunks)
gpc.set_virtual_pipeline_parallel_rank(0)
self.num_model_chunks = num_model_chunks
@ -494,15 +508,16 @@ class InterleavedPipelineSchedule(PipelineSchedule):
# Run warmup forward passes.
gpc.set_virtual_pipeline_parallel_rank(0)
if not gpc.is_pipeline_first_stage():
input_tensor_shapes[0] = recv_tensor_meta(input_tensor_shapes[0])
input_tensors[0].append(recv_forward(input_tensor_shapes[0], dtype=self.dtype))
input_tensor_shapes[0] = comm.recv_tensor_meta(input_tensor_shapes[0])
input_tensors[0].append(comm.recv_forward(input_tensor_shapes[0], dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors))
for k in range(num_warmup_microbatches):
model_chunk_id = get_model_chunk_id(k, forward=True)
output_tensor = forward_step_helper(k)
if not gpc.is_pipeline_last_stage():
output_tensor_shapes[model_chunk_id] = output_tensor.shape
send_tensor_shape_flags[model_chunk_id] = send_tensor_meta(
send_tensor_shape_flags[model_chunk_id] = comm.send_tensor_meta(
output_tensor, send_tensor_shape_flags[model_chunk_id])
# Determine if tensor should be received from previous stage.
next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True)
@ -519,7 +534,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
with switch_virtual_pipeline_parallel_rank(next_forward_model_chunk_id):
if not gpc.is_pipeline_first_stage():
input_tensor_shapes[next_forward_model_chunk_id] = recv_tensor_meta(
input_tensor_shapes[next_forward_model_chunk_id] = comm.recv_tensor_meta(
input_tensor_shapes[next_forward_model_chunk_id])
# Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration).
@ -532,20 +547,22 @@ class InterleavedPipelineSchedule(PipelineSchedule):
recv_next = False
output_shape = output_tensor_shapes[num_model_chunks-1] if recv_next else None
input_tensor, output_tensor_grad = \
send_forward_backward_recv_forward_backward(
comm.send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad,
input_shape,
output_shape,
recv_prev=recv_prev, recv_next=recv_next,
dtype=self.dtype)
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
else:
input_tensor = \
send_forward_recv_forward(
comm.send_forward_recv_forward(
output_tensor,
input_shape,
recv_prev=recv_prev,
dtype=self.dtype)
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
# Run 1F1B in steady state.
@ -608,12 +625,13 @@ class InterleavedPipelineSchedule(PipelineSchedule):
output_shape = output_tensor_shapes[next_backward_model_chunk_id] if recv_next else None
# Communicate tensors.
input_tensor, output_tensor_grad = \
send_forward_backward_recv_forward_backward(
comm.send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad,
input_shape,
output_shape,
recv_prev=recv_prev, recv_next=recv_next,
dtype=self.dtype)
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
@ -627,7 +645,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
if not forward_only:
if all_warmup_microbatches:
output_tensor_grads[num_model_chunks-1].append(
recv_backward(output_tensor_shapes[num_model_chunks-1]))
comm.recv_backward(output_tensor_shapes[num_model_chunks-1], scatter_gather_tensors=self.scatter_gather_tensors))
for k in range(num_microbatches_remaining, num_microbatches):
input_tensor_grad = backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
@ -639,11 +657,12 @@ class InterleavedPipelineSchedule(PipelineSchedule):
recv_next = False
output_shape = output_tensor_shapes[next_backward_model_chunk_id] if recv_next else None
output_tensor_grads[next_backward_model_chunk_id].append(
send_backward_recv_backward(
comm.send_backward_recv_backward(
input_tensor_grad,
output_shape,
recv_next=recv_next,
dtype=self.dtype))
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors))
if len(return_tensors) > 0:
output, label = tuple(map(list, zip(*return_tensors)))

View File

@ -290,9 +290,10 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
# initialize amp
amp_mode = None
if fp16_cfg is not None and fp16_cfg.mode is not None:
# TODO: pipeline only support NAIVE AMP
cfg_ = fp16_cfg.copy()
amp_mode = cfg_.pop('mode')
if is_using_pp():
assert amp_mode == AMP_TYPE.NAIVE, 'Pipeline only support NaiveAMP currently'
if amp_mode == AMP_TYPE.NAIVE:
cfg_['clip_grad'] = clip_grad_norm
model, optimizer, criterion = convert_to_amp(model=model,