[autoparallel] fix runtime apply memory estimation (#2281)

* [autoparallel] align the data_ptr with the old version of auto activation checkpoint pipeline

* [autoparallel] using fwd_time and bwd_time instead of fwd_flop and bwd_flop

* [autoparallel] specifycomm nodes' memory cost in construct chain

* [autoparallel] fix wrong runtime apply calculation

* [autoparallel] fix wrong runtime apply calculation

* [autoparallel] fix wrong runtime apply calculation
pull/2288/head
Boyuan Yao 2 years ago committed by GitHub
parent 8e8900ff3f
commit 22e947f982
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -441,6 +441,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
if discard_input:
alloc_numel -= input_numel
return alloc_numel, peak_numel
def split_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
"""analyze split memory footprint
split will allocate memory for the output tensor if we don't apply shard on the first dimension of
@ -478,11 +480,13 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
# kind of weird, and I think we could ignore it for now.
pass
return alloc_numel, peak_numel
def reduce_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
"""
a dummy function for reduce memory footprint analysis, as the reduce action doesn't allocate extra memory
"""
pass
return alloc_numel, peak_numel
def all2all_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
"""analyze all_to_all memory footprint
@ -508,11 +512,13 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
if discard_input:
alloc_numel -= input_numel
return alloc_numel, peak_numel
def identity_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
"""
a dummy function for identity memory footprint analysis, as the identity action doesn't allocate extra memory
"""
pass
return alloc_numel, peak_numel
pattern_to_func_dict = {
CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: [gather_analysis, split_analysis],
@ -539,17 +545,18 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
for idx, action_spec_pair in enumerate(zip(fwd_actions, comm_action_sequence)):
# the first forward comm action will not discard input
fwd_action, comm_spec = action_spec_pair
if idx == 0:
fwd_action(comm_spec, False, fwd_alloc_numel, fwd_peak_numel)
else:
fwd_action(comm_spec, True, fwd_alloc_numel, fwd_peak_numel)
fwd_alloc_numel, fwd_peak_numel = fwd_action(comm_spec, False, fwd_alloc_numel,
fwd_peak_numel) if idx == 0 else fwd_action(
comm_spec, True, fwd_alloc_numel, fwd_peak_numel)
# analyze memory footprint for backward comm actions sequence
bwd_alloc_numel = 0
bwd_peak_numel = 0
for idx, action_spec_pair in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))):
bwd_action, comm_spec = action_spec_pair
bwd_action(comm_spec, True, bwd_alloc_numel, bwd_peak_numel)
bwd_alloc_numel, bwd_peak_numel = bwd_action(comm_spec, False, bwd_alloc_numel,
bwd_peak_numel) if idx == 0 else bwd_action(
comm_spec, True, bwd_alloc_numel, bwd_peak_numel)
fwd_mem = MemoryCost(activation=fwd_alloc_numel, temp=fwd_peak_numel - fwd_alloc_numel)
bwd_mem = MemoryCost(activation=bwd_alloc_numel, temp=bwd_peak_numel - bwd_alloc_numel)

Loading…
Cancel
Save