diff --git a/colossalai/tensor/shape_consistency.py b/colossalai/tensor/shape_consistency.py index daf81034f..2831b10a3 100644 --- a/colossalai/tensor/shape_consistency.py +++ b/colossalai/tensor/shape_consistency.py @@ -441,6 +441,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): if discard_input: alloc_numel -= input_numel + return alloc_numel, peak_numel + def split_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): """analyze split memory footprint split will allocate memory for the output tensor if we don't apply shard on the first dimension of @@ -478,11 +480,13 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): # kind of weird, and I think we could ignore it for now. pass + return alloc_numel, peak_numel + def reduce_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): """ a dummy function for reduce memory footprint analysis, as the reduce action doesn't allocate extra memory """ - pass + return alloc_numel, peak_numel def all2all_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): """analyze all_to_all memory footprint @@ -508,11 +512,13 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): if discard_input: alloc_numel -= input_numel + return alloc_numel, peak_numel + def identity_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): """ a dummy function for identity memory footprint analysis, as the identity action doesn't allocate extra memory """ - pass + return alloc_numel, peak_numel pattern_to_func_dict = { CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: [gather_analysis, split_analysis], @@ -539,17 +545,18 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): for idx, action_spec_pair in enumerate(zip(fwd_actions, comm_action_sequence)): # the first forward comm action will not discard input fwd_action, comm_spec = action_spec_pair - if idx == 0: - fwd_action(comm_spec, False, fwd_alloc_numel, fwd_peak_numel) - else: - fwd_action(comm_spec, True, fwd_alloc_numel, fwd_peak_numel) + fwd_alloc_numel, fwd_peak_numel = fwd_action(comm_spec, False, fwd_alloc_numel, + fwd_peak_numel) if idx == 0 else fwd_action( + comm_spec, True, fwd_alloc_numel, fwd_peak_numel) # analyze memory footprint for backward comm actions sequence bwd_alloc_numel = 0 bwd_peak_numel = 0 for idx, action_spec_pair in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))): bwd_action, comm_spec = action_spec_pair - bwd_action(comm_spec, True, bwd_alloc_numel, bwd_peak_numel) + bwd_alloc_numel, bwd_peak_numel = bwd_action(comm_spec, False, bwd_alloc_numel, + bwd_peak_numel) if idx == 0 else bwd_action( + comm_spec, True, bwd_alloc_numel, bwd_peak_numel) fwd_mem = MemoryCost(activation=fwd_alloc_numel, temp=fwd_peak_numel - fwd_alloc_numel) bwd_mem = MemoryCost(activation=bwd_alloc_numel, temp=bwd_peak_numel - bwd_alloc_numel)