mirror of https://github.com/hpcaitech/ColossalAI
add scatter/gather optim for pipeline (#123)
parent
404e6f88ed
commit
293fb40c42
|
@ -13,5 +13,5 @@ __all__ = [
|
||||||
'send_forward_backward_recv_forward_backward', 'send_backward',
|
'send_forward_backward_recv_forward_backward', 'send_backward',
|
||||||
'send_backward_recv_backward', 'send_backward_recv_forward',
|
'send_backward_recv_backward', 'send_backward_recv_forward',
|
||||||
'send_forward_recv_backward', 'recv_backward', 'recv_forward',
|
'send_forward_recv_backward', 'recv_backward', 'recv_forward',
|
||||||
'ring_forward', 'send_tensor_meta', 'recv_tensor_meta'
|
'ring_forward', 'send_tensor_meta', 'recv_tensor_meta',
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,12 +1,42 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
|
from typing import List, Tuple, Union
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
from functools import reduce
|
||||||
|
import operator
|
||||||
|
from .utils import split_tensor_into_1d_equal_chunks, gather_split_1d_tensor
|
||||||
|
|
||||||
|
|
||||||
|
TensorShape = Union[torch.Size, List[int], Tuple[int]]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) -> Tuple[TensorShape, bool]:
|
||||||
|
"""get the exact tensor shape when communicating and return whether the tensor is a chunk
|
||||||
|
|
||||||
|
:param tensor_shape: shape of tensor
|
||||||
|
:type tensor_shape: TensorShape
|
||||||
|
:param chunk_tensor: whether to chunk tensor, defaults to False
|
||||||
|
:type chunk_tensor: bool, optional
|
||||||
|
:return: exact tensor shape, whether to chunk tensor
|
||||||
|
:rtype: Tuple[Union[torch.Size, List[int], Tuple[int]], bool]
|
||||||
|
"""
|
||||||
|
if chunk_tensor:
|
||||||
|
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1)
|
||||||
|
tensor_parallel_world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||||
|
if tensor_chunk_shape % tensor_parallel_world_size == 0:
|
||||||
|
tensor_chunk_shape = tensor_chunk_shape // tensor_parallel_world_size
|
||||||
|
else:
|
||||||
|
tensor_chunk_shape = tensor_shape
|
||||||
|
chunk_tensor = False
|
||||||
|
else:
|
||||||
|
tensor_chunk_shape = tensor_shape
|
||||||
|
return tensor_chunk_shape, chunk_tensor
|
||||||
|
|
||||||
|
|
||||||
def _communicate(tensor_send_next=None,
|
def _communicate(tensor_send_next=None,
|
||||||
|
@ -17,7 +47,8 @@ def _communicate(tensor_send_next=None,
|
||||||
recv_next_shape=None,
|
recv_next_shape=None,
|
||||||
prev_rank=None,
|
prev_rank=None,
|
||||||
next_rank=None,
|
next_rank=None,
|
||||||
dtype=None):
|
dtype=None,
|
||||||
|
scatter_gather_tensors=False):
|
||||||
"""
|
"""
|
||||||
Adapted from megatron.p2p_communication.
|
Adapted from megatron.p2p_communication.
|
||||||
Communicate tensors between stages. Used as helper method in other
|
Communicate tensors between stages. Used as helper method in other
|
||||||
|
@ -42,13 +73,15 @@ def _communicate(tensor_send_next=None,
|
||||||
|
|
||||||
if recv_prev:
|
if recv_prev:
|
||||||
assert recv_prev_shape is not None
|
assert recv_prev_shape is not None
|
||||||
tensor_recv_prev = torch.empty(recv_prev_shape,
|
recv_prev_chunk_shape, recv_prev_split = _get_tensor_shape(recv_prev_shape, scatter_gather_tensors)
|
||||||
|
tensor_recv_prev = torch.empty(recv_prev_chunk_shape,
|
||||||
requires_grad=True,
|
requires_grad=True,
|
||||||
device=get_current_device(),
|
device=get_current_device(),
|
||||||
dtype=dtype)
|
dtype=dtype)
|
||||||
if recv_next:
|
if recv_next:
|
||||||
assert recv_next_shape is not None
|
assert recv_next_shape is not None
|
||||||
tensor_recv_next = torch.empty(recv_next_shape,
|
recv_next_chunk_shape, recv_next_split = _get_tensor_shape(recv_next_shape, scatter_gather_tensors)
|
||||||
|
tensor_recv_next = torch.empty(recv_next_chunk_shape,
|
||||||
requires_grad=True,
|
requires_grad=True,
|
||||||
device=get_current_device(),
|
device=get_current_device(),
|
||||||
dtype=dtype)
|
dtype=dtype)
|
||||||
|
@ -63,6 +96,16 @@ def _communicate(tensor_send_next=None,
|
||||||
next_rank = gpc.get_next_global_rank(
|
next_rank = gpc.get_next_global_rank(
|
||||||
ParallelMode.PIPELINE)
|
ParallelMode.PIPELINE)
|
||||||
|
|
||||||
|
if tensor_send_prev is not None:
|
||||||
|
send_prev_split = _get_tensor_shape(tensor_send_prev.shape, scatter_gather_tensors)[1]
|
||||||
|
if send_prev_split:
|
||||||
|
tensor_send_prev = split_tensor_into_1d_equal_chunks(tensor_send_prev)
|
||||||
|
|
||||||
|
if tensor_send_next is not None:
|
||||||
|
send_next_split = _get_tensor_shape(tensor_send_next.shape, scatter_gather_tensors)[1]
|
||||||
|
if send_next_split:
|
||||||
|
tensor_send_next = split_tensor_into_1d_equal_chunks(tensor_send_next)
|
||||||
|
|
||||||
ops = []
|
ops = []
|
||||||
if tensor_send_prev is not None:
|
if tensor_send_prev is not None:
|
||||||
send_prev_op = dist.P2POp(dist.isend, tensor_send_prev, prev_rank)
|
send_prev_op = dist.P2POp(dist.isend, tensor_send_prev, prev_rank)
|
||||||
|
@ -82,10 +125,15 @@ def _communicate(tensor_send_next=None,
|
||||||
req.wait()
|
req.wait()
|
||||||
# To protect against race condition when using batch_isend_irecv().
|
# To protect against race condition when using batch_isend_irecv().
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
if recv_prev and recv_prev_split:
|
||||||
|
tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_()
|
||||||
|
if recv_next and recv_next_split:
|
||||||
|
tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_()
|
||||||
return tensor_recv_prev, tensor_recv_next
|
return tensor_recv_prev, tensor_recv_next
|
||||||
|
|
||||||
|
|
||||||
def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float):
|
def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_gather_tensors=False):
|
||||||
"""Receives the input tensor from the previous member in pipeline.
|
"""Receives the input tensor from the previous member in pipeline.
|
||||||
|
|
||||||
:param input_tensor_shape: The shape of the tensor to be recieved
|
:param input_tensor_shape: The shape of the tensor to be recieved
|
||||||
|
@ -101,11 +149,12 @@ def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float):
|
||||||
input_tensor, _ = _communicate(recv_prev=True,
|
input_tensor, _ = _communicate(recv_prev=True,
|
||||||
recv_prev_shape=input_tensor_shape,
|
recv_prev_shape=input_tensor_shape,
|
||||||
prev_rank=prev_rank,
|
prev_rank=prev_rank,
|
||||||
dtype=dtype)
|
dtype=dtype,
|
||||||
|
scatter_gather_tensors=scatter_gather_tensors)
|
||||||
return input_tensor
|
return input_tensor
|
||||||
|
|
||||||
|
|
||||||
def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float):
|
def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False):
|
||||||
"""Receives the grad tensor from the next member in pipeline.
|
"""Receives the grad tensor from the next member in pipeline.
|
||||||
|
|
||||||
:param output_grad_shape: The shape of the tensor to be recieved
|
:param output_grad_shape: The shape of the tensor to be recieved
|
||||||
|
@ -121,11 +170,12 @@ def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float):
|
||||||
_, output_tensor_grad = _communicate(recv_next=True,
|
_, output_tensor_grad = _communicate(recv_next=True,
|
||||||
recv_next_shape=output_grad_shape,
|
recv_next_shape=output_grad_shape,
|
||||||
next_rank=next_rank,
|
next_rank=next_rank,
|
||||||
dtype=dtype)
|
dtype=dtype,
|
||||||
|
scatter_gather_tensors=scatter_gather_tensors)
|
||||||
return output_tensor_grad
|
return output_tensor_grad
|
||||||
|
|
||||||
|
|
||||||
def send_forward(output_tensor, next_rank=None):
|
def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False):
|
||||||
"""Sends the input tensor to the next member in pipeline.
|
"""Sends the input tensor to the next member in pipeline.
|
||||||
|
|
||||||
:param output_tensor: Tensor to be sent
|
:param output_tensor: Tensor to be sent
|
||||||
|
@ -135,10 +185,11 @@ def send_forward(output_tensor, next_rank=None):
|
||||||
"""
|
"""
|
||||||
if not gpc.is_pipeline_last_stage():
|
if not gpc.is_pipeline_last_stage():
|
||||||
_communicate(tensor_send_next=output_tensor,
|
_communicate(tensor_send_next=output_tensor,
|
||||||
next_rank=next_rank)
|
next_rank=next_rank,
|
||||||
|
scatter_gather_tensors=scatter_gather_tensors)
|
||||||
|
|
||||||
|
|
||||||
def send_backward(input_tensor_grad, prev_rank=None):
|
def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False):
|
||||||
"""Sends the grad tensor to the previous member in pipeline.
|
"""Sends the grad tensor to the previous member in pipeline.
|
||||||
|
|
||||||
:param input_tensor_grad: Tensor to be sent
|
:param input_tensor_grad: Tensor to be sent
|
||||||
|
@ -148,14 +199,16 @@ def send_backward(input_tensor_grad, prev_rank=None):
|
||||||
"""
|
"""
|
||||||
if not gpc.is_pipeline_first_stage():
|
if not gpc.is_pipeline_first_stage():
|
||||||
_communicate(tensor_send_prev=input_tensor_grad,
|
_communicate(tensor_send_prev=input_tensor_grad,
|
||||||
prev_rank=prev_rank)
|
prev_rank=prev_rank,
|
||||||
|
scatter_gather_tensors=scatter_gather_tensors)
|
||||||
|
|
||||||
|
|
||||||
def send_forward_recv_backward(output_tensor,
|
def send_forward_recv_backward(output_tensor,
|
||||||
output_grad_shape,
|
output_grad_shape,
|
||||||
recv_next=True,
|
recv_next=True,
|
||||||
next_rank=None,
|
next_rank=None,
|
||||||
dtype=torch.float):
|
dtype=torch.float,
|
||||||
|
scatter_gather_tensors=False):
|
||||||
"""Batched communication operation. Sends the input tensor to the
|
"""Batched communication operation. Sends the input tensor to the
|
||||||
next member in pipeline, while recieves the grad tensor from the
|
next member in pipeline, while recieves the grad tensor from the
|
||||||
next member in pipeline.
|
next member in pipeline.
|
||||||
|
@ -174,7 +227,8 @@ def send_forward_recv_backward(output_tensor,
|
||||||
recv_next=recv_next,
|
recv_next=recv_next,
|
||||||
recv_next_shape=output_grad_shape,
|
recv_next_shape=output_grad_shape,
|
||||||
next_rank=next_rank,
|
next_rank=next_rank,
|
||||||
dtype=dtype)
|
dtype=dtype,
|
||||||
|
scatter_gather_tensors=scatter_gather_tensors)
|
||||||
return output_tensor_grad
|
return output_tensor_grad
|
||||||
|
|
||||||
|
|
||||||
|
@ -182,7 +236,8 @@ def send_backward_recv_forward(input_tensor_grad,
|
||||||
input_tensor_shape,
|
input_tensor_shape,
|
||||||
recv_prev=True,
|
recv_prev=True,
|
||||||
prev_rank=None,
|
prev_rank=None,
|
||||||
dtype=torch.float):
|
dtype=torch.float,
|
||||||
|
scatter_gather_tensors=False):
|
||||||
"""Batched communication operation. Sends the grad tensor to the
|
"""Batched communication operation. Sends the grad tensor to the
|
||||||
previous member in pipeline, while recieves the input tensor from the
|
previous member in pipeline, while recieves the input tensor from the
|
||||||
previous member in pipeline.
|
previous member in pipeline.
|
||||||
|
@ -201,7 +256,8 @@ def send_backward_recv_forward(input_tensor_grad,
|
||||||
recv_prev=recv_prev,
|
recv_prev=recv_prev,
|
||||||
recv_prev_shape=input_tensor_shape,
|
recv_prev_shape=input_tensor_shape,
|
||||||
prev_rank=prev_rank,
|
prev_rank=prev_rank,
|
||||||
dtype=dtype)
|
dtype=dtype,
|
||||||
|
scatter_gather_tensors=scatter_gather_tensors)
|
||||||
return input_tensor
|
return input_tensor
|
||||||
|
|
||||||
|
|
||||||
|
@ -210,7 +266,8 @@ def send_forward_recv_forward(output_tensor,
|
||||||
recv_prev=True,
|
recv_prev=True,
|
||||||
prev_rank=None,
|
prev_rank=None,
|
||||||
next_rank=None,
|
next_rank=None,
|
||||||
dtype=torch.float):
|
dtype=torch.float,
|
||||||
|
scatter_gather_tensors=False):
|
||||||
"""Batched communication operation. Sends the input tensor to the
|
"""Batched communication operation. Sends the input tensor to the
|
||||||
next member in pipeline, while recieves the input tensor from the
|
next member in pipeline, while recieves the input tensor from the
|
||||||
previous member in pipeline.
|
previous member in pipeline.
|
||||||
|
@ -227,7 +284,8 @@ def send_forward_recv_forward(output_tensor,
|
||||||
recv_prev_shape=input_tensor_shape,
|
recv_prev_shape=input_tensor_shape,
|
||||||
prev_rank=prev_rank,
|
prev_rank=prev_rank,
|
||||||
next_rank=next_rank,
|
next_rank=next_rank,
|
||||||
dtype=dtype)
|
dtype=dtype,
|
||||||
|
scatter_gather_tensors=scatter_gather_tensors)
|
||||||
return input_tensor
|
return input_tensor
|
||||||
|
|
||||||
|
|
||||||
|
@ -236,7 +294,8 @@ def send_backward_recv_backward(input_tensor_grad,
|
||||||
recv_next=True,
|
recv_next=True,
|
||||||
prev_rank=None,
|
prev_rank=None,
|
||||||
next_rank=None,
|
next_rank=None,
|
||||||
dtype=torch.float):
|
dtype=torch.float,
|
||||||
|
scatter_gather_tensors=False):
|
||||||
"""Batched communication operation. Sends the grad tensor to the
|
"""Batched communication operation. Sends the grad tensor to the
|
||||||
previous member in pipeline, while recieves the grad tensor from the
|
previous member in pipeline, while recieves the grad tensor from the
|
||||||
next member in pipeline.
|
next member in pipeline.
|
||||||
|
@ -253,7 +312,8 @@ def send_backward_recv_backward(input_tensor_grad,
|
||||||
recv_next_shape=output_grad_shape,
|
recv_next_shape=output_grad_shape,
|
||||||
prev_rank=prev_rank,
|
prev_rank=prev_rank,
|
||||||
next_rank=next_rank,
|
next_rank=next_rank,
|
||||||
dtype=dtype)
|
dtype=dtype,
|
||||||
|
scatter_gather_tensors=scatter_gather_tensors)
|
||||||
return output_tensor_grad
|
return output_tensor_grad
|
||||||
|
|
||||||
|
|
||||||
|
@ -265,7 +325,8 @@ def send_forward_backward_recv_forward_backward(output_tensor,
|
||||||
recv_next=True,
|
recv_next=True,
|
||||||
prev_rank=None,
|
prev_rank=None,
|
||||||
next_rank=None,
|
next_rank=None,
|
||||||
dtype=torch.float):
|
dtype=torch.float,
|
||||||
|
scatter_gather_tensors=False):
|
||||||
"""Batched communication operation. Sends the input tensor to the next and
|
"""Batched communication operation. Sends the input tensor to the next and
|
||||||
the grad tensor to the previous, while recieves the grad tensor from the
|
the grad tensor to the previous, while recieves the grad tensor from the
|
||||||
next and the input tensor from the previous.
|
next and the input tensor from the previous.
|
||||||
|
@ -290,5 +351,6 @@ def send_forward_backward_recv_forward_backward(output_tensor,
|
||||||
recv_next_shape=output_grad_shape,
|
recv_next_shape=output_grad_shape,
|
||||||
prev_rank=prev_rank,
|
prev_rank=prev_rank,
|
||||||
next_rank=next_rank,
|
next_rank=next_rank,
|
||||||
dtype=dtype)
|
dtype=dtype,
|
||||||
|
scatter_gather_tensors=scatter_gather_tensors)
|
||||||
return input_tensor, output_tensor_grad
|
return input_tensor, output_tensor_grad
|
||||||
|
|
|
@ -62,3 +62,31 @@ def recv_tensor_meta(tensor_shape, prev_rank=None):
|
||||||
tensor_shape = torch.Size(recv_shape)
|
tensor_shape = torch.Size(recv_shape)
|
||||||
|
|
||||||
return tensor_shape
|
return tensor_shape
|
||||||
|
|
||||||
|
|
||||||
|
def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
|
||||||
|
"""Break a tensor into equal 1D chunks."""
|
||||||
|
partition_size = torch.numel(tensor) // gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||||
|
start_index = partition_size * gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||||
|
end_index = start_index + partition_size
|
||||||
|
if new_buffer:
|
||||||
|
data = torch.empty(partition_size, dtype=tensor.dtype,
|
||||||
|
device=torch.cuda.current_device(),
|
||||||
|
requires_grad=False)
|
||||||
|
data.copy_(tensor.view(-1)[start_index:end_index])
|
||||||
|
else:
|
||||||
|
data = tensor.view(-1)[start_index:end_index]
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def gather_split_1d_tensor(tensor):
|
||||||
|
"""Opposite of above function, gather values from model parallel ranks."""
|
||||||
|
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||||
|
numel = torch.numel(tensor)
|
||||||
|
numel_gathered = world_size * numel
|
||||||
|
gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
|
||||||
|
device=torch.cuda.current_device(),
|
||||||
|
requires_grad=False)
|
||||||
|
chunks = [gathered[i*numel:(i+1)*numel] for i in range(world_size)]
|
||||||
|
dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||||
|
return gathered
|
||||||
|
|
|
@ -6,7 +6,7 @@ import inspect
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from colossalai.communication import *
|
import colossalai.communication as comm
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.amp.naive_amp import NaiveAMPModel
|
from colossalai.amp.naive_amp import NaiveAMPModel
|
||||||
|
@ -33,16 +33,22 @@ class PipelineSchedule(BaseSchedule):
|
||||||
:type num_microbatches: int
|
:type num_microbatches: int
|
||||||
:param batch_data_process_func: The preprocessing function which receives a batch of data, and it will be executed in `load_batch`
|
:param batch_data_process_func: The preprocessing function which receives a batch of data, and it will be executed in `load_batch`
|
||||||
:type batch_data_process_func: Callable
|
:type batch_data_process_func: Callable
|
||||||
|
:param scatter_gather_tensors: If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization
|
||||||
|
:type scatter_gather_tensors: bool
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_microbatches,
|
num_microbatches,
|
||||||
batch_data_process_func: Callable = None,
|
batch_data_process_func: Callable = None,
|
||||||
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None):
|
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
|
||||||
|
scatter_gather_tensors: bool = False):
|
||||||
super().__init__(batch_data_process_func=batch_data_process_func)
|
super().__init__(batch_data_process_func=batch_data_process_func)
|
||||||
self.num_microbatches = num_microbatches
|
self.num_microbatches = num_microbatches
|
||||||
self.dtype = torch.float
|
self.dtype = torch.float
|
||||||
self.tensor_shape = tensor_shape
|
self.tensor_shape = tensor_shape
|
||||||
|
self.scatter_gather_tensors = False
|
||||||
|
if gpc.is_initialized(ParallelMode.PARALLEL_1D) and gpc.get_world_size(ParallelMode.PARALLEL_1D) > 1:
|
||||||
|
self.scatter_gather_tensors = scatter_gather_tensors
|
||||||
|
|
||||||
def load_batch(self, data_iter):
|
def load_batch(self, data_iter):
|
||||||
# Pipeline schedule just puts data in memory
|
# Pipeline schedule just puts data in memory
|
||||||
|
@ -227,8 +233,9 @@ class PipelineSchedule(BaseSchedule):
|
||||||
# Run warmup forward passes.
|
# Run warmup forward passes.
|
||||||
for i in range(num_warmup_microbatches):
|
for i in range(num_warmup_microbatches):
|
||||||
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||||
ft_shape = recv_tensor_meta(ft_shape)
|
ft_shape = comm.recv_tensor_meta(ft_shape)
|
||||||
input_tensor = recv_forward(ft_shape, dtype=self.dtype)
|
input_tensor = comm.recv_forward(ft_shape, dtype=self.dtype,
|
||||||
|
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||||
output_tensor = self.forward_step(
|
output_tensor = self.forward_step(
|
||||||
engine, input_tensor, return_tensors,
|
engine, input_tensor, return_tensors,
|
||||||
return_output_label=return_output_label,
|
return_output_label=return_output_label,
|
||||||
|
@ -236,8 +243,8 @@ class PipelineSchedule(BaseSchedule):
|
||||||
)
|
)
|
||||||
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||||
bt_shape = output_tensor.shape
|
bt_shape = output_tensor.shape
|
||||||
fs_checker = send_tensor_meta(output_tensor, fs_checker)
|
fs_checker = comm.send_tensor_meta(output_tensor, fs_checker)
|
||||||
send_forward(output_tensor)
|
comm.send_forward(output_tensor, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||||
|
|
||||||
if not forward_only:
|
if not forward_only:
|
||||||
input_tensors.append(input_tensor)
|
input_tensors.append(input_tensor)
|
||||||
|
@ -248,8 +255,9 @@ class PipelineSchedule(BaseSchedule):
|
||||||
# receive this tensor here.
|
# receive this tensor here.
|
||||||
if num_microbatches_remaining > 0:
|
if num_microbatches_remaining > 0:
|
||||||
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||||
ft_shape = recv_tensor_meta(ft_shape)
|
ft_shape = comm.recv_tensor_meta(ft_shape)
|
||||||
input_tensor = recv_forward(ft_shape, dtype=self.dtype)
|
input_tensor = comm.recv_forward(ft_shape, dtype=self.dtype,
|
||||||
|
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||||
|
|
||||||
# Run 1F1B in steady state.
|
# Run 1F1B in steady state.
|
||||||
for i in range(num_microbatches_remaining):
|
for i in range(num_microbatches_remaining):
|
||||||
|
@ -261,14 +269,15 @@ class PipelineSchedule(BaseSchedule):
|
||||||
accum_loss=accum_loss
|
accum_loss=accum_loss
|
||||||
)
|
)
|
||||||
if forward_only:
|
if forward_only:
|
||||||
send_forward(output_tensor)
|
comm.send_forward(output_tensor, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||||
|
|
||||||
if not last_iteration:
|
if not last_iteration:
|
||||||
input_tensor = recv_forward(ft_shape, dtype=self.dtype)
|
input_tensor = comm.recv_forward(ft_shape, dtype=self.dtype,
|
||||||
|
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
output_tensor_grad = send_forward_recv_backward(
|
output_tensor_grad = comm.send_forward_recv_backward(
|
||||||
output_tensor, bt_shape, dtype=self.dtype)
|
output_tensor, bt_shape, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||||
|
|
||||||
# Add input_tensor and output_tensor to end of list.
|
# Add input_tensor and output_tensor to end of list.
|
||||||
input_tensors.append(input_tensor)
|
input_tensors.append(input_tensor)
|
||||||
|
@ -287,10 +296,10 @@ class PipelineSchedule(BaseSchedule):
|
||||||
|
|
||||||
if last_iteration:
|
if last_iteration:
|
||||||
input_tensor = None
|
input_tensor = None
|
||||||
send_backward(input_tensor_grad)
|
comm.send_backward(input_tensor_grad, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||||
else:
|
else:
|
||||||
input_tensor = send_backward_recv_forward(
|
input_tensor = comm.send_backward_recv_forward(
|
||||||
input_tensor_grad, ft_shape, dtype=self.dtype)
|
input_tensor_grad, ft_shape, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||||
|
|
||||||
# Run cooldown backward passes.
|
# Run cooldown backward passes.
|
||||||
if not forward_only:
|
if not forward_only:
|
||||||
|
@ -298,7 +307,8 @@ class PipelineSchedule(BaseSchedule):
|
||||||
input_tensor = input_tensors.pop(0)
|
input_tensor = input_tensors.pop(0)
|
||||||
output_tensor = output_tensors.pop(0)
|
output_tensor = output_tensors.pop(0)
|
||||||
|
|
||||||
output_tensor_grad = recv_backward(bt_shape, dtype=self.dtype)
|
output_tensor_grad = comm.recv_backward(bt_shape, dtype=self.dtype,
|
||||||
|
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||||
|
|
||||||
input_tensor_grad = self.backward_step(
|
input_tensor_grad = self.backward_step(
|
||||||
engine,
|
engine,
|
||||||
|
@ -306,7 +316,7 @@ class PipelineSchedule(BaseSchedule):
|
||||||
output_tensor_grad
|
output_tensor_grad
|
||||||
)
|
)
|
||||||
|
|
||||||
send_backward(input_tensor_grad)
|
comm.send_backward(input_tensor_grad, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||||
|
|
||||||
if len(return_tensors) > 0:
|
if len(return_tensors) > 0:
|
||||||
output, label = tuple(map(list, zip(*return_tensors)))
|
output, label = tuple(map(list, zip(*return_tensors)))
|
||||||
|
@ -322,7 +332,8 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||||
num_microbatches,
|
num_microbatches,
|
||||||
num_model_chunks,
|
num_model_chunks,
|
||||||
batch_data_process_func: Callable = None,
|
batch_data_process_func: Callable = None,
|
||||||
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None):
|
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
|
||||||
|
scatter_gather_tensors: bool = False):
|
||||||
"""A helper schedule class for pipeline parallelism running environment.
|
"""A helper schedule class for pipeline parallelism running environment.
|
||||||
It uses interleaved 1F1B strategy. Other properties are similar as
|
It uses interleaved 1F1B strategy. Other properties are similar as
|
||||||
:class:`NonPipelineSchedule`.
|
:class:`NonPipelineSchedule`.
|
||||||
|
@ -333,10 +344,13 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||||
:type num_model_chunks: int
|
:type num_model_chunks: int
|
||||||
:param batch_data_process_func: The preprocessing function which receives a batch of data, and it will be executed in `load_batch`
|
:param batch_data_process_func: The preprocessing function which receives a batch of data, and it will be executed in `load_batch`
|
||||||
:type batch_data_process_func: Callable
|
:type batch_data_process_func: Callable
|
||||||
|
:param scatter_gather_tensors: If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization
|
||||||
|
:type scatter_gather_tensors: bool
|
||||||
"""
|
"""
|
||||||
assert num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0, \
|
assert num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0, \
|
||||||
'num_microbatches must be an integer multiple of pipeline parallel world size'
|
'num_microbatches must be an integer multiple of pipeline parallel world size'
|
||||||
super().__init__(num_microbatches, batch_data_process_func=batch_data_process_func, tensor_shape=tensor_shape)
|
super().__init__(num_microbatches, batch_data_process_func=batch_data_process_func,
|
||||||
|
tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather_tensors)
|
||||||
gpc.set_virtual_pipeline_parallel_size(num_model_chunks)
|
gpc.set_virtual_pipeline_parallel_size(num_model_chunks)
|
||||||
gpc.set_virtual_pipeline_parallel_rank(0)
|
gpc.set_virtual_pipeline_parallel_rank(0)
|
||||||
self.num_model_chunks = num_model_chunks
|
self.num_model_chunks = num_model_chunks
|
||||||
|
@ -494,15 +508,16 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||||
# Run warmup forward passes.
|
# Run warmup forward passes.
|
||||||
gpc.set_virtual_pipeline_parallel_rank(0)
|
gpc.set_virtual_pipeline_parallel_rank(0)
|
||||||
if not gpc.is_pipeline_first_stage():
|
if not gpc.is_pipeline_first_stage():
|
||||||
input_tensor_shapes[0] = recv_tensor_meta(input_tensor_shapes[0])
|
input_tensor_shapes[0] = comm.recv_tensor_meta(input_tensor_shapes[0])
|
||||||
input_tensors[0].append(recv_forward(input_tensor_shapes[0], dtype=self.dtype))
|
input_tensors[0].append(comm.recv_forward(input_tensor_shapes[0], dtype=self.dtype,
|
||||||
|
scatter_gather_tensors=self.scatter_gather_tensors))
|
||||||
|
|
||||||
for k in range(num_warmup_microbatches):
|
for k in range(num_warmup_microbatches):
|
||||||
model_chunk_id = get_model_chunk_id(k, forward=True)
|
model_chunk_id = get_model_chunk_id(k, forward=True)
|
||||||
output_tensor = forward_step_helper(k)
|
output_tensor = forward_step_helper(k)
|
||||||
if not gpc.is_pipeline_last_stage():
|
if not gpc.is_pipeline_last_stage():
|
||||||
output_tensor_shapes[model_chunk_id] = output_tensor.shape
|
output_tensor_shapes[model_chunk_id] = output_tensor.shape
|
||||||
send_tensor_shape_flags[model_chunk_id] = send_tensor_meta(
|
send_tensor_shape_flags[model_chunk_id] = comm.send_tensor_meta(
|
||||||
output_tensor, send_tensor_shape_flags[model_chunk_id])
|
output_tensor, send_tensor_shape_flags[model_chunk_id])
|
||||||
# Determine if tensor should be received from previous stage.
|
# Determine if tensor should be received from previous stage.
|
||||||
next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True)
|
next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True)
|
||||||
|
@ -519,7 +534,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||||
|
|
||||||
with switch_virtual_pipeline_parallel_rank(next_forward_model_chunk_id):
|
with switch_virtual_pipeline_parallel_rank(next_forward_model_chunk_id):
|
||||||
if not gpc.is_pipeline_first_stage():
|
if not gpc.is_pipeline_first_stage():
|
||||||
input_tensor_shapes[next_forward_model_chunk_id] = recv_tensor_meta(
|
input_tensor_shapes[next_forward_model_chunk_id] = comm.recv_tensor_meta(
|
||||||
input_tensor_shapes[next_forward_model_chunk_id])
|
input_tensor_shapes[next_forward_model_chunk_id])
|
||||||
# Send and receive tensors as appropriate (send tensors computed
|
# Send and receive tensors as appropriate (send tensors computed
|
||||||
# in this iteration; receive tensors for next iteration).
|
# in this iteration; receive tensors for next iteration).
|
||||||
|
@ -532,20 +547,22 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||||
recv_next = False
|
recv_next = False
|
||||||
output_shape = output_tensor_shapes[num_model_chunks-1] if recv_next else None
|
output_shape = output_tensor_shapes[num_model_chunks-1] if recv_next else None
|
||||||
input_tensor, output_tensor_grad = \
|
input_tensor, output_tensor_grad = \
|
||||||
send_forward_backward_recv_forward_backward(
|
comm.send_forward_backward_recv_forward_backward(
|
||||||
output_tensor, input_tensor_grad,
|
output_tensor, input_tensor_grad,
|
||||||
input_shape,
|
input_shape,
|
||||||
output_shape,
|
output_shape,
|
||||||
recv_prev=recv_prev, recv_next=recv_next,
|
recv_prev=recv_prev, recv_next=recv_next,
|
||||||
dtype=self.dtype)
|
dtype=self.dtype,
|
||||||
|
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||||
output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
|
output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
|
||||||
else:
|
else:
|
||||||
input_tensor = \
|
input_tensor = \
|
||||||
send_forward_recv_forward(
|
comm.send_forward_recv_forward(
|
||||||
output_tensor,
|
output_tensor,
|
||||||
input_shape,
|
input_shape,
|
||||||
recv_prev=recv_prev,
|
recv_prev=recv_prev,
|
||||||
dtype=self.dtype)
|
dtype=self.dtype,
|
||||||
|
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||||
input_tensors[next_forward_model_chunk_id].append(input_tensor)
|
input_tensors[next_forward_model_chunk_id].append(input_tensor)
|
||||||
|
|
||||||
# Run 1F1B in steady state.
|
# Run 1F1B in steady state.
|
||||||
|
@ -608,12 +625,13 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||||
output_shape = output_tensor_shapes[next_backward_model_chunk_id] if recv_next else None
|
output_shape = output_tensor_shapes[next_backward_model_chunk_id] if recv_next else None
|
||||||
# Communicate tensors.
|
# Communicate tensors.
|
||||||
input_tensor, output_tensor_grad = \
|
input_tensor, output_tensor_grad = \
|
||||||
send_forward_backward_recv_forward_backward(
|
comm.send_forward_backward_recv_forward_backward(
|
||||||
output_tensor, input_tensor_grad,
|
output_tensor, input_tensor_grad,
|
||||||
input_shape,
|
input_shape,
|
||||||
output_shape,
|
output_shape,
|
||||||
recv_prev=recv_prev, recv_next=recv_next,
|
recv_prev=recv_prev, recv_next=recv_next,
|
||||||
dtype=self.dtype)
|
dtype=self.dtype,
|
||||||
|
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||||
|
|
||||||
# Put input_tensor and output_tensor_grad in data structures in the
|
# Put input_tensor and output_tensor_grad in data structures in the
|
||||||
# right location.
|
# right location.
|
||||||
|
@ -627,7 +645,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||||
if not forward_only:
|
if not forward_only:
|
||||||
if all_warmup_microbatches:
|
if all_warmup_microbatches:
|
||||||
output_tensor_grads[num_model_chunks-1].append(
|
output_tensor_grads[num_model_chunks-1].append(
|
||||||
recv_backward(output_tensor_shapes[num_model_chunks-1]))
|
comm.recv_backward(output_tensor_shapes[num_model_chunks-1], scatter_gather_tensors=self.scatter_gather_tensors))
|
||||||
for k in range(num_microbatches_remaining, num_microbatches):
|
for k in range(num_microbatches_remaining, num_microbatches):
|
||||||
input_tensor_grad = backward_step_helper(k)
|
input_tensor_grad = backward_step_helper(k)
|
||||||
next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
|
next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
|
||||||
|
@ -639,11 +657,12 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||||
recv_next = False
|
recv_next = False
|
||||||
output_shape = output_tensor_shapes[next_backward_model_chunk_id] if recv_next else None
|
output_shape = output_tensor_shapes[next_backward_model_chunk_id] if recv_next else None
|
||||||
output_tensor_grads[next_backward_model_chunk_id].append(
|
output_tensor_grads[next_backward_model_chunk_id].append(
|
||||||
send_backward_recv_backward(
|
comm.send_backward_recv_backward(
|
||||||
input_tensor_grad,
|
input_tensor_grad,
|
||||||
output_shape,
|
output_shape,
|
||||||
recv_next=recv_next,
|
recv_next=recv_next,
|
||||||
dtype=self.dtype))
|
dtype=self.dtype,
|
||||||
|
scatter_gather_tensors=self.scatter_gather_tensors))
|
||||||
|
|
||||||
if len(return_tensors) > 0:
|
if len(return_tensors) > 0:
|
||||||
output, label = tuple(map(list, zip(*return_tensors)))
|
output, label = tuple(map(list, zip(*return_tensors)))
|
||||||
|
|
|
@ -290,9 +290,10 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
||||||
# initialize amp
|
# initialize amp
|
||||||
amp_mode = None
|
amp_mode = None
|
||||||
if fp16_cfg is not None and fp16_cfg.mode is not None:
|
if fp16_cfg is not None and fp16_cfg.mode is not None:
|
||||||
# TODO: pipeline only support NAIVE AMP
|
|
||||||
cfg_ = fp16_cfg.copy()
|
cfg_ = fp16_cfg.copy()
|
||||||
amp_mode = cfg_.pop('mode')
|
amp_mode = cfg_.pop('mode')
|
||||||
|
if is_using_pp():
|
||||||
|
assert amp_mode == AMP_TYPE.NAIVE, 'Pipeline only support NaiveAMP currently'
|
||||||
if amp_mode == AMP_TYPE.NAIVE:
|
if amp_mode == AMP_TYPE.NAIVE:
|
||||||
cfg_['clip_grad'] = clip_grad_norm
|
cfg_['clip_grad'] = clip_grad_norm
|
||||||
model, optimizer, criterion = convert_to_amp(model=model,
|
model, optimizer, criterion = convert_to_amp(model=model,
|
||||||
|
|
Loading…
Reference in New Issue