@ -1,8 +1,9 @@
from functools import partial
from typing import Any , Callable , Dict , Iterable , List , Optional , Union
from typing import Any , Callable , Dict , Iterable , List , Optional , Tuple , Union
import torch
import torch . cuda
import torch . distributed
from torch . nn import Module , ModuleList
from torch . utils . _pytree import tree_map
@ -16,6 +17,12 @@ from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_
from . base import PipelineSchedule
def _wait_p2p ( wait_handles : List [ torch . cuda . Event ] ) - > None :
if wait_handles is not None :
for req in wait_handles :
req . wait ( )
class InterleavedSchedule ( PipelineSchedule ) :
def __init__ (
self ,
@ -24,13 +31,15 @@ class InterleavedSchedule(PipelineSchedule):
num_microbatch : Optional [ int ] = None ,
microbatch_size : Optional [ int ] = None ,
enable_metadata_cache : bool = True ,
overlap_p2p : bool = True ,
) - > None :
super ( ) . __init__ ( stage_manager )
assert (
num_microbatch is not None or microbatch_size is not None
) , " Either num_microbatch or microbatch_size should be provided "
self . comm = PipelineP2PCommunication ( stage_manager )
self . comm = PipelineP2PCommunication ( stage_manager , overlap_p2p = overlap_p2p )
self . overlap_p2p = overlap_p2p
self . num_microbatch = num_microbatch
self . microbatch_size = microbatch_size
self . num_model_chunks = num_model_chunks
@ -113,14 +122,17 @@ class InterleavedSchedule(PipelineSchedule):
Returns :
int : The model chunk idx of the input microbatch_id
"""
assert microbatch_id < self . num_microbatch * self . num_model_chunks
assert (
microbatch_id < self . num_microbatch * self . num_model_chunks
) , f " microbatch_id { microbatch_id } is out of range ( { self . num_microbatch * self . num_model_chunks } ) "
microbatch_id_in_group = microbatch_id % ( self . stage_manager . num_stages * self . num_model_chunks )
model_chunk_id = microbatch_id_in_group / / self . stage_manager . num_stages
if not is_forward :
# Reverse order
model_chunk_id = self . num_model_chunks - model_chunk_id - 1
return model_chunk_id
def recv_forward ( self , model_chunk_id : int , prev_rank : int = None ) - > Any :
def recv_forward ( self , model_chunk_id : int , prev_rank : int = None ) - > Tuple [ Any , List ] :
""" Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
For interleaved 1 F1B .
@ -130,16 +142,19 @@ class InterleavedSchedule(PipelineSchedule):
Returns :
Any : The input tensor or input tensor list .
Any : The wait handles for the communication .
"""
with self . stage_manager . switch_model_chunk_id ( model_chunk_id ) :
if not self . stage_manager . is_first_stage ( ) :
input_tensor = self . comm . recv_forward ( prev_rank , metadata_recv = self . tensor_metadata_recv )
input_tensor , wait_handles = self . comm . recv_forward ( prev_rank , metadata_recv = self . tensor_metadata_recv )
if self . enable_metadata_cache and self . tensor_metadata_recv is None :
self . tensor_metadata_recv = create_send_metadata ( input_tensor )
return input_tensor
return input_tensor , wait_handles
return None , [ ]
def recv_backward ( self , model_chunk_id : int , next_rank : int = None ) - > Any :
def recv_backward ( self , model_chunk_id : int , next_rank : int = None ) - > Tuple [ Any , List ] :
""" Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
For interleaved 1 F1B .
@ -149,16 +164,20 @@ class InterleavedSchedule(PipelineSchedule):
Returns :
Any : The input gradient tensor or gradient tensor list .
Any : The wait handles for the communication .
"""
with self . stage_manager . switch_model_chunk_id ( model_chunk_id ) :
if not self . stage_manager . is_last_stage ( ) :
output_tensor_grad = self . comm . recv_backward ( next_rank , metadata_recv = self . grad_metadata_recv )
output_tensor_grad , wait_handles = self . comm . recv_backward (
next_rank , metadata_recv = self . grad_metadata_recv
)
if self . enable_metadata_cache and self . grad_metadata_recv is None :
self . grad_metadata_recv = create_send_metadata ( output_tensor_grad )
return output_tensor_grad , wait_handles
return output_tensor_grad
return None , [ ]
def send_forward ( self , model_chunk_id : int , output_tensor : Any , next_rank : int = None ) - > None :
def send_forward ( self , model_chunk_id : int , output_tensor : Any , next_rank : int = None ) - > List :
""" Sends the input tensor to the next stage in pipeline.
For interleaved 1 F1B .
@ -166,13 +185,18 @@ class InterleavedSchedule(PipelineSchedule):
model_chunk_id ( int ) : The current model chunk idx .
output_object ( Any ) : Object to be sent .
next_rank ( int , optional ) : The rank of the recipient of the tensor .
Returns :
Any : The wait handles for the communication .
"""
with self . stage_manager . switch_model_chunk_id ( model_chunk_id ) :
if not self . stage_manager . is_last_stage ( ) :
self . comm . send_forward ( output_tensor , next_rank , send_metadata = self . send_tensor_metadata )
send_handles = self . comm . send_forward ( output_tensor , next_rank , send_metadata = self . send_tensor_metadata )
self . send_tensor_metadata = not self . enable_metadata_cache
return send_handles
return [ ]
def send_backward ( self , model_chunk_id : int , input_tensor_grad : Any , prev_rank : int = None ) - > None :
def send_backward ( self , model_chunk_id : int , input_tensor_grad : Any , prev_rank : int = None ) - > List :
""" Sends the gradient tensor to the previous stage in pipeline.
For interleaved 1 F1B .
@ -180,99 +204,61 @@ class InterleavedSchedule(PipelineSchedule):
model_chunk_id ( int ) : The current model chunk idx .
input_object ( Any ) : Object to be sent .
prev_rank ( int , optional ) : The rank of the recipient of the tensor
Returns :
Any : The wait handles for the communication .
"""
with self . stage_manager . switch_model_chunk_id ( model_chunk_id ) :
if not self . stage_manager . is_first_stage ( ) :
self . comm . send_backward ( input_tensor_grad , prev_rank , send_metadata = self . send_grad_metadata )
send_handles = self . comm . send_backward (
input_tensor_grad , prev_rank , send_metadata = self . send_grad_metadata
)
self . send_grad_metadata = not self . enable_metadata_cache
return send_handles
return [ ]
def send_forward_recv_backward (
self ,
model_chunk_id_send : int ,
model_chunk_id_recv : int ,
output_tensor : Any ,
next_rank : Optional [ int ] = None ,
send_prior_fallback : Optional [ bool ] = None ,
) - > Any :
def send_forward_recv_forward (
self , model_chunk_id_send : int , model_chunk_id_recv : int , output_tensor : Any , send_first : bool = True
) - > Tuple [ Any , List ] :
with self . stage_manager . switch_model_chunk_id ( model_chunk_id_send ) :
send_data = not self . stage_manager . is_last_stage ( )
is_send = not self . stage_manager . is_last_stage ( )
with self . stage_manager . switch_model_chunk_id ( model_chunk_id_recv ) :
recv_data = not self . stage_manager . is_last_stage ( )
if send_data and recv_data :
if not self . send_forward_recv_backward and self . grad_metadata_recv is not None :
send_prior_fallback = None # must not fallback
output_tensor_grad = self . comm . send_forward_recv_backward (
output_tensor ,
next_rank ,
send_metadata = self . send_tensor_metadata ,
metadata_recv = self . grad_metadata_recv ,
send_prior_fallback = send_prior_fallback ,
)
self . send_tensor_metadata = not self . enable_metadata_cache
if self . enable_metadata_cache and self . grad_metadata_recv is None :
self . grad_metadata_recv = create_send_metadata ( output_tensor_grad )
return output_tensor_grad
is_recv = not self . stage_manager . is_first_stage ( )
input_tensor , wait_handles = self . comm . send_forward_recv_forward (
output_tensor ,
is_send ,
is_recv ,
send_metadata = self . send_tensor_metadata ,
metadata_recv = self . tensor_metadata_recv ,
send_first = send_first ,
)
# Cache metadata
self . send_tensor_metadata = not self . enable_metadata_cache and is_send
if is_recv and self . enable_metadata_cache and self . tensor_metadata_recv is None :
self . tensor_metadata_recv = create_send_metadata ( input_tensor )
return input_tensor , wait_handles
# send only or recv only
self . send_forward ( model_chunk_id_send , output_tensor )
return self . recv_backward ( model_chunk_id_recv )
def send_backward_recv_forward (
self ,
model_chunk_id_send : int ,
model_chunk_id_recv : int ,
input_tensor_grad : Any ,
prev_rank : Optional [ int ] = None ,
send_prior_fallback : Optional [ bool ] = None ,
) - > Any :
def send_backward_recv_backward (
self , model_chunk_id_send : int , model_chunk_id_recv : int , input_tensor_grad : Any , send_first : bool = True
) - > Tuple [ Any , List ] :
with self . stage_manager . switch_model_chunk_id ( model_chunk_id_send ) :
send_data = not self . stage_manager . is_first_stage ( )
is_send = not self . stage_manager . is_first_stage ( )
with self . stage_manager . switch_model_chunk_id ( model_chunk_id_recv ) :
recv_data = not self . stage_manager . is_first_stage ( )
if send_data and recv_data :
if not self . send_backward_recv_backward and self . tensor_metadata_recv is not None :
send_prior_fallback = None # must not fallback
input_tensor = self . comm . send_backward_recv_forward (
input_tensor_grad ,
prev_rank ,
send_metadata = self . send_grad_metadata ,
metadata_recv = self . tensor_metadata_recv ,
send_prior_fallback = send_prior_fallback ,
)
self . send_grad_metadata = not self . enable_metadata_cache
if self . enable_metadata_cache and self . tensor_metadata_recv is None :
self . tensor_metadata_recv = create_send_metadata ( input_tensor )
return input_tensor
# send only or recv only
self . send_backward ( model_chunk_id_send , input_tensor_grad )
return self . recv_forward ( model_chunk_id_recv )
def send_forward_recv_forward (
self , model_chunk_id_send : int , model_chunk_id_recv : int , output_tensor : Any , send_prior : bool
) :
if send_prior :
self . send_forward ( model_chunk_id_send , output_tensor )
input_tensor = self . recv_forward ( model_chunk_id_recv )
else :
input_tensor = self . recv_forward ( model_chunk_id_recv )
self . send_forward ( model_chunk_id_send , output_tensor )
return input_tensor
def send_backward_recv_backward (
self , model_chunk_id_send : int , model_chunk_id_recv : int , input_tensor_grad : Any , send_prior : bool
) :
if send_prior :
self . send_backward ( model_chunk_id_send , input_tensor_grad )
output_tensor_grad = self . recv_backward ( model_chunk_id_recv )
else :
output_tensor_grad = self . recv_backward ( model_chunk_id_recv )
self . send_backward ( model_chunk_id_send , input_tensor_grad )
return output_tensor_grad
is_recv = not self . stage_manager . is_last_stage ( )
output_tensor_grad , wait_handles = self . comm . send_backward_recv_backward (
input_tensor_grad ,
is_send ,
is_recv ,
send_metadata = self . send_grad_metadata ,
metadata_recv = self . grad_metadata_recv ,
send_first = send_first ,
)
# Cache metadata
self . send_grad_metadata = not self . enable_metadata_cache and is_send
if is_recv and self . enable_metadata_cache and self . grad_metadata_recv is None :
self . grad_metadata_recv = create_send_metadata ( output_tensor_grad )
return output_tensor_grad , wait_handles
def forward_step (
self ,
@ -294,10 +280,12 @@ class InterleavedSchedule(PipelineSchedule):
Returns :
Union [ torch . Tensor , dict ] : The intermediate output ( dict ) of the current stage . If it is the last stage , the output is the loss ( Tensor ) .
"""
# Load input ids, attention mask and labels
micro_batch = self . load_micro_batch ( model_chunk_id = model_chunk_id )
# for the first stage, input_obj is None
# for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
# for other stages, input_obj is the output of the previous stage containing hidden_states etc.
# Only attention_mask from micro_batch is used
with self . stage_manager . switch_model_chunk_id ( model_chunk_id ) :
if isinstance ( model_chunk , ModuleList ) :
@ -381,23 +369,27 @@ class InterleavedSchedule(PipelineSchedule):
if return_loss and self . stage_manager . is_last_stage ( ignore_chunk = True ) :
accum_loss = torch . scalar_tensor ( 0 , device = get_current_device ( ) )
fwd_wait_handles = [ ]
model_chunk_id = self . get_model_chunk_id ( 0 , is_forward = True )
input_obj = self . recv_forward ( model_chunk_id )
input_obj , fwd_wait_handles = self . recv_forward ( model_chunk_id )
for i in range ( self . num_microbatch * self . num_model_chunks ) :
last_iteration = i == self . num_microbatch * self . num_model_chunks - 1
last_batch = i == self . num_microbatch * self . num_model_chunks - 1
model_chunk_id = self . get_model_chunk_id ( i , is_forward = True )
# Wait until current input is received
_wait_p2p ( fwd_wait_handles )
output_obj = self . forward_step ( model_chunk , model_chunk_id , input_obj , criterion , accum_loss , outputs )
if not last_iteration :
input_obj = self . send_forward_recv_forward (
if not last_batch :
input_obj , fwd_wait_handles = self . send_forward_recv_forward (
model_chunk_id_send = model_chunk_id ,
model_chunk_id_recv = self . get_model_chunk_id ( i + 1 , is_forward = True ) ,
output_tensor = output_obj ,
send_prior = self . stage_manager . stage % 2 == 0 ,
send_first = self . stage_manager . stage % 2 == 0 ,
)
else :
self . send_forward ( model_chunk_id , output_obj )
fwd_wait_handles = self . send_forward ( model_chunk_id , output_obj )
if outputs is not None :
outputs = merge_batch ( outputs )
@ -420,7 +412,9 @@ class InterleavedSchedule(PipelineSchedule):
self . load_batch ( data_iter )
num_microbatch = self . num_microbatch * self . num_model_chunks
# Forward + until 1st backward
num_warmup_microbatch = ( self . stage_manager . num_stages - self . stage_manager . stage - 1 ) * 2
# Steps needed to reach the last chunk
num_warmup_microbatch + = ( self . num_model_chunks - 1 ) * self . stage_manager . num_stages
num_warmup_microbatch = min ( num_warmup_microbatch , num_microbatch )
num_microbatch_remaining = num_microbatch - num_warmup_microbatch
@ -435,35 +429,44 @@ class InterleavedSchedule(PipelineSchedule):
if return_loss and self . stage_manager . is_last_stage ( ignore_chunk = True ) :
accum_loss = torch . scalar_tensor ( 0 , device = get_current_device ( ) )
bwd_wait_handles = [ ]
# Get the 1st input batch
model_chunk_id = self . get_model_chunk_id ( 0 , is_forward = True )
input_obj = self . recv_forward ( model_chunk_id )
input_obj , fwd_wait_handles = self . recv_forward ( model_chunk_id )
# Run warmup forward passes.
for i in range ( num_warmup_microbatch ) :
last_iteration = i == num_warmup_microbatch - 1
last_batch = i == num_warmup_microbatch - 1
model_chunk_id = self . get_model_chunk_id ( i , is_forward = True )
# Wait for input
_wait_p2p ( fwd_wait_handles )
output_obj = self . forward_step ( model_chunk , model_chunk_id , input_obj , criterion , accum_loss , outputs )
input_objs [ model_chunk_id ] . append ( input_obj )
output_objs [ model_chunk_id ] . append ( output_obj )
if last_iteration and num_microbatch_remaining == 0 :
self . send_forward ( model_chunk_id , output_obj )
if last_batch and num_microbatch_remaining == 0 :
fwd_wait_handles = self . send_forward ( model_chunk_id , output_obj )
else :
input_obj = self . send_forward_recv_forward (
input_obj , fwd_wait_handles = self . send_forward_recv_forward (
model_chunk_id_send = model_chunk_id ,
model_chunk_id_recv = self . get_model_chunk_id ( i + 1 , is_forward = True ) ,
output_tensor = output_obj ,
send_prior = self . stage_manager . stage % 2 == 0 ,
send_first = self . stage_manager . stage % 2 == 0 ,
)
if num_microbatch_remaining > 0 :
model_chunk_id = self . get_model_chunk_id ( 0 , is_forward = False )
output_obj_grad = self . recv_backward ( model_chunk_id )
output_obj_grad , bwd_wait_handles = self . recv_backward ( model_chunk_id )
# Run 1F1B in steady state.
for i in range ( num_microbatch_remaining ) :
last_iteration = i == num_microbatch_remaining - 1
fwd_batch_id = i + num_warmup_microbatch
last_batch = i == num_microbatch_remaining - 1
model_chunk_id = self . get_model_chunk_id ( fwd_batch_id , is_forward = True )
model_chunk_id = self . get_model_chunk_id ( i + num_warmup_microbatch , is_forward = True )
# Wait for input.
_wait_p2p ( fwd_wait_handles )
output_obj = self . forward_step ( model_chunk , model_chunk_id , input_obj , criterion , accum_loss , outputs )
# Add input_obj and output_obj to end of list.
input_objs [ model_chunk_id ] . append ( input_obj )
@ -473,64 +476,75 @@ class InterleavedSchedule(PipelineSchedule):
# Pop output_obj and output_obj from the start of the list for the backward pass.
_input_obj = input_objs [ model_chunk_id ] . pop ( 0 )
_output_obj = output_objs [ model_chunk_id ] . pop ( 0 )
input_obj_grad = self . backward_step ( optimizer , _input_obj , _output_obj , output_obj_grad )
# NOTE: perform 2x communication for forward and backward
def send_forward_recv_backward ( ) :
if last_iteration and num_microbatch == num_microbatch_remaining :
model_chunk_id = self . get_model_chunk_id ( i + num_warmup_microbatch , is_forward = True )
self . send_forward ( model_chunk_id , output_obj )
# Helper functions
def send_forward_recv_forward ( ) :
if last_batch :
model_chunk_id = self . get_model_chunk_id ( fwd_batch_id , is_forward = True )
wait_handles = self . send_forward ( model_chunk_id , output_obj )
return None , wait_handles
else :
output_obj_grad = self . send_forward_recv_back ward (
model_chunk_id_send = self . get_model_chunk_id ( i + num_warmup_microbatch , is_forward = True ) ,
model_chunk_id_recv = self . get_model_chunk_id ( i + 1 , is_forward = Fals e) ,
input_obj , wait_handles = self . send_forward_recv_for ward (
model_chunk_id_send = self . get_model_chunk_id ( fwd_batch_id , is_forward = True ) ,
model_chunk_id_recv = self . get_model_chunk_id ( fwd_batch_ id + 1 , is_forward = Tru e) ,
output_tensor = output_obj ,
send_prior_fallback = self . stage_manager . stage % 2 == 0 ,
send_first = self . stage_manager . stage % 2 == 0
and i > 0 , # Receive from warmup stage first in the first batch
)
return output_obj_grad
return input_obj , wait_handles
def send_backward_recv_forward ( ) :
if last_iteration :
def send_backward_recv_backward ( ) :
no_cooldown = num_microbatch == num_microbatch_remaining
if last_batch and no_cooldown :
model_chunk_id = self . get_model_chunk_id ( i , is_forward = False )
self . send_backward ( model_chunk_id , input_obj_grad )
wait_handles = self . send_backward ( model_chunk_id , input_obj_grad )
return None , wait_handles
else :
input_obj = self . send_backward_recv_for ward (
output_obj_grad , wait_handles = self . send_backward_recv_back ward (
model_chunk_id_send = self . get_model_chunk_id ( i , is_forward = False ) ,
model_chunk_id_recv = self . get_model_chunk_id ( i + num_warmup_microbatch + 1 , is_forward = Tru e) ,
model_chunk_id_recv = self . get_model_chunk_id ( i + 1 , is_forward = Fals e) ,
input_tensor_grad = input_obj_grad ,
send_prior_fallback = self . stage_manager . stage % 2 == 0 and i > 0 ,
send_first = self . stage_manager . stage % 2 == 0 ,
)
return input_obj
return output_obj_grad , wait_handles
if self . stage_manager . stage % 2 == 0 :
output_obj_grad = send_forward_recv_backward ( )
input_obj = send_backward_recv_forward ( )
else :
input_obj = send_backward_recv_forward ( )
output_obj_grad = send_forward_recv_backward ( )
input_obj , fwd_wait_handles = send_forward_recv_forward ( )
# Wait for upstream grad
_wait_p2p ( bwd_wait_handles )
input_obj_grad = self . backward_step ( optimizer , _input_obj , _output_obj , output_obj_grad )
# NOTE: It's documented by NCCL that running two concurrent communicators (batch_isend_irecv)
# risks deadlock (https://docs.nvidia.com/deeplearning/nccl/archives/nccl_2134/user-guide/docs/usage/communicators.html)
# however in practice this works fine, and Megatron does this too
# (https://github.com/microsoft/Megatron-DeepSpeed/blob/bcedecd1ff788d4d363f3365fd396053a08d65be/megatron/core/pipeline_parallel/schedules.py#L774)
# if deadlock, call _wait_p2p(fwd_wait_handles) here
output_obj_grad , bwd_wait_handles = send_backward_recv_backward ( )
if num_microbatch_remaining == 0 :
model_chunk_id = self . get_model_chunk_id ( 0 , is_forward = False )
output_obj_grad = self . recv_backward ( model_chunk_id )
output_obj_grad , bwd_wait_handles = self . recv_backward ( model_chunk_id )
# Run cooldown backward passes.
for i in range ( num_microbatch_remaining , num_microbatch ) :
last_iteration = i == num_microbatch - 1
last_batch = i == num_microbatch - 1
model_chunk_id = self . get_model_chunk_id ( i , is_forward = False )
_input_obj = input_objs [ model_chunk_id ] . pop ( 0 )
_output_obj = output_objs [ model_chunk_id ] . pop ( 0 )
# output_obj_grad = self.recv_backward(model_chunk_id)
input_obj_grad = self . backward_step ( optimizer , _input_obj , _output_obj , output_obj_grad )
if not last_iteration :
output_obj_grad = self . send_backward_recv_backward (
# Wait for upstream grad
_wait_p2p ( bwd_wait_handles )
# backward local grads
input_obj_grad = self . backward_step ( optimizer , _input_obj , _output_obj , output_obj_grad )
if not last_batch :
output_obj_grad , bwd_wait_handles = self . send_backward_recv_backward (
model_chunk_id_send = self . get_model_chunk_id ( i , is_forward = False ) ,
model_chunk_id_recv = self . get_model_chunk_id ( i + 1 , is_forward = False ) ,
input_tensor_grad = input_obj_grad ,
send_prior = self . stage_manager . stage % 2 == 0 and i > num_microbatch_remaining ,
send_first = self . stage_manager . stage % 2 == 0 and i > num_microbatch_remaining ,
)
assert ( not self . overlap_p2p ) or len ( bwd_wait_handles ) > 0
else :
model_chunk_id = self . get_model_chunk_id ( i , is_forward = False )
self . send_backward ( model_chunk_id , input_obj_grad )
_ = self . send_backward ( model_chunk_id , input_obj_grad )
assert all ( len ( v ) == 0 for v in input_objs ) and all ( len ( v ) == 0 for v in output_objs )