@ -1,7 +1,7 @@
import threading
from enum import Enum
from typing import List , Any , Tuple , Dict
from abc import ABC , abstractmethod
from abc import ABC
import torch
from torch import nn
@ -18,9 +18,8 @@ use_color_debug = False
use_progress = False
# TODO:
# 1. design a unique_key without node.name (Maybe I can use combination of microbatch_id and stage_id)
# 2. use waiting list to contain the uncomplete WorkItem
# 3. think about the representation of the order of args and kwargs
# 1. replace world_size with other parameters
# 2. adjust to args and kwargs
def color_debug ( text , prefix = ' ' , color = ' blue ' ) :
@ -126,33 +125,32 @@ class RemoteOptimizer:
class Worker :
def __init__ ( self ,
cur_rank_ module: nn . Module ,
rank : int ,
world_size : int ,
module_partition : nn . Module ,
pp_ rank: int ,
actual_stage_num : int ,
num_microbatches : int ,
max_outstanding : int ,
device : str ,
checkpoint : bool = False ) - > None :
super ( ) . __init__ ( )
self . rank = rank
self . world_size = world_size
self . pp_ rank = pp_ rank
self . actual_stage_num = actual_stage_num
self . num_microbatches = num_microbatches
self . max_outstanding = max_outstanding
self . outstanding = 0
self . checkpoint = checkpoint
if device == ' cuda ' :
device = f ' cuda: { rank } '
self . device = device
self . future_devices = None if device is None or device == ' cpu ' else [ device ]
self . stage _to_worker_rref: Dict [ int , PyRRef ] = None
self . pp_rank _to_worker_rref: Dict [ int , PyRRef ] = None
self . producer_stage_ids : List [ int ] = None
self . consumer_stage_ids : List [ int ] = None
# module
self . cur_rank_module = cur_rank_module . to ( device )
self . module_partition = module_partition . to ( device )
self . debug_list = [ None ] * num_microbatches
self . microbatch_id_to_backward_cache : Dict [ int , BackwardCache ] = dict ( )
@ -164,16 +162,16 @@ class Worker:
self . work_list_condition_lock = threading . Condition ( threading . Lock ( ) )
self . output_list_condition_lock = threading . Condition ( threading . Lock ( ) )
self . main_loop_thread = threading . Thread ( target = self . _work_loop , name = f ' rank_ { rank } ' , daemon = True )
self . main_loop_thread = threading . Thread ( target = self . _work_loop , name = f ' rank_ { pp_ rank} ' , daemon = True )
self . main_loop_thread . start ( )
def _get_future_by_device ( self ) :
return torch . futures . Future ( devices = None if self . device in ( None , ' cpu ' ) else [ self . device ] )
def sync_global_worker_rrefs ( self , stage _to_worker_rref: Dict [ int , PyRRef ] ) - > None :
assert self . stage _to_worker_rref is None , f " in rank { self . rank } , worker has sync global workers rrefs "
assert stage _to_worker_rref is not None , " stage_to_workers must be a dict instead of None "
self . stage_to_worker_rref = stage _to_worker_rref
def sync_global_worker_rrefs ( self , pp_rank _to_worker_rref: Dict [ int , PyRRef ] ) - > None :
assert self . pp_rank _to_worker_rref is None , f " in rank { self . pp_ rank} , worker has sync global workers rrefs "
assert pp_rank _to_worker_rref is not None , " stage_to_workers must be a dict instead of None "
self . pp_rank_to_worker_rref = pp_rank _to_worker_rref
def get_output_by_key ( self , key : UniqueKey ) - > Any :
with self . output_list_condition_lock :
@ -183,7 +181,7 @@ class Worker:
output_work_item = self . output_list [ key ]
output = output_work_item . output . wait ( )
# color_debug(f'rank {self.rank}, output {type(output)}', 'get output', 'red')
# color_debug(f'rank {self.pp_ rank}, output {type(output)}', 'get output', 'red')
output_work_item . refcount + = 1
# all consumers have been satisfied, the work_item can be released
@ -193,8 +191,13 @@ class Worker:
return output
# just for first rank
# TODO : input is args kwargs
def get_parameters ( self ) - > List [ torch . Tensor ] :
return [ p for p in self . module_partition . parameters ( ) ]
def get_parameter_gradients ( self ) - > List [ torch . Tensor ] :
return [ p . grad for p in self . module_partition . parameters ( ) ]
# just for first pp_rank
def set_input ( self , microbatch_id : int , microbatch : Tuple [ Any ] ) :
with self . work_list_condition_lock :
assert self . consumer_stage_ids is not None
@ -203,16 +206,15 @@ class Worker:
output = self . _get_future_by_device ( )
args = [ microbatch ] if isinstance ( microbatch , torch . Tensor ) else microbatch
work_item = WorkItem ( self . rank , Phase . FORWARD , args , { } , output , microbatch_id , None , self . num_microbatches ,
consumer_num )
work_item = WorkItem ( self . pp_ rank, Phase . FORWARD , args , { } , output , microbatch_id , None ,
self . num_microbatches , consumer_num )
self . work_list [ key ] = work_item
color_debug ( f ' rank { self . rank } receive data from dataloader ' , ' data dispatch ' , ' magenta ' )
color_debug ( f ' rank { self . pp_ rank} receive data from dataloader ' , ' data dispatch ' , ' magenta ' )
self . work_list_condition_lock . notify_all ( )
# just for last rank
# TODO : write a function to add gradient to work_list and see if there is contradictory
# just for last pp_rank
def _begin_backward ( self , microbatch_id : int ) :
with self . work_list_condition_lock :
assert self . producer_stage_ids is not None
@ -221,10 +223,10 @@ class Worker:
output = self . _get_future_by_device ( )
grad_wrt_loss = torch . tensor ( 1 , device = self . device )
work_item = WorkItem ( self . rank , Phase . BACKWARD , grad_wrt_loss , { } , output , microbatch_id , None ,
work_item = WorkItem ( self . pp_ rank, Phase . BACKWARD , grad_wrt_loss , { } , output , microbatch_id , None ,
self . num_microbatches , producer_num )
color_debug ( f ' rank { self . rank } propose backward ' , ' data dispatch ' , ' magenta ' )
color_debug ( f ' rank { self . pp_ rank} propose backward ' , ' data dispatch ' , ' magenta ' )
self . work_list [ key ] = work_item
self . work_list_condition_lock . notify_all ( )
@ -238,7 +240,7 @@ class Worker:
consumer_num = len ( self . consumer_stage_ids )
assert producer_num > 0 , " only stage that has producers can subscribe producers "
stage_id = self . rank
stage_id = self . pp_ rank
subscribe_forward_futures : List [ Future ] = [ None ] * producer_num
output = self . _get_future_by_device ( )
@ -246,10 +248,10 @@ class Worker:
for i in range ( producer_num ) :
producer_stage_id = self . producer_stage_ids [ i ]
producer_output_key = UniqueKey ( microbatch_id , Phase . FORWARD )
producer_worker_rref = self . stage _to_worker_rref[ producer_stage_id ]
producer_worker_rref = self . pp_rank _to_worker_rref[ producer_stage_id ]
subscribe_forward_futures [ i ] = producer_worker_rref . rpc_async ( ) . get_output_by_key ( producer_output_key )
color_debug ( f ' rank { self . rank } get { len ( subscribe_forward_futures ) } futs from its producer ' , ' data dispatch ' ,
color_debug ( f ' rank { self . pp_ rank} get { len ( subscribe_forward_futures ) } futs from its producer ' , ' data dispatch ' ,
' magenta ' )
args = [ ]
@ -261,14 +263,14 @@ class Worker:
work_item_from_producer = WorkItem ( stage_id , Phase . FORWARD , args , { } , output , microbatch_id , None ,
self . num_microbatches , consumer_num )
color_debug ( f ' rank { self . rank } get value { tensor_shape_list ( args ) } from fut ' , ' data dispatch ' , ' magenta ' )
color_debug ( f ' rank { self . pp_ rank} get value { tensor_shape_list ( args ) } from fut ' , ' data dispatch ' , ' magenta ' )
# add work_item to work_list
with self . work_list_condition_lock :
key = UniqueKey ( microbatch_id , Phase . FORWARD )
assert key not in self . work_list
self . work_list [ key ] = work_item_from_producer
color_debug (
f ' rank_ { self . rank } load a new task to its work_list { key } { work_item_from_producer . phase } data: { tensor_shape_list ( work_item_from_producer . args ) } ' ,
f ' rank_ { self . pp_ rank} load a new task to its work_list { key } { work_item_from_producer . phase } data: { tensor_shape_list ( work_item_from_producer . args ) } ' ,
' data dispatch ' , ' magenta ' )
self . work_list_condition_lock . notify_all ( )
@ -282,18 +284,18 @@ class Worker:
assert consumer_num > 0 , " only stage that has consumers can subscribe comsumers "
# TODO : is this right?
stage_id = self . rank
stage_id = self . pp_ rank
subscribe_backward_futures : List [ Future ] = [ None ] * consumer_num
output = self . _get_future_by_device ( )
color_debug ( f ' rank { self . rank } get { len ( subscribe_backward_futures ) } futs from its consumer ' , ' data dispatch ' ,
' magenta ' )
color_debug ( f ' rank { self . pp_ rank} get { len ( subscribe_backward_futures ) } futs from its consumer ' ,
' data dispatch ' , ' magenta' )
for i in range ( consumer_num ) :
consumer_stage_id = self . consumer_stage_ids [ i ]
consumer_output_key = UniqueKey ( microbatch_id , Phase . BACKWARD )
consumer_worker_rref = self . stage _to_worker_rref[ consumer_stage_id ]
consumer_worker_rref = self . pp_rank _to_worker_rref[ consumer_stage_id ]
subscribe_backward_futures [ i ] = consumer_worker_rref . rpc_async ( ) . get_output_by_key ( consumer_output_key )
args = [ ]
@ -305,7 +307,7 @@ class Worker:
work_item_from_consumer = WorkItem ( stage_id , Phase . BACKWARD , args , { } , output , microbatch_id , None ,
self . num_microbatches , producer_num )
color_debug ( f ' rank { self . rank } get value { tensor_shape_list ( args ) } from fut ' , ' data dispatch ' , ' magenta ' )
color_debug ( f ' rank { self . pp_ rank} get value { tensor_shape_list ( args ) } from fut ' , ' data dispatch ' , ' magenta ' )
# add work_item to work_list
with self . work_list_condition_lock :
@ -313,13 +315,12 @@ class Worker:
assert key not in self . work_list
self . work_list [ key ] = work_item_from_consumer
color_debug (
f ' rank_ { self . rank } load a new task to its work_list { key } { work_item_from_consumer . phase } data: { tensor_shape_list ( work_item_from_consumer . args ) } ' ,
f ' rank_ { self . pp_ rank} load a new task to its work_list { key } { work_item_from_consumer . phase } data: { tensor_shape_list ( work_item_from_consumer . args ) } ' ,
' data dispatch ' , ' magenta ' )
self . work_list_condition_lock . notify_all ( )
# TODO : fit in any type of partition of network
def _get_producer_consumer ( self ) - > None :
rank = self . rank
rank = self . pp_ rank
assert self . producer_stage_ids is None , f " all the producers of rank { rank } has been subscribed "
assert self . consumer_stage_ids is None , f " all the consumers of rank { rank } has been subscribed "
@ -332,34 +333,41 @@ class Worker:
next_rank = rank + 1
if prev_rank > = 0 :
self . producer_stage_ids . append ( prev_rank )
if next_rank < = self . world_size - 1 :
if next_rank < = self . actual_stage_num - 1 :
self . consumer_stage_ids . append ( next_rank )
def _skip_forward ( self , work_item_phase : Phase ) - > bool :
if work_item_phase == Phase . FORWARD and \
self . max_outstanding is not None and \
self . outstanding > = self . max_outstanding :
return True
return False
def _get_work_item_key ( self ) - > UniqueKey :
with self . work_list_condition_lock :
while len ( self . work_list ) == 0 :
self . work_list_condition_lock . wait ( )
# execute backward first (if backward phase in work_list)
select_work_list_key = None
for key in self . work_list :
work_item = self . work_list [ key ]
if work_item . phase == Phase . BACKWARD :
return key
if self . _skip_forward ( work_item . phase ) :
if work_item . phase == Phase . FORWARD and \
self . max_outstanding is not None and \
self . outstanding > = self . max_outstanding :
continue
else :
select_work_list_key = key
if select_work_list_key is not None and \
select_work_list_key . phase == Phase . FORWARD and \
key . phase == Phase . BACKWARD :
continue
if select_work_list_key is None :
select_work_list_key = key
else :
phase_pair = ( select_work_list_key . phase , key . phase )
# choose forward first
if phase_pair == ( Phase . BACKWARD , Phase . FORWARD ) :
select_work_list_key = key
elif phase_pair == ( Phase . FORWARD , Phase . BACKWARD ) :
continue
# choose work_item which has a smaller microbactch_id first
elif key . microbatch_id < select_work_list_key . microbatch_id :
select_work_list_key = key
return select_work_list_key
def _consume_work_item_by_phase ( self , work_item : WorkItem ) :
@ -369,7 +377,10 @@ class Worker:
microbatch_id = work_item . microbatch_id
consume_result = None
# color_debug(f'rank_{self.rank} enter consume', 'consume', 'blue')
# if self.pp_rank == 0:
# print(f"I am rank_{self.pp_rank} microbatch_id : {microbatch_id}", work_item.phase, len(self.work_list))
# color_debug(f'rank_{self.pp_rank} enter consume', 'consume', 'blue')
if phase == Phase . FORWARD :
self . outstanding + = 1
@ -381,19 +392,20 @@ class Worker:
args [ i ] = arg_obj . requires_grad_ ( )
# TODO : use process manager to acquire rank info later
is_last_stage = len ( self . consumer_stage_ids ) == 0
is_last_stage = ( self . pp_rank == self . actual_stage_num - 1 )
# last stage doesn't need to do checkpoint, for it will do backward instantly
if self . checkpoint and not is_last_stage :
with torch . no_grad ( ) :
consume_result = self . cur_rank_ module( * args , * * kwargs )
consume_result = self . module_partition ( * args , * * kwargs )
stage_outputs = None
stage_inputs = args
self . microbatch_id_to_backward_cache [ microbatch_id ] = BackwardCache ( stage_inputs ,
stage_outputs ,
checkpoint = True )
else :
# TODO : replace with *args, **kwargs and ensure the consume_result is a tuple
consume_result = self . cur_rank_module ( * args , * * kwargs )
consume_result = self . module_partition ( * args , * * kwargs )
stage_outputs = consume_result
stage_inputs = args
self . microbatch_id_to_backward_cache [ microbatch_id ] = BackwardCache ( stage_inputs ,
@ -415,15 +427,13 @@ class Worker:
stage_inputs = backward_cache . stage_inputs
grad_tensors = args
# color_debug(f'rank_{self.rank} before backward', 'consume', 'yellow')
use_checkpoint = backward_cache . checkpoint
if self . checkpoint :
stage_outputs = [ self . cur_rank_ module( * stage_inputs ) ]
if use_ checkpoint:
stage_outputs = [ self . module_partition ( * stage_inputs ) ]
autograd . backward ( stage_outputs , grad_tensors = grad_tensors )
# color_debug(f'rank_{self.rank} after backward', 'consume', 'yellow')
# collect grad of input tensor
consume_result = [ ]
for input_node in stage_inputs :
@ -453,7 +463,7 @@ class Worker:
work_item = self . work_list . pop ( work_item_key )
color_debug (
f ' rank { self . rank } get a key : { work_item_key } work_item args: { tensor_shape_list ( work_item . args ) } ' ,
f ' rank { self . pp_ rank} get a key : { work_item_key } work_item args: { tensor_shape_list ( work_item . args ) } ' ,
' work loop ' , ' green ' )
with self . output_list_condition_lock :
@ -464,11 +474,8 @@ class Worker:
consume_result = self . _consume_work_item_by_phase ( work_item )
color_debug (
f ' rank_ { self . rank } [ { work_item . phase } ] finish consuming, result is { tensor_shape_list ( consume_result ) } ' ,
f ' rank_ { self . pp_ rank} [ { work_item . phase } ] finish consuming, result is { tensor_shape_list ( consume_result ) } ' ,
' work loop ' , ' green ' )
# if work_item.stage_id == 1 and work_item.phase == Phase.BACKWARD:
# from time import sleep
# sleep(5)
work_item . output . set_result ( consume_result )
@ -479,11 +486,11 @@ class PipelineEngineBase(ABC, nn.Module):
def __init__ ( self ,
module_partitions ,
chunk ,
world_size ,
stage_num ,
num_microbatches ,
device : str ,
max_outstanding = None ,
chunk : int = 1 ,
use_interleave : bool = False ,
checkpoint : bool = False ) - > None :
super ( ) . __init__ ( )
@ -492,55 +499,86 @@ class PipelineEngineBase(ABC, nn.Module):
self . num_microbatches = num_microbatches
self . device = device
self . max_outstanding = max_outstanding
self . world_size = world_size
self . stage_num = stage_num
self . checkpoint = checkpoint
self . use_interleave = use_interleave
self . stage_to_worker_rref : Dict [ int , PyRRef ] = dict ( )
self . pp_rank_to_worker_rref : Dict [ int , PyRRef ] = dict ( )
self . _check_argument ( )
self . _create_pp_rank_to_rpc_worker_id ( )
self . _init_worker ( )
def _check_argument ( self ) :
self . virtual_stage_num = self . stage_num * self . chunk
assert self . stage_num < = torch . cuda . device_count ( ) , " stage_num must be smaller than device count! "
assert self . virtual_stage_num == len (
self . module_partitions ) , " stage_num * chunk must be equal to length of model partition! "
if self . use_interleave :
assert self . num_microbatches % self . stage_num == 0 , " if you use interleaving strategy, make sure ' num_microbatches ' is a multiple of stage_num! "
def _get_actual_stage_num ( self ) :
return self . stage_num if self . chunk == 1 else self . virtual_stage_num
def _create_pp_rank_to_rpc_worker_id ( self ) :
""" create a map from model partition to stage_id, which is useful when use_interleave is True.
e . g . If a model is splited into 4 parts , which means len ( self . module_partitions ) == 3.
stage_num is 2 , chunk is 2 , then pp_rank_to_rpc_worker_id = [ 0 , 1 , 0 , 1 ] , that means first and third part
of partitions will be moved to device 0 and the others to device 1
"""
stage_num = self . stage_num
actual_stage_num = self . _get_actual_stage_num ( )
self . pp_rank_to_rpc_worker_id = [ 0 ] * actual_stage_num
for pp_rank in range ( actual_stage_num ) :
self . pp_rank_to_rpc_worker_id [ pp_rank ] = pp_rank % stage_num
def _init_worker ( self ) :
world_size = self . world_size
actual_stage_num = self . _get_actual_stage_num ( )
max_outstanding = self . max_outstanding
checkpoint = self . checkpoint
num_microbatches = self . num_microbatches
device = self . device
# TODO : world size is correct ?
for rank in range ( world_size ) :
cur_rank_module = self . module_partitions [ rank ]
self . stage_to_worker_rref [ rank ] = rpc . remote ( rank ,
Worker ,
args = ( cur_rank_module , rank , world_size , num_microbatches ,
max_outstanding , device , checkpoint ) )
for pp_rank in range ( actual_stage_num ) :
module_partition = self . module_partitions [ pp_rank ]
rpc_worker_id = self . pp_rank_to_rpc_worker_id [ pp_rank ]
if device [ : 4 ] == ' cuda ' :
device = f ' cuda: { rpc_worker_id } '
self . pp_rank_to_worker_rref [ pp_rank ] = rpc . remote ( rpc_worker_id ,
Worker ,
args = ( module_partition , pp_rank , actual_stage_num ,
num_microbatches , max_outstanding , device ,
checkpoint ) )
# let each worker know global worker rref (include itself)
for rank in range ( world_size ) :
self . stage_to_worker_rref [ rank ] . rpc_sync ( ) . sync_global_worker_rrefs ( self . stage_to_worker_rref )
@abstractmethod
def forward_backward ( self ) :
pass
for pp_rank in range ( actual_stage_num ) :
self . pp_rank_to_worker_rref [ pp_rank ] . rpc_sync ( ) . sync_global_worker_rrefs ( self . pp_rank_to_worker_rref )
def remote_parameters ( self ) - > Dict [ int , List [ torch . Tensor ] ] :
parameters = { }
for stage_id in self . pp_rank_to_worker_rref :
parameters [ stage_id ] = [ ]
worker_rref = self . pp_rank_to_worker_rref [ stage_id ]
for p in worker_rref . rpc_sync ( ) . get_parameters ( ) :
parameters [ stage_id ] . append ( p )
return parameters
def remote_grad ( self ) - > Dict [ int , List [ torch . Tensor ] ] :
grads = { }
for stage_id in self . pp_rank_to_worker_rref :
grads [ stage_id ] = [ ]
worker_rref = self . pp_rank_to_worker_rref [ stage_id ]
for grad in worker_rref . rpc_sync ( ) . get_parameter_gradients ( ) :
grads [ stage_id ] . append ( grad )
return grads
class FillDrainPipelineEngine ( PipelineEngineBase ) :
def __init__ ( self ,
module_partitions ,
chunk ,
world_size ,
num_microbatches ,
device : str ,
max_outstanding = None ,
use_interleave : bool = False ,
checkpoint : bool = False ) - > None :
super ( ) . __init__ ( module_partitions , chunk , world_size , num_microbatches , device , max_outstanding ,
use_interleave , checkpoint )
# TODO : adjust to args and kwargs
def forward_backward ( self , batch : torch . Tensor ) :
first_stage_worker = self . stage _to_worker_rref[ 0 ]
first_stage_worker = self . pp_rank_to_worker_rref [ 0 ]
microbatch_size = len ( batch ) / / self . num_microbatches
actual_stage_num = self . _get_actual_stage_num ( )
microbatch_iter = range ( self . num_microbatches )
if use_progress :
@ -550,31 +588,63 @@ class FillDrainPipelineEngine(PipelineEngineBase):
microbatch = batch [ microbatch_size * microbatch_id : microbatch_size * ( microbatch_id + 1 ) ]
# forward subscribe asynchronously
for rank in range ( 1 , self . world_size , 1 ) :
worker_rref = self . stage _to_worker_rref[ rank ]
for pp_ rank in range ( 1 , actual_stage_num , 1 ) :
worker_rref = self . pp_rank _to_worker_rref[ pp_ rank]
worker_rref . rpc_async ( ) . subscribe_producer ( microbatch_id )
# backward subscribe asynchronously
for rank in range ( self . world_size - 2 , - 1 , - 1 ) :
worker_rref = self . stage _to_worker_rref[ rank ]
for pp_ rank in range ( actual_stage_num - 2 , - 1 , - 1 ) :
worker_rref = self . pp_rank _to_worker_rref[ pp_ rank]
worker_rref . rpc_async ( ) . subscribe_consumer ( microbatch_id )
# run one microbatch
first_stage_worker . rpc_sync ( ) . set_input ( microbatch_id , microbatch )
# wait forward
# TODO : all the node to output
forward_result = None
last_worker_rref = self . pp_rank_to_worker_rref [ actual_stage_num - 1 ]
for microbatch_id in range ( self . num_microbatches ) :
key = UniqueKey ( microbatch_id , Phase . FORWARD )
ret = last_worker_rref . rpc_sync ( ) . get_output_by_key ( key )
if forward_result is None :
forward_result = [ [ ] ] * len ( ret )
for i in range ( len ( forward_result ) ) :
forward_result [ i ] . append ( ret [ i ] )
# wait for last backward in rank0
key = UniqueKey ( self . num_microbatches - 1 , Phase . BACKWARD )
first_stage_worker . rpc_sync ( ) . get_output_by_key ( key )
return forward_result
class OneFOneBPipelineEngine ( FillDrainPipelineEngine ) :
class FillDrain PipelineEngine( PipelineEngineBas e ) :
def __init__ ( self ,
module_partitions ,
chunk ,
world_size ,
num_microbatches ,
module_partitions : List [ nn . Module ] ,
stage_num : int ,
num_microbatches : int ,
device : str ,
chunk : int = 1 ,
use_interleave : bool = False ,
checkpoint : bool = False ) - > None :
max_outstanding = None
super ( ) . __init__ ( module_partitions , stage_num , num_microbatches , device , max_outstanding , chunk , use_interleave ,
checkpoint )
class OneFOneBPipelineEngine ( PipelineEngineBase ) :
def __init__ ( self ,
module_partitions : List [ nn . Module ] ,
stage_num : int ,
num_microbatches : int ,
device : str ,
max_outstanding = None ,
chunk : int = 1 ,
use_interleave : bool = False ,
checkpoint : bool = False ) - > None :
if max_outstanding is None :
max_outstanding = world_size
super ( ) . __init__ ( module_partitions , chunk , world_size , num_microbatches , device , max_outstanding ,
use_interleave , checkpoint )
max_outstanding = len ( module_partitions )
super ( ) . __init__ ( module_partitions , stage_num , num_microbatches , device , max_outstanding , chunk , use_interleave ,
checkpoint )