mirror of https://github.com/hpcaitech/ColossalAI
add scatter/gather optim for pipeline (#123)
parent
404e6f88ed
commit
293fb40c42
|
@ -13,5 +13,5 @@ __all__ = [
|
|||
'send_forward_backward_recv_forward_backward', 'send_backward',
|
||||
'send_backward_recv_backward', 'send_backward_recv_forward',
|
||||
'send_forward_recv_backward', 'recv_backward', 'recv_forward',
|
||||
'ring_forward', 'send_tensor_meta', 'recv_tensor_meta'
|
||||
]
|
||||
'ring_forward', 'send_tensor_meta', 'recv_tensor_meta',
|
||||
]
|
||||
|
|
|
@ -1,12 +1,42 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import List, Tuple, Union
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import get_current_device
|
||||
from functools import reduce
|
||||
import operator
|
||||
from .utils import split_tensor_into_1d_equal_chunks, gather_split_1d_tensor
|
||||
|
||||
|
||||
TensorShape = Union[torch.Size, List[int], Tuple[int]]
|
||||
|
||||
|
||||
def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) -> Tuple[TensorShape, bool]:
|
||||
"""get the exact tensor shape when communicating and return whether the tensor is a chunk
|
||||
|
||||
:param tensor_shape: shape of tensor
|
||||
:type tensor_shape: TensorShape
|
||||
:param chunk_tensor: whether to chunk tensor, defaults to False
|
||||
:type chunk_tensor: bool, optional
|
||||
:return: exact tensor shape, whether to chunk tensor
|
||||
:rtype: Tuple[Union[torch.Size, List[int], Tuple[int]], bool]
|
||||
"""
|
||||
if chunk_tensor:
|
||||
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1)
|
||||
tensor_parallel_world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
if tensor_chunk_shape % tensor_parallel_world_size == 0:
|
||||
tensor_chunk_shape = tensor_chunk_shape // tensor_parallel_world_size
|
||||
else:
|
||||
tensor_chunk_shape = tensor_shape
|
||||
chunk_tensor = False
|
||||
else:
|
||||
tensor_chunk_shape = tensor_shape
|
||||
return tensor_chunk_shape, chunk_tensor
|
||||
|
||||
|
||||
def _communicate(tensor_send_next=None,
|
||||
|
@ -17,7 +47,8 @@ def _communicate(tensor_send_next=None,
|
|||
recv_next_shape=None,
|
||||
prev_rank=None,
|
||||
next_rank=None,
|
||||
dtype=None):
|
||||
dtype=None,
|
||||
scatter_gather_tensors=False):
|
||||
"""
|
||||
Adapted from megatron.p2p_communication.
|
||||
Communicate tensors between stages. Used as helper method in other
|
||||
|
@ -42,13 +73,15 @@ def _communicate(tensor_send_next=None,
|
|||
|
||||
if recv_prev:
|
||||
assert recv_prev_shape is not None
|
||||
tensor_recv_prev = torch.empty(recv_prev_shape,
|
||||
recv_prev_chunk_shape, recv_prev_split = _get_tensor_shape(recv_prev_shape, scatter_gather_tensors)
|
||||
tensor_recv_prev = torch.empty(recv_prev_chunk_shape,
|
||||
requires_grad=True,
|
||||
device=get_current_device(),
|
||||
dtype=dtype)
|
||||
if recv_next:
|
||||
assert recv_next_shape is not None
|
||||
tensor_recv_next = torch.empty(recv_next_shape,
|
||||
recv_next_chunk_shape, recv_next_split = _get_tensor_shape(recv_next_shape, scatter_gather_tensors)
|
||||
tensor_recv_next = torch.empty(recv_next_chunk_shape,
|
||||
requires_grad=True,
|
||||
device=get_current_device(),
|
||||
dtype=dtype)
|
||||
|
@ -63,6 +96,16 @@ def _communicate(tensor_send_next=None,
|
|||
next_rank = gpc.get_next_global_rank(
|
||||
ParallelMode.PIPELINE)
|
||||
|
||||
if tensor_send_prev is not None:
|
||||
send_prev_split = _get_tensor_shape(tensor_send_prev.shape, scatter_gather_tensors)[1]
|
||||
if send_prev_split:
|
||||
tensor_send_prev = split_tensor_into_1d_equal_chunks(tensor_send_prev)
|
||||
|
||||
if tensor_send_next is not None:
|
||||
send_next_split = _get_tensor_shape(tensor_send_next.shape, scatter_gather_tensors)[1]
|
||||
if send_next_split:
|
||||
tensor_send_next = split_tensor_into_1d_equal_chunks(tensor_send_next)
|
||||
|
||||
ops = []
|
||||
if tensor_send_prev is not None:
|
||||
send_prev_op = dist.P2POp(dist.isend, tensor_send_prev, prev_rank)
|
||||
|
@ -82,10 +125,15 @@ def _communicate(tensor_send_next=None,
|
|||
req.wait()
|
||||
# To protect against race condition when using batch_isend_irecv().
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if recv_prev and recv_prev_split:
|
||||
tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_()
|
||||
if recv_next and recv_next_split:
|
||||
tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_()
|
||||
return tensor_recv_prev, tensor_recv_next
|
||||
|
||||
|
||||
def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float):
|
||||
def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_gather_tensors=False):
|
||||
"""Receives the input tensor from the previous member in pipeline.
|
||||
|
||||
:param input_tensor_shape: The shape of the tensor to be recieved
|
||||
|
@ -101,11 +149,12 @@ def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float):
|
|||
input_tensor, _ = _communicate(recv_prev=True,
|
||||
recv_prev_shape=input_tensor_shape,
|
||||
prev_rank=prev_rank,
|
||||
dtype=dtype)
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors)
|
||||
return input_tensor
|
||||
|
||||
|
||||
def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float):
|
||||
def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False):
|
||||
"""Receives the grad tensor from the next member in pipeline.
|
||||
|
||||
:param output_grad_shape: The shape of the tensor to be recieved
|
||||
|
@ -121,11 +170,12 @@ def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float):
|
|||
_, output_tensor_grad = _communicate(recv_next=True,
|
||||
recv_next_shape=output_grad_shape,
|
||||
next_rank=next_rank,
|
||||
dtype=dtype)
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors)
|
||||
return output_tensor_grad
|
||||
|
||||
|
||||
def send_forward(output_tensor, next_rank=None):
|
||||
def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False):
|
||||
"""Sends the input tensor to the next member in pipeline.
|
||||
|
||||
:param output_tensor: Tensor to be sent
|
||||
|
@ -135,10 +185,11 @@ def send_forward(output_tensor, next_rank=None):
|
|||
"""
|
||||
if not gpc.is_pipeline_last_stage():
|
||||
_communicate(tensor_send_next=output_tensor,
|
||||
next_rank=next_rank)
|
||||
next_rank=next_rank,
|
||||
scatter_gather_tensors=scatter_gather_tensors)
|
||||
|
||||
|
||||
def send_backward(input_tensor_grad, prev_rank=None):
|
||||
def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False):
|
||||
"""Sends the grad tensor to the previous member in pipeline.
|
||||
|
||||
:param input_tensor_grad: Tensor to be sent
|
||||
|
@ -148,14 +199,16 @@ def send_backward(input_tensor_grad, prev_rank=None):
|
|||
"""
|
||||
if not gpc.is_pipeline_first_stage():
|
||||
_communicate(tensor_send_prev=input_tensor_grad,
|
||||
prev_rank=prev_rank)
|
||||
prev_rank=prev_rank,
|
||||
scatter_gather_tensors=scatter_gather_tensors)
|
||||
|
||||
|
||||
def send_forward_recv_backward(output_tensor,
|
||||
output_grad_shape,
|
||||
recv_next=True,
|
||||
next_rank=None,
|
||||
dtype=torch.float):
|
||||
dtype=torch.float,
|
||||
scatter_gather_tensors=False):
|
||||
"""Batched communication operation. Sends the input tensor to the
|
||||
next member in pipeline, while recieves the grad tensor from the
|
||||
next member in pipeline.
|
||||
|
@ -174,7 +227,8 @@ def send_forward_recv_backward(output_tensor,
|
|||
recv_next=recv_next,
|
||||
recv_next_shape=output_grad_shape,
|
||||
next_rank=next_rank,
|
||||
dtype=dtype)
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors)
|
||||
return output_tensor_grad
|
||||
|
||||
|
||||
|
@ -182,7 +236,8 @@ def send_backward_recv_forward(input_tensor_grad,
|
|||
input_tensor_shape,
|
||||
recv_prev=True,
|
||||
prev_rank=None,
|
||||
dtype=torch.float):
|
||||
dtype=torch.float,
|
||||
scatter_gather_tensors=False):
|
||||
"""Batched communication operation. Sends the grad tensor to the
|
||||
previous member in pipeline, while recieves the input tensor from the
|
||||
previous member in pipeline.
|
||||
|
@ -201,7 +256,8 @@ def send_backward_recv_forward(input_tensor_grad,
|
|||
recv_prev=recv_prev,
|
||||
recv_prev_shape=input_tensor_shape,
|
||||
prev_rank=prev_rank,
|
||||
dtype=dtype)
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors)
|
||||
return input_tensor
|
||||
|
||||
|
||||
|
@ -210,7 +266,8 @@ def send_forward_recv_forward(output_tensor,
|
|||
recv_prev=True,
|
||||
prev_rank=None,
|
||||
next_rank=None,
|
||||
dtype=torch.float):
|
||||
dtype=torch.float,
|
||||
scatter_gather_tensors=False):
|
||||
"""Batched communication operation. Sends the input tensor to the
|
||||
next member in pipeline, while recieves the input tensor from the
|
||||
previous member in pipeline.
|
||||
|
@ -227,7 +284,8 @@ def send_forward_recv_forward(output_tensor,
|
|||
recv_prev_shape=input_tensor_shape,
|
||||
prev_rank=prev_rank,
|
||||
next_rank=next_rank,
|
||||
dtype=dtype)
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors)
|
||||
return input_tensor
|
||||
|
||||
|
||||
|
@ -236,7 +294,8 @@ def send_backward_recv_backward(input_tensor_grad,
|
|||
recv_next=True,
|
||||
prev_rank=None,
|
||||
next_rank=None,
|
||||
dtype=torch.float):
|
||||
dtype=torch.float,
|
||||
scatter_gather_tensors=False):
|
||||
"""Batched communication operation. Sends the grad tensor to the
|
||||
previous member in pipeline, while recieves the grad tensor from the
|
||||
next member in pipeline.
|
||||
|
@ -253,7 +312,8 @@ def send_backward_recv_backward(input_tensor_grad,
|
|||
recv_next_shape=output_grad_shape,
|
||||
prev_rank=prev_rank,
|
||||
next_rank=next_rank,
|
||||
dtype=dtype)
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors)
|
||||
return output_tensor_grad
|
||||
|
||||
|
||||
|
@ -265,7 +325,8 @@ def send_forward_backward_recv_forward_backward(output_tensor,
|
|||
recv_next=True,
|
||||
prev_rank=None,
|
||||
next_rank=None,
|
||||
dtype=torch.float):
|
||||
dtype=torch.float,
|
||||
scatter_gather_tensors=False):
|
||||
"""Batched communication operation. Sends the input tensor to the next and
|
||||
the grad tensor to the previous, while recieves the grad tensor from the
|
||||
next and the input tensor from the previous.
|
||||
|
@ -290,5 +351,6 @@ def send_forward_backward_recv_forward_backward(output_tensor,
|
|||
recv_next_shape=output_grad_shape,
|
||||
prev_rank=prev_rank,
|
||||
next_rank=next_rank,
|
||||
dtype=dtype)
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors)
|
||||
return input_tensor, output_tensor_grad
|
||||
|
|
|
@ -62,3 +62,31 @@ def recv_tensor_meta(tensor_shape, prev_rank=None):
|
|||
tensor_shape = torch.Size(recv_shape)
|
||||
|
||||
return tensor_shape
|
||||
|
||||
|
||||
def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
|
||||
"""Break a tensor into equal 1D chunks."""
|
||||
partition_size = torch.numel(tensor) // gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||
start_index = partition_size * gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
end_index = start_index + partition_size
|
||||
if new_buffer:
|
||||
data = torch.empty(partition_size, dtype=tensor.dtype,
|
||||
device=torch.cuda.current_device(),
|
||||
requires_grad=False)
|
||||
data.copy_(tensor.view(-1)[start_index:end_index])
|
||||
else:
|
||||
data = tensor.view(-1)[start_index:end_index]
|
||||
return data
|
||||
|
||||
|
||||
def gather_split_1d_tensor(tensor):
|
||||
"""Opposite of above function, gather values from model parallel ranks."""
|
||||
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||
numel = torch.numel(tensor)
|
||||
numel_gathered = world_size * numel
|
||||
gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
|
||||
device=torch.cuda.current_device(),
|
||||
requires_grad=False)
|
||||
chunks = [gathered[i*numel:(i+1)*numel] for i in range(world_size)]
|
||||
dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
return gathered
|
||||
|
|
|
@ -6,7 +6,7 @@ import inspect
|
|||
import torch.cuda
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.communication import *
|
||||
import colossalai.communication as comm
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.amp.naive_amp import NaiveAMPModel
|
||||
|
@ -33,16 +33,22 @@ class PipelineSchedule(BaseSchedule):
|
|||
:type num_microbatches: int
|
||||
:param batch_data_process_func: The preprocessing function which receives a batch of data, and it will be executed in `load_batch`
|
||||
:type batch_data_process_func: Callable
|
||||
:param scatter_gather_tensors: If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization
|
||||
:type scatter_gather_tensors: bool
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_microbatches,
|
||||
batch_data_process_func: Callable = None,
|
||||
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None):
|
||||
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
|
||||
scatter_gather_tensors: bool = False):
|
||||
super().__init__(batch_data_process_func=batch_data_process_func)
|
||||
self.num_microbatches = num_microbatches
|
||||
self.dtype = torch.float
|
||||
self.tensor_shape = tensor_shape
|
||||
self.scatter_gather_tensors = False
|
||||
if gpc.is_initialized(ParallelMode.PARALLEL_1D) and gpc.get_world_size(ParallelMode.PARALLEL_1D) > 1:
|
||||
self.scatter_gather_tensors = scatter_gather_tensors
|
||||
|
||||
def load_batch(self, data_iter):
|
||||
# Pipeline schedule just puts data in memory
|
||||
|
@ -227,8 +233,9 @@ class PipelineSchedule(BaseSchedule):
|
|||
# Run warmup forward passes.
|
||||
for i in range(num_warmup_microbatches):
|
||||
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
ft_shape = recv_tensor_meta(ft_shape)
|
||||
input_tensor = recv_forward(ft_shape, dtype=self.dtype)
|
||||
ft_shape = comm.recv_tensor_meta(ft_shape)
|
||||
input_tensor = comm.recv_forward(ft_shape, dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
output_tensor = self.forward_step(
|
||||
engine, input_tensor, return_tensors,
|
||||
return_output_label=return_output_label,
|
||||
|
@ -236,8 +243,8 @@ class PipelineSchedule(BaseSchedule):
|
|||
)
|
||||
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
bt_shape = output_tensor.shape
|
||||
fs_checker = send_tensor_meta(output_tensor, fs_checker)
|
||||
send_forward(output_tensor)
|
||||
fs_checker = comm.send_tensor_meta(output_tensor, fs_checker)
|
||||
comm.send_forward(output_tensor, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
|
||||
if not forward_only:
|
||||
input_tensors.append(input_tensor)
|
||||
|
@ -248,8 +255,9 @@ class PipelineSchedule(BaseSchedule):
|
|||
# receive this tensor here.
|
||||
if num_microbatches_remaining > 0:
|
||||
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
ft_shape = recv_tensor_meta(ft_shape)
|
||||
input_tensor = recv_forward(ft_shape, dtype=self.dtype)
|
||||
ft_shape = comm.recv_tensor_meta(ft_shape)
|
||||
input_tensor = comm.recv_forward(ft_shape, dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
|
||||
# Run 1F1B in steady state.
|
||||
for i in range(num_microbatches_remaining):
|
||||
|
@ -261,14 +269,15 @@ class PipelineSchedule(BaseSchedule):
|
|||
accum_loss=accum_loss
|
||||
)
|
||||
if forward_only:
|
||||
send_forward(output_tensor)
|
||||
comm.send_forward(output_tensor, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
|
||||
if not last_iteration:
|
||||
input_tensor = recv_forward(ft_shape, dtype=self.dtype)
|
||||
input_tensor = comm.recv_forward(ft_shape, dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
|
||||
else:
|
||||
output_tensor_grad = send_forward_recv_backward(
|
||||
output_tensor, bt_shape, dtype=self.dtype)
|
||||
output_tensor_grad = comm.send_forward_recv_backward(
|
||||
output_tensor, bt_shape, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
|
||||
# Add input_tensor and output_tensor to end of list.
|
||||
input_tensors.append(input_tensor)
|
||||
|
@ -287,10 +296,10 @@ class PipelineSchedule(BaseSchedule):
|
|||
|
||||
if last_iteration:
|
||||
input_tensor = None
|
||||
send_backward(input_tensor_grad)
|
||||
comm.send_backward(input_tensor_grad, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
else:
|
||||
input_tensor = send_backward_recv_forward(
|
||||
input_tensor_grad, ft_shape, dtype=self.dtype)
|
||||
input_tensor = comm.send_backward_recv_forward(
|
||||
input_tensor_grad, ft_shape, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
|
||||
# Run cooldown backward passes.
|
||||
if not forward_only:
|
||||
|
@ -298,7 +307,8 @@ class PipelineSchedule(BaseSchedule):
|
|||
input_tensor = input_tensors.pop(0)
|
||||
output_tensor = output_tensors.pop(0)
|
||||
|
||||
output_tensor_grad = recv_backward(bt_shape, dtype=self.dtype)
|
||||
output_tensor_grad = comm.recv_backward(bt_shape, dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
|
||||
input_tensor_grad = self.backward_step(
|
||||
engine,
|
||||
|
@ -306,7 +316,7 @@ class PipelineSchedule(BaseSchedule):
|
|||
output_tensor_grad
|
||||
)
|
||||
|
||||
send_backward(input_tensor_grad)
|
||||
comm.send_backward(input_tensor_grad, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
|
||||
if len(return_tensors) > 0:
|
||||
output, label = tuple(map(list, zip(*return_tensors)))
|
||||
|
@ -322,7 +332,8 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
|||
num_microbatches,
|
||||
num_model_chunks,
|
||||
batch_data_process_func: Callable = None,
|
||||
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None):
|
||||
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
|
||||
scatter_gather_tensors: bool = False):
|
||||
"""A helper schedule class for pipeline parallelism running environment.
|
||||
It uses interleaved 1F1B strategy. Other properties are similar as
|
||||
:class:`NonPipelineSchedule`.
|
||||
|
@ -333,10 +344,13 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
|||
:type num_model_chunks: int
|
||||
:param batch_data_process_func: The preprocessing function which receives a batch of data, and it will be executed in `load_batch`
|
||||
:type batch_data_process_func: Callable
|
||||
:param scatter_gather_tensors: If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization
|
||||
:type scatter_gather_tensors: bool
|
||||
"""
|
||||
assert num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0, \
|
||||
'num_microbatches must be an integer multiple of pipeline parallel world size'
|
||||
super().__init__(num_microbatches, batch_data_process_func=batch_data_process_func, tensor_shape=tensor_shape)
|
||||
super().__init__(num_microbatches, batch_data_process_func=batch_data_process_func,
|
||||
tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather_tensors)
|
||||
gpc.set_virtual_pipeline_parallel_size(num_model_chunks)
|
||||
gpc.set_virtual_pipeline_parallel_rank(0)
|
||||
self.num_model_chunks = num_model_chunks
|
||||
|
@ -494,15 +508,16 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
|||
# Run warmup forward passes.
|
||||
gpc.set_virtual_pipeline_parallel_rank(0)
|
||||
if not gpc.is_pipeline_first_stage():
|
||||
input_tensor_shapes[0] = recv_tensor_meta(input_tensor_shapes[0])
|
||||
input_tensors[0].append(recv_forward(input_tensor_shapes[0], dtype=self.dtype))
|
||||
input_tensor_shapes[0] = comm.recv_tensor_meta(input_tensor_shapes[0])
|
||||
input_tensors[0].append(comm.recv_forward(input_tensor_shapes[0], dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors))
|
||||
|
||||
for k in range(num_warmup_microbatches):
|
||||
model_chunk_id = get_model_chunk_id(k, forward=True)
|
||||
output_tensor = forward_step_helper(k)
|
||||
if not gpc.is_pipeline_last_stage():
|
||||
output_tensor_shapes[model_chunk_id] = output_tensor.shape
|
||||
send_tensor_shape_flags[model_chunk_id] = send_tensor_meta(
|
||||
send_tensor_shape_flags[model_chunk_id] = comm.send_tensor_meta(
|
||||
output_tensor, send_tensor_shape_flags[model_chunk_id])
|
||||
# Determine if tensor should be received from previous stage.
|
||||
next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True)
|
||||
|
@ -519,7 +534,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
|||
|
||||
with switch_virtual_pipeline_parallel_rank(next_forward_model_chunk_id):
|
||||
if not gpc.is_pipeline_first_stage():
|
||||
input_tensor_shapes[next_forward_model_chunk_id] = recv_tensor_meta(
|
||||
input_tensor_shapes[next_forward_model_chunk_id] = comm.recv_tensor_meta(
|
||||
input_tensor_shapes[next_forward_model_chunk_id])
|
||||
# Send and receive tensors as appropriate (send tensors computed
|
||||
# in this iteration; receive tensors for next iteration).
|
||||
|
@ -532,20 +547,22 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
|||
recv_next = False
|
||||
output_shape = output_tensor_shapes[num_model_chunks-1] if recv_next else None
|
||||
input_tensor, output_tensor_grad = \
|
||||
send_forward_backward_recv_forward_backward(
|
||||
comm.send_forward_backward_recv_forward_backward(
|
||||
output_tensor, input_tensor_grad,
|
||||
input_shape,
|
||||
output_shape,
|
||||
recv_prev=recv_prev, recv_next=recv_next,
|
||||
dtype=self.dtype)
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
|
||||
else:
|
||||
input_tensor = \
|
||||
send_forward_recv_forward(
|
||||
comm.send_forward_recv_forward(
|
||||
output_tensor,
|
||||
input_shape,
|
||||
recv_prev=recv_prev,
|
||||
dtype=self.dtype)
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
input_tensors[next_forward_model_chunk_id].append(input_tensor)
|
||||
|
||||
# Run 1F1B in steady state.
|
||||
|
@ -608,12 +625,13 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
|||
output_shape = output_tensor_shapes[next_backward_model_chunk_id] if recv_next else None
|
||||
# Communicate tensors.
|
||||
input_tensor, output_tensor_grad = \
|
||||
send_forward_backward_recv_forward_backward(
|
||||
comm.send_forward_backward_recv_forward_backward(
|
||||
output_tensor, input_tensor_grad,
|
||||
input_shape,
|
||||
output_shape,
|
||||
recv_prev=recv_prev, recv_next=recv_next,
|
||||
dtype=self.dtype)
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
|
||||
# Put input_tensor and output_tensor_grad in data structures in the
|
||||
# right location.
|
||||
|
@ -627,7 +645,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
|||
if not forward_only:
|
||||
if all_warmup_microbatches:
|
||||
output_tensor_grads[num_model_chunks-1].append(
|
||||
recv_backward(output_tensor_shapes[num_model_chunks-1]))
|
||||
comm.recv_backward(output_tensor_shapes[num_model_chunks-1], scatter_gather_tensors=self.scatter_gather_tensors))
|
||||
for k in range(num_microbatches_remaining, num_microbatches):
|
||||
input_tensor_grad = backward_step_helper(k)
|
||||
next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
|
||||
|
@ -639,11 +657,12 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
|||
recv_next = False
|
||||
output_shape = output_tensor_shapes[next_backward_model_chunk_id] if recv_next else None
|
||||
output_tensor_grads[next_backward_model_chunk_id].append(
|
||||
send_backward_recv_backward(
|
||||
comm.send_backward_recv_backward(
|
||||
input_tensor_grad,
|
||||
output_shape,
|
||||
recv_next=recv_next,
|
||||
dtype=self.dtype))
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors))
|
||||
|
||||
if len(return_tensors) > 0:
|
||||
output, label = tuple(map(list, zip(*return_tensors)))
|
||||
|
|
|
@ -290,9 +290,10 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
|||
# initialize amp
|
||||
amp_mode = None
|
||||
if fp16_cfg is not None and fp16_cfg.mode is not None:
|
||||
# TODO: pipeline only support NAIVE AMP
|
||||
cfg_ = fp16_cfg.copy()
|
||||
amp_mode = cfg_.pop('mode')
|
||||
if is_using_pp():
|
||||
assert amp_mode == AMP_TYPE.NAIVE, 'Pipeline only support NaiveAMP currently'
|
||||
if amp_mode == AMP_TYPE.NAIVE:
|
||||
cfg_['clip_grad'] = clip_grad_norm
|
||||
model, optimizer, criterion = convert_to_amp(model=model,
|
||||
|
|
Loading…
Reference in New Issue