mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] fix runtime apply memory estimation (#2281)
* [autoparallel] align the data_ptr with the old version of auto activation checkpoint pipeline * [autoparallel] using fwd_time and bwd_time instead of fwd_flop and bwd_flop * [autoparallel] specifycomm nodes' memory cost in construct chain * [autoparallel] fix wrong runtime apply calculation * [autoparallel] fix wrong runtime apply calculation * [autoparallel] fix wrong runtime apply calculationpull/2288/head
parent
8e8900ff3f
commit
22e947f982
|
@ -441,6 +441,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||||
if discard_input:
|
if discard_input:
|
||||||
alloc_numel -= input_numel
|
alloc_numel -= input_numel
|
||||||
|
|
||||||
|
return alloc_numel, peak_numel
|
||||||
|
|
||||||
def split_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
|
def split_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
|
||||||
"""analyze split memory footprint
|
"""analyze split memory footprint
|
||||||
split will allocate memory for the output tensor if we don't apply shard on the first dimension of
|
split will allocate memory for the output tensor if we don't apply shard on the first dimension of
|
||||||
|
@ -478,11 +480,13 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||||
# kind of weird, and I think we could ignore it for now.
|
# kind of weird, and I think we could ignore it for now.
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
return alloc_numel, peak_numel
|
||||||
|
|
||||||
def reduce_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
|
def reduce_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
|
||||||
"""
|
"""
|
||||||
a dummy function for reduce memory footprint analysis, as the reduce action doesn't allocate extra memory
|
a dummy function for reduce memory footprint analysis, as the reduce action doesn't allocate extra memory
|
||||||
"""
|
"""
|
||||||
pass
|
return alloc_numel, peak_numel
|
||||||
|
|
||||||
def all2all_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
|
def all2all_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
|
||||||
"""analyze all_to_all memory footprint
|
"""analyze all_to_all memory footprint
|
||||||
|
@ -508,11 +512,13 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||||
if discard_input:
|
if discard_input:
|
||||||
alloc_numel -= input_numel
|
alloc_numel -= input_numel
|
||||||
|
|
||||||
|
return alloc_numel, peak_numel
|
||||||
|
|
||||||
def identity_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
|
def identity_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
|
||||||
"""
|
"""
|
||||||
a dummy function for identity memory footprint analysis, as the identity action doesn't allocate extra memory
|
a dummy function for identity memory footprint analysis, as the identity action doesn't allocate extra memory
|
||||||
"""
|
"""
|
||||||
pass
|
return alloc_numel, peak_numel
|
||||||
|
|
||||||
pattern_to_func_dict = {
|
pattern_to_func_dict = {
|
||||||
CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: [gather_analysis, split_analysis],
|
CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: [gather_analysis, split_analysis],
|
||||||
|
@ -539,17 +545,18 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||||
for idx, action_spec_pair in enumerate(zip(fwd_actions, comm_action_sequence)):
|
for idx, action_spec_pair in enumerate(zip(fwd_actions, comm_action_sequence)):
|
||||||
# the first forward comm action will not discard input
|
# the first forward comm action will not discard input
|
||||||
fwd_action, comm_spec = action_spec_pair
|
fwd_action, comm_spec = action_spec_pair
|
||||||
if idx == 0:
|
fwd_alloc_numel, fwd_peak_numel = fwd_action(comm_spec, False, fwd_alloc_numel,
|
||||||
fwd_action(comm_spec, False, fwd_alloc_numel, fwd_peak_numel)
|
fwd_peak_numel) if idx == 0 else fwd_action(
|
||||||
else:
|
comm_spec, True, fwd_alloc_numel, fwd_peak_numel)
|
||||||
fwd_action(comm_spec, True, fwd_alloc_numel, fwd_peak_numel)
|
|
||||||
|
|
||||||
# analyze memory footprint for backward comm actions sequence
|
# analyze memory footprint for backward comm actions sequence
|
||||||
bwd_alloc_numel = 0
|
bwd_alloc_numel = 0
|
||||||
bwd_peak_numel = 0
|
bwd_peak_numel = 0
|
||||||
for idx, action_spec_pair in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))):
|
for idx, action_spec_pair in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))):
|
||||||
bwd_action, comm_spec = action_spec_pair
|
bwd_action, comm_spec = action_spec_pair
|
||||||
bwd_action(comm_spec, True, bwd_alloc_numel, bwd_peak_numel)
|
bwd_alloc_numel, bwd_peak_numel = bwd_action(comm_spec, False, bwd_alloc_numel,
|
||||||
|
bwd_peak_numel) if idx == 0 else bwd_action(
|
||||||
|
comm_spec, True, bwd_alloc_numel, bwd_peak_numel)
|
||||||
|
|
||||||
fwd_mem = MemoryCost(activation=fwd_alloc_numel, temp=fwd_peak_numel - fwd_alloc_numel)
|
fwd_mem = MemoryCost(activation=fwd_alloc_numel, temp=fwd_peak_numel - fwd_alloc_numel)
|
||||||
bwd_mem = MemoryCost(activation=bwd_alloc_numel, temp=bwd_peak_numel - bwd_alloc_numel)
|
bwd_mem = MemoryCost(activation=bwd_alloc_numel, temp=bwd_peak_numel - bwd_alloc_numel)
|
||||||
|
|
Loading…
Reference in New Issue