[autoparallel] fix runtime apply memory estimation (#2281)

* [autoparallel] align the data_ptr with the old version of auto activation checkpoint pipeline

* [autoparallel] using fwd_time and bwd_time instead of fwd_flop and bwd_flop

* [autoparallel] specifycomm nodes' memory cost in construct chain

* [autoparallel] fix wrong runtime apply calculation

* [autoparallel] fix wrong runtime apply calculation

* [autoparallel] fix wrong runtime apply calculation
pull/2288/head
Boyuan Yao 2 years ago committed by GitHub
parent 8e8900ff3f
commit 22e947f982
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -441,6 +441,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
if discard_input: if discard_input:
alloc_numel -= input_numel alloc_numel -= input_numel
return alloc_numel, peak_numel
def split_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): def split_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
"""analyze split memory footprint """analyze split memory footprint
split will allocate memory for the output tensor if we don't apply shard on the first dimension of split will allocate memory for the output tensor if we don't apply shard on the first dimension of
@ -478,11 +480,13 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
# kind of weird, and I think we could ignore it for now. # kind of weird, and I think we could ignore it for now.
pass pass
return alloc_numel, peak_numel
def reduce_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): def reduce_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
""" """
a dummy function for reduce memory footprint analysis, as the reduce action doesn't allocate extra memory a dummy function for reduce memory footprint analysis, as the reduce action doesn't allocate extra memory
""" """
pass return alloc_numel, peak_numel
def all2all_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): def all2all_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
"""analyze all_to_all memory footprint """analyze all_to_all memory footprint
@ -508,11 +512,13 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
if discard_input: if discard_input:
alloc_numel -= input_numel alloc_numel -= input_numel
return alloc_numel, peak_numel
def identity_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): def identity_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
""" """
a dummy function for identity memory footprint analysis, as the identity action doesn't allocate extra memory a dummy function for identity memory footprint analysis, as the identity action doesn't allocate extra memory
""" """
pass return alloc_numel, peak_numel
pattern_to_func_dict = { pattern_to_func_dict = {
CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: [gather_analysis, split_analysis], CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: [gather_analysis, split_analysis],
@ -539,17 +545,18 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
for idx, action_spec_pair in enumerate(zip(fwd_actions, comm_action_sequence)): for idx, action_spec_pair in enumerate(zip(fwd_actions, comm_action_sequence)):
# the first forward comm action will not discard input # the first forward comm action will not discard input
fwd_action, comm_spec = action_spec_pair fwd_action, comm_spec = action_spec_pair
if idx == 0: fwd_alloc_numel, fwd_peak_numel = fwd_action(comm_spec, False, fwd_alloc_numel,
fwd_action(comm_spec, False, fwd_alloc_numel, fwd_peak_numel) fwd_peak_numel) if idx == 0 else fwd_action(
else: comm_spec, True, fwd_alloc_numel, fwd_peak_numel)
fwd_action(comm_spec, True, fwd_alloc_numel, fwd_peak_numel)
# analyze memory footprint for backward comm actions sequence # analyze memory footprint for backward comm actions sequence
bwd_alloc_numel = 0 bwd_alloc_numel = 0
bwd_peak_numel = 0 bwd_peak_numel = 0
for idx, action_spec_pair in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))): for idx, action_spec_pair in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))):
bwd_action, comm_spec = action_spec_pair bwd_action, comm_spec = action_spec_pair
bwd_action(comm_spec, True, bwd_alloc_numel, bwd_peak_numel) bwd_alloc_numel, bwd_peak_numel = bwd_action(comm_spec, False, bwd_alloc_numel,
bwd_peak_numel) if idx == 0 else bwd_action(
comm_spec, True, bwd_alloc_numel, bwd_peak_numel)
fwd_mem = MemoryCost(activation=fwd_alloc_numel, temp=fwd_peak_numel - fwd_alloc_numel) fwd_mem = MemoryCost(activation=fwd_alloc_numel, temp=fwd_peak_numel - fwd_alloc_numel)
bwd_mem = MemoryCost(activation=bwd_alloc_numel, temp=bwd_peak_numel - bwd_alloc_numel) bwd_mem = MemoryCost(activation=bwd_alloc_numel, temp=bwd_peak_numel - bwd_alloc_numel)

Loading…
Cancel
Save