mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] fix runtime apply memory estimation (#2281)
* [autoparallel] align the data_ptr with the old version of auto activation checkpoint pipeline * [autoparallel] using fwd_time and bwd_time instead of fwd_flop and bwd_flop * [autoparallel] specifycomm nodes' memory cost in construct chain * [autoparallel] fix wrong runtime apply calculation * [autoparallel] fix wrong runtime apply calculation * [autoparallel] fix wrong runtime apply calculationpull/2288/head
parent
8e8900ff3f
commit
22e947f982
|
@ -441,6 +441,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
if discard_input:
|
||||
alloc_numel -= input_numel
|
||||
|
||||
return alloc_numel, peak_numel
|
||||
|
||||
def split_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
|
||||
"""analyze split memory footprint
|
||||
split will allocate memory for the output tensor if we don't apply shard on the first dimension of
|
||||
|
@ -478,11 +480,13 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
# kind of weird, and I think we could ignore it for now.
|
||||
pass
|
||||
|
||||
return alloc_numel, peak_numel
|
||||
|
||||
def reduce_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
|
||||
"""
|
||||
a dummy function for reduce memory footprint analysis, as the reduce action doesn't allocate extra memory
|
||||
"""
|
||||
pass
|
||||
return alloc_numel, peak_numel
|
||||
|
||||
def all2all_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
|
||||
"""analyze all_to_all memory footprint
|
||||
|
@ -508,11 +512,13 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
if discard_input:
|
||||
alloc_numel -= input_numel
|
||||
|
||||
return alloc_numel, peak_numel
|
||||
|
||||
def identity_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
|
||||
"""
|
||||
a dummy function for identity memory footprint analysis, as the identity action doesn't allocate extra memory
|
||||
"""
|
||||
pass
|
||||
return alloc_numel, peak_numel
|
||||
|
||||
pattern_to_func_dict = {
|
||||
CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: [gather_analysis, split_analysis],
|
||||
|
@ -539,17 +545,18 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
for idx, action_spec_pair in enumerate(zip(fwd_actions, comm_action_sequence)):
|
||||
# the first forward comm action will not discard input
|
||||
fwd_action, comm_spec = action_spec_pair
|
||||
if idx == 0:
|
||||
fwd_action(comm_spec, False, fwd_alloc_numel, fwd_peak_numel)
|
||||
else:
|
||||
fwd_action(comm_spec, True, fwd_alloc_numel, fwd_peak_numel)
|
||||
fwd_alloc_numel, fwd_peak_numel = fwd_action(comm_spec, False, fwd_alloc_numel,
|
||||
fwd_peak_numel) if idx == 0 else fwd_action(
|
||||
comm_spec, True, fwd_alloc_numel, fwd_peak_numel)
|
||||
|
||||
# analyze memory footprint for backward comm actions sequence
|
||||
bwd_alloc_numel = 0
|
||||
bwd_peak_numel = 0
|
||||
for idx, action_spec_pair in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))):
|
||||
bwd_action, comm_spec = action_spec_pair
|
||||
bwd_action(comm_spec, True, bwd_alloc_numel, bwd_peak_numel)
|
||||
bwd_alloc_numel, bwd_peak_numel = bwd_action(comm_spec, False, bwd_alloc_numel,
|
||||
bwd_peak_numel) if idx == 0 else bwd_action(
|
||||
comm_spec, True, bwd_alloc_numel, bwd_peak_numel)
|
||||
|
||||
fwd_mem = MemoryCost(activation=fwd_alloc_numel, temp=fwd_peak_numel - fwd_alloc_numel)
|
||||
bwd_mem = MemoryCost(activation=bwd_alloc_numel, temp=bwd_peak_numel - bwd_alloc_numel)
|
||||
|
|
Loading…
Reference in New Issue