@ -16,7 +16,6 @@ from torch import autograd, nn, optim
from torch . _C . _distributed_rpc import PyRRef
from torch . futures import Future
class Phase ( Enum ) :
FORWARD = 0
BACKWARD = 1
@ -136,9 +135,6 @@ class WorkerBase(ABC):
self . criterion = criterion
self . metric = metric
# middleware info
self . _is_output = False
# context to maintain loop
self . _initialize_context_container ( )
@ -190,21 +186,33 @@ class WorkerBase(ABC):
with self . output_list_condition_lock :
self . output_list_condition_lock . wait_for ( lambda : key in self . output_list )
output_work_item = self . output_list [ key ]
self . output_list . pop ( key )
output_work_item . refcount + = 1
refcount = output_work_item . refcount
output = output_work_item . output
if output_work_item . phase != Phase . INPUT :
# lifecycle management for DAG scheduler
lifecycle = len ( self . get_consumer_stage_ids ( ) )
if self . is_model_output ( ) : # an extra reference for scheduler collecting results
lifecycle + = 1
with self . output_list_condition_lock :
# all consumers have been satisfied, the work_item can be released
# or put it into work list again.
if refcount < lifecycle :
self . output_list [ key ] = output_work_item
self . output_list_condition_lock . notify_all ( )
else :
with self . output_list_condition_lock :
self . output_list [ key ] = output_work_item
self . output_list_condition_lock . notify_all ( )
if isinstance ( output , Future ) :
output = output . wait ( )
# output_work_item.refcount += 1
# TODO(jiangziyue) redesign lifecycle management for DAG scheduler
# all consumers have been satisfied, the work_item can be released
with self . output_list_condition_lock :
if output_work_item . refcount > = len ( self . consumer_stage_ids ) :
self . output_list . pop ( key )
return output
def get_parameters ( self ) - > List [ torch . Tensor ] :
return [ p for p in self . module_partition . parameters ( ) ]
@ -246,8 +254,6 @@ class WorkerBase(ABC):
raise TypeError ( f " Input batch can be only dict, list, tuple or tensor, but receive { type ( microbatch ) } " )
# just for first pp_rank
# TODO(jiangziyue) Consider whether this function should be protected by Lock in DAG env.
# TODO(jiangziyue) Define a Class for DAG.
def set_input ( self , microbatch_id : int , microbatch : Tuple [ Any ] , forward_only : bool ) :
key = UniqueKey ( microbatch_id , Phase . FORWARD )
output = self . _get_future_by_device ( )
@ -312,8 +318,7 @@ class WorkerBase(ABC):
self . work_list [ key ] = work_item
self . work_list_condition_lock . notify_all ( )
# TODO(jiangziyue) Consider whether this function should be protected by Lock in DAG env.
def subscribe_producer ( self , microbatch_id : int , forward_only : bool ) :
def _subscribe_producer ( self , microbatch_id : int , forward_only : bool ) :
"""
You should call this function asynchronously
"""
@ -328,10 +333,6 @@ class WorkerBase(ABC):
producer_worker_rref = self . pp_rank_to_worker_rref [ producer_stage_id ]
subscribe_forward_futures [ i ] = producer_worker_rref . rpc_async ( ) . get_output_by_key ( producer_output_key )
else :
with self . work_list_condition_lock :
key = UniqueKey ( microbatch_id , Phase . FORWARD )
if key in self . work_list :
return
producer_stage_ids = self . get_producer_stage_ids ( )
producer_num = len ( producer_stage_ids )
if self . need_model_input ( ) :
@ -361,10 +362,18 @@ class WorkerBase(ABC):
work_item_from_producer = WorkItem ( stage_id , Phase . FORWARD , subscribe_forward_futures , { } , output ,
microbatch_id , None , self . num_microbatches , forward_only )
# add work_item to work_list
return work_item_from_producer
# TODO(jiangziyue) Profile the side effect of the lock for lifecycle protection and consider a better one.
def subscribe_producer ( self , microbatch_id : int , forward_only : bool ) :
key = UniqueKey ( microbatch_id , Phase . FORWARD )
with self . work_list_condition_lock :
key = UniqueKey ( microbatch_id , Phase . FORWARD )
if key not in self . work_list :
# On current PP middleware design for DAG, get_output_by_key used by _subscribe_producer
# can only be executed once for every producer-consumer stage pair, which is necessary
# to count the lifecycle of work_item. So, keeping the _subscribe_producer in the same
# lock of work_item queue operation gurantees the consistency of lifecycle counter.
work_item_from_producer = self . _subscribe_producer ( microbatch_id , forward_only )
self . work_list [ key ] = work_item_from_producer
self . work_list_condition_lock . notify_all ( )
@ -444,12 +453,10 @@ class WorkerBase(ABC):
self . producer_stage_ids = self . get_producer_stage_ids ( )
self . consumer_stage_ids = self . get_consumer_stage_ids ( )
# TODO(jiangziyue) Define a Class for DAG.
def pp_rank_to_partition_id ( self , pp_rank : int , topo : Topo ) :
partition_ids = topo . get_mid_partition_ids ( )
return partition_ids [ pp_rank ]
# TODO(jiangziyue) Define a Class for DAG.
def partition_id_to_pp_rank ( self , partition_id : int , topo : Topo ) :
partition_ids = topo . get_mid_partition_ids ( )
for i , id in enumerate ( partition_ids ) :
@ -552,6 +559,9 @@ class WorkerBase(ABC):
need_input = True
return not self . is_first_stage ( ) and need_input
def is_model_output ( self ) :
return self . is_last_stage ( )
def _default_data_process_func ( self , args_kwargs ) :
if self . is_first_stage ( ) :
args = args_kwargs [ 0 ]
@ -748,7 +758,8 @@ class WorkerBase(ABC):
# move current work item to output_list to activate subscribe in advance
with self . work_list_condition_lock :
work_item = self . work_list . pop ( work_item_key )
#self.work_list_condition_lock.wait_for(lambda: work_item_key in self.work_list)
work_item = self . work_list [ work_item_key ]
with self . output_list_condition_lock :
# assert work_item_key not in self.output_list
@ -758,6 +769,8 @@ class WorkerBase(ABC):
consume_result = self . _consume_work_item_by_phase ( work_item )
work_item . output . set_result ( consume_result )
with self . work_list_condition_lock :
self . work_list . pop ( work_item_key )
# if is last step in one batch reset context and do step
if self . _is_last_step ( work_item ) :