@ -441,6 +441,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
if discard_input :
alloc_numel - = input_numel
return alloc_numel , peak_numel
def split_analysis ( comm_spec : CommSpec , discard_input : bool , alloc_numel : int , peak_numel : int ) :
""" analyze split memory footprint
split will allocate memory for the output tensor if we don ' t apply shard on the first dimension of
@ -478,11 +480,13 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
# kind of weird, and I think we could ignore it for now.
pass
return alloc_numel , peak_numel
def reduce_analysis ( comm_spec : CommSpec , discard_input : bool , alloc_numel : int , peak_numel : int ) :
"""
a dummy function for reduce memory footprint analysis , as the reduce action doesn ' t allocate extra memory
"""
pass
return alloc_numel , peak_numel
def all2all_analysis ( comm_spec : CommSpec , discard_input : bool , alloc_numel : int , peak_numel : int ) :
""" analyze all_to_all memory footprint
@ -508,11 +512,13 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
if discard_input :
alloc_numel - = input_numel
return alloc_numel , peak_numel
def identity_analysis ( comm_spec : CommSpec , discard_input : bool , alloc_numel : int , peak_numel : int ) :
"""
a dummy function for identity memory footprint analysis , as the identity action doesn ' t allocate extra memory
"""
pass
return alloc_numel , peak_numel
pattern_to_func_dict = {
CollectiveCommPattern . GATHER_FWD_SPLIT_BWD : [ gather_analysis , split_analysis ] ,
@ -539,17 +545,18 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
for idx , action_spec_pair in enumerate ( zip ( fwd_actions , comm_action_sequence ) ) :
# the first forward comm action will not discard input
fwd_action , comm_spec = action_spec_pair
if idx == 0 :
fwd_action ( comm_spec , False , fwd_alloc_numel , fwd_peak_numel )
else :
fwd_action ( comm_spec , True , fwd_alloc_numel , fwd_peak_numel )
fwd_alloc_numel , fwd_peak_numel = fwd_action ( comm_spec , False , fwd_alloc_numel ,
fwd_peak_numel ) if idx == 0 else fwd_action (
comm_spec , True , fwd_alloc_numel , fwd_peak_numel )
# analyze memory footprint for backward comm actions sequence
bwd_alloc_numel = 0
bwd_peak_numel = 0
for idx , action_spec_pair in enumerate ( zip ( reversed ( bwd_actions ) , reversed ( comm_action_sequence ) ) ) :
bwd_action , comm_spec = action_spec_pair
bwd_action ( comm_spec , True , bwd_alloc_numel , bwd_peak_numel )
bwd_alloc_numel , bwd_peak_numel = bwd_action ( comm_spec , False , bwd_alloc_numel ,
bwd_peak_numel ) if idx == 0 else bwd_action (
comm_spec , True , bwd_alloc_numel , bwd_peak_numel )
fwd_mem = MemoryCost ( activation = fwd_alloc_numel , temp = fwd_peak_numel - fwd_alloc_numel )
bwd_mem = MemoryCost ( activation = bwd_alloc_numel , temp = bwd_peak_numel - bwd_alloc_numel )