diff --git a/colossalai/auto_parallel/tensor_shard/sharding_strategy.py b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py index 4929e09ad..6af927272 100644 --- a/colossalai/auto_parallel/tensor_shard/sharding_strategy.py +++ b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py @@ -6,7 +6,7 @@ from typing import Any, Dict, List, Tuple, Union import torch from torch.fx.node import Node -from colossalai.tensor.shape_consistency import CommSpec +from colossalai.tensor.comm_spec import CommSpec from colossalai.tensor.sharding_spec import ShardingSpec from .constants import ( diff --git a/colossalai/tensor/shape_consistency.py b/colossalai/tensor/shape_consistency.py index d566e3515..144712fc5 100644 --- a/colossalai/tensor/shape_consistency.py +++ b/colossalai/tensor/shape_consistency.py @@ -3,8 +3,10 @@ from copy import deepcopy from dataclasses import dataclass from typing import Dict, List, Tuple +import numpy as np import torch +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem from colossalai.context.singleton_meta import SingletonMeta from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, mix_gather_simulator, shard_simulator @@ -403,6 +405,158 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): valid_spec_dict.update(self.get_all_shard_spec(source_spec, orig_cost_dict)) return valid_spec_dict + def mem_cost(self, comm_action_sequence: List[CommSpec]) -> TrainCycleItem: + """memory cost of the communication action sequence + TODO: Currently we just consider tensor numel in the shape consistency manger, + as the manager itself doesn't have the access to tensor dtype, we need to take + it into consideration in memory estimation. + + Args: + comm_action_sequence (List[CommSpec]): list of communication actions + + Returns: + TrainCycleItem: memory (numel) cost of such comm_action_sequence + """ + + def compute_shape(sharding_spec: ShardingSpec): + shape = sharding_spec.entire_shape + for dim, shard in sharding_spec.dim_partition_dict.items(): + shape[dim] = shape[dim] // len(shard) + return shape + + def gather_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): + """analyze all_gather memory footprint + all_gather will allocate memory for the output tensor, and there will be temp memory for + all_gather operation, which is twice the size of output tensor + + Args: + comm_spec (CommSpec): input CommSpec + discard_input (bool): whether to discard the input tensor + alloc_numel (int): current allocated numel + peak_numel (int): current peak numel + """ + input_shape = compute_shape(comm_spec.sharding_spec) + input_numel = np.prod(input_shape) + output_numel = input_numel * comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis] + peak_numel = max(peak_numel, alloc_numel + output_numel * 2) + alloc_numel += output_numel + if discard_input: + alloc_numel -= input_numel + + def split_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): + """analyze split memory footprint + split will allocate memory for the output tensor if we don't apply shard on the first dimension of + the input tensor. If we apply shard on the first dimension, the `torch.tensor.contiguous()` will not + generate new tensor in this case, so no memory will be allocated. + + Args: + comm_spec (CommSpec): input CommSpec + discard_input (bool): whether to discard the input tensor + alloc_numel (int): current allocated numel + peak_numel (int): current peak numel + """ + shard_dim = comm_spec.shard_dim + if shard_dim != 0: + # if we don't shard the tensor on the first dimension, the split action will + # generate a new tensor + input_shape = compute_shape(comm_spec.sharding_spec) + input_numel = np.prod(input_shape) + output_numel = input_numel // comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axes] + alloc_numel += output_numel + peak_numel = max(peak_numel, alloc_numel) + if discard_input: + alloc_numel -= input_numel + else: + # if we shard the tensor on the first dimension, the split action will not generate + # a new tensor, and as it will preserve a reference to the input tensor, we could + # override the discard_input option here + # NOTE: this special case might fail in some weird cases, e.g. if we have three split + # actions in the comm actions sequence, the first split action operate on the second dimension, + # the second split action operate on the first dimension, and the third split action operate, again, + # on the second dimension. Therefore, after the first two actions in the sequence, we will allocate + # memory the same size as the output of first split action. However, the third split action will discard + # the input tensor, and it actually should discard the tensor generated by the first split action, so in + # the current memory estimation framework, we will overestimate the memory usage. But the above case is + # kind of weird, and I think we could ignore it for now. + pass + + def reduce_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): + """ + a dummy function for reduce memory footprint analysis, as the reduce action doesn't allocate extra memory + """ + pass + + def all2all_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): + """analyze all_to_all memory footprint + all_to_all will allocate memory for the output tensor, and temp memory of all_to_all action + is twice the size of output tensor if we shard input tensor on the first dimension, otherwise + the temp memory is three times the size of output tensor + + Args: + comm_spec (CommSpec): input CommSpec + discard_input (bool): whether to discard the input tensor + alloc_numel (int): current allocated numel + peak_numel (int): current peak numel + """ + input_shape = compute_shape(comm_spec.sharding_spec) + input_numel = np.prod(input_shape) + output_numel = input_numel + shard_dim = comm_spec.shard_dim + if shard_dim != 0: + peak_numel = max(peak_numel, alloc_numel + output_numel * 3) + else: + peak_numel = max(peak_numel, alloc_numel + output_numel * 2) + alloc_numel += output_numel + if discard_input: + alloc_numel -= input_numel + + def identity_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): + """ + a dummy function for identity memory footprint analysis, as the identity action doesn't allocate extra memory + """ + pass + + pattern_to_func_dict = { + CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: [gather_analysis, split_analysis], + CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: [all2all_analysis, all2all_analysis], + CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: [split_analysis, gather_analysis], + CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: [reduce_analysis, identity_analysis], + CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: [identity_analysis, reduce_analysis], + CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD: [], + } + + fwd_actions = [] + bwd_actions = [] + + # construct forward and backward comm actions sequence + for comm_spec in comm_action_sequence: + comm_spec: CommSpec + fwd_action, bwd_action = pattern_to_func_dict[comm_spec.comm_pattern] + fwd_actions.append(fwd_action) + bwd_actions.append(bwd_action) + + # analyze memory footprint of forward comm actions sequence + fwd_alloc_numel = 0 + fwd_peak_numel = 0 + for idx, fwd_action, comm_spec in enumerate(zip(fwd_actions, comm_action_sequence)): + # the first forward comm action will not discard input + if idx == 0: + fwd_action(comm_spec, False, fwd_alloc_numel, fwd_peak_numel) + else: + fwd_action(comm_spec, True, fwd_alloc_numel, fwd_peak_numel) + + # analyze memory footprint for backward comm actions sequence + bwd_alloc_numel = 0 + bwd_peak_numel = 0 + for idx, bwd_action, comm_spec in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))): + bwd_action(comm_spec, True, bwd_alloc_numel, bwd_peak_numel) + + fwd_mem = MemoryCost(activation=fwd_alloc_numel, temp=fwd_peak_numel - fwd_alloc_numel) + bwd_mem = MemoryCost(activation=bwd_alloc_numel, temp=bwd_peak_numel - bwd_alloc_numel) + total_mem = MemoryCost(activation=fwd_alloc_numel + bwd_alloc_numel) + + return TrainCycleItem(fwd_mem, bwd_mem, total_mem) + def shape_consistency(self, source_spec: ShardingSpec, target_spec: ShardingSpec) -> Tuple[List[ShardingSpec], List[CommSpec], float]: '''