From 24246f7aa5c4a7efb043bd61fe5b00f272f278ef Mon Sep 17 00:00:00 2001 From: Boyuan Yao <70263930+Cypher30@users.noreply.github.com> Date: Wed, 28 Dec 2022 13:37:40 +0800 Subject: [PATCH] [autoparallel] Attach input, buffer and output tensor to MetaInfo class (#2162) * [fx] metainfo class for auto parallel * [fx] add unit test for linear metainfo * [fx] fix bwd param for linear * [fx] modify unit test * [fx] modify unit test * [fx] modify import * [fx] modify import * [fx] modify import * [fx] move meta profiler to auto parallel * [fx] add conv metainfo class * [fx] restore profiler * [fx] restore meta profiler * [autoparallel] modify unit test * [fx] modify unit test * [autoparallel] add batchnorm metainfo class * [autoparallel] fix batchnorm unit test function declaration * [fx] restore profiler * [fx] add relu metainfo class * [fx] restore profiler * [autoparallel] modify metainfo input * [autoparallel] add pooling metainfo * [autoparallel] add F.linear metainfo generator * [autoparallel] add binary elementwise metainfo * [fx] recover profiler * [autoparallel] fix forward memory calculation * [autoparallel] modify constants.py * [autoparallel] remove redundant print * [autoparallel] add F.conv metainfo * [autoparallel] linear fix * [autoparallel] memory estimation for communication actions * [autoparallel] fix docstring * [autoparallel] fix variables name * [autoparallel] attach tensor to metainfo class * [autoparallel] fix dangerous try except * [autoparallel] attach memory cost to shape consistency node * [autoparallel] attach shape consistency node's metainfo to the node * [autoparallel] remove todo in shape consistency memory estimation * [autoparallel] fix the annotation --- .../meta_profiler/meta_registry/activation.py | 10 ++-- .../meta_registry/binary_elementwise_ops.py | 10 ++-- .../meta_profiler/meta_registry/conv.py | 8 +-- .../meta_profiler/meta_registry/linear.py | 8 +-- .../meta_profiler/meta_registry/norm.py | 8 +-- .../meta_profiler/meta_registry/pooling.py | 16 +++--- .../auto_parallel/meta_profiler/metainfo.py | 32 +++++++----- .../passes/runtime_apply_pass.py | 50 +++++++++++++++++++ colossalai/tensor/shape_consistency.py | 16 +++--- .../test_metainfo/test_batchnorm_metainfo.py | 2 +- .../test_metainfo/test_linear_metainfo.py | 2 +- 11 files changed, 118 insertions(+), 44 deletions(-) diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py index 7b2f8dfa4..909232e61 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py @@ -64,7 +64,11 @@ def relu_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, Lis memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) - # store fwd_in - fwd_in = [input_tensor] + # store fwd_in, fwd_buffer, fwd_out + # NOTE: It might seems a little bit weird here, we just want to align it with the older version + # of MetaInfoProp. In the future we might modify this part to make it clearer. + fwd_in = [] + fwd_buffer = [torch.zeros_like(output_tensor, device='meta')] + fwd_out = [torch.zeros_like(output_tensor, device='meta')] - return compute_cost, memory_cost, fwd_in + return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py index 0292121b6..eb8042368 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py @@ -6,7 +6,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, from colossalai.fx.profiler.memory_utils import activation_size from colossalai.fx.profiler.opcount import flop_mapping -from ..constants import BCAST_FUNC_OP +from ..constants import BCAST_FUNC_OP, NO_SAVE_ACTIVATION from ..registry import meta_register __all__ = ['binary_elementwise_meta_info'] @@ -59,7 +59,9 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) - # store fwd_in - fwd_in = fwd_in_args + # store fwd_in, fwd_buffer, fwd_out + fwd_in = [torch.zeros_like(input_op_data.data, device='meta')] + fwd_buffer = [] + fwd_out = [torch.zeros_like(output_op_data.data, device='meta')] - return compute_cost, memory_cost, fwd_in + return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py b/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py index fd6c5184a..d1bb6e7fa 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py @@ -129,7 +129,9 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) - # store fwd_in - fwd_in = [input_tensor] + # store fwd_in, fwd_buffer, fwd_out + fwd_in = [torch.zeros_like(input_tensor, device='meta')] + fwd_buffer = [] + fwd_out = [torch.zeros_like(output_tensor, device='meta')] - return compute_cost, memory_cost, fwd_in + return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py b/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py index bb7935d0f..61f8fdff3 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py @@ -164,7 +164,9 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) - # store fwd_in - fwd_in = [input_tensor] + # store fwd_in, fwd_buffer, fwd_out + fwd_in = [torch.zeros_like(input_tensor, device='meta')] + fwd_buffer = [] + fwd_out = [torch.zeros_like(output_tensor, device='meta')] - return compute_cost, memory_cost, fwd_in + return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py b/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py index b88bed88b..9b34332db 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py @@ -95,7 +95,9 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) - # store fwd_in - fwd_in = [input_tensor] + # store fwd_in, fwd_buffer, fwd_out + fwd_in = [torch.zeros_like(input_tensor, device='meta')] + fwd_buffer = [torch.zeros_like(mean_tensor, device='meta'), torch.zeros_like(var_tensor, device='meta')] + fwd_out = [torch.zeros_like(output_tensor, device='meta')] - return compute_cost, memory_cost, fwd_in + return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py b/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py index 1c04bdc73..3ecabb6dc 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py @@ -59,10 +59,12 @@ def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, mem_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) - # store_fwd_in - fwd_in = [input_tensor] + # store fwd_in, fwd_buffer, fwd_out + fwd_in = [] + fwd_buffer = [] + fwd_out = [torch.zeros_like(output_tensor, device='meta')] - return compute_cost, mem_cost, fwd_in + return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out @meta_register.register(torch.nn.MaxPool1d) @@ -122,7 +124,9 @@ def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, mem_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) - # store_fwd_in - fwd_in = [input_tensor] + # store fwd_in, fwd_buffer, fwd_out + fwd_in = [torch.zeros_like(input_tensor, device='meta')] + fwd_buffer = [torch.zeros_like(index_matrix, device='meta')] + fwd_out = [torch.zeros_like(output_tensor, device='meta')] - return compute_cost, mem_cost, fwd_in + return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out diff --git a/colossalai/auto_parallel/meta_profiler/metainfo.py b/colossalai/auto_parallel/meta_profiler/metainfo.py index b7cbc57bd..1f3463713 100644 --- a/colossalai/auto_parallel/meta_profiler/metainfo.py +++ b/colossalai/auto_parallel/meta_profiler/metainfo.py @@ -1,4 +1,4 @@ -from typing import Callable +from typing import Callable, List import numpy as np import torch @@ -33,10 +33,13 @@ class MetaInfo: self.memory_cost: TrainCycleItem # list of input tensors - self.fwd_in: list[OperationData] + self.fwd_in: List[torch.Tensor] - # bool type to indicate whether the function will save forward activation - self.save_fwd_in: bool + # list of buffer tensors + self.fwd_buffer: List[torch.Tensor] + + # list of output tensors + self.fwd_out: List[torch.Tensor] # sharding strategy self._strategy = strategy @@ -94,19 +97,20 @@ class MetaInfo: """ Compute meta info based on sharding strategy and the given target function. """ - - try: + assert meta_register.has(self._target.__class__) or meta_register.has(self._target), \ + f"Meta info for {self._target} is not registered." + if meta_register.has(self._target.__class__): # module meta_func = meta_register.get(self._target.__class__) - # check whether the target in the module list that we don't need to save activation - self.save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION - except: + # check whether the target in the list that we don't need to save activation + save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION + else: # function meta_func = meta_register.get(self._target) - # check whether the target in the module list that we don't need to save activation - self.save_fwd_in = self._target not in NO_SAVE_ACTIVATION + # check whether the target in the list that we don't need to save activation + save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION # construct args for meta_func args = [self.compute_sharded_tensor(k, v) for k, v in self._strategy.sharding_specs.items()] @@ -118,4 +122,8 @@ class MetaInfo: kwargs = {'inplace': False} # compute metainfo with meta_func - self.compute_cost, self.memory_cost, self.fwd_in = meta_func(*args, **kwargs) + self.compute_cost, self.memory_cost, self.fwd_in, self.fwd_buffer, self.fwd_out = meta_func(*args, **kwargs) + + # process corner case for NO_SAVE_ACTIVATION + if not save_fwd_in: + self.fwd_in = [] diff --git a/colossalai/auto_parallel/passes/runtime_apply_pass.py b/colossalai/auto_parallel/passes/runtime_apply_pass.py index b81402c27..caf118c89 100644 --- a/colossalai/auto_parallel/passes/runtime_apply_pass.py +++ b/colossalai/auto_parallel/passes/runtime_apply_pass.py @@ -4,11 +4,13 @@ from typing import Dict, List import torch from torch.fx.node import Node +from colossalai.auto_parallel.meta_profiler import MetaInfo from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( CommAction, CommType, OperationData, OperationDataType, + TrainCycleItem, ) from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.comm_spec import CommSpec @@ -45,6 +47,52 @@ def runtime_apply_for_iterable_object(node: Node, origin_dict: Dict, input_dict: return rst +def construct_meta_info(node: Node, user_node: Node) -> MetaInfo: + """ + This method is used to construct `MetaInto` for shape consistency node + TODO: Actually we could attain the cost information from resharding cost in node + handler, we should modify this part in the future. + """ + + def compute_shape(sharding_spec: ShardingSpec): + shape = sharding_spec.entire_shape + new_shape = [] + for dim, shard in sharding_spec.dim_partition_dict.items(): + new_shape.append(shape[dim] // len(shard)) + return new_shape + + meta_info = MetaInfo() + origin_sharding_spec, target_sharding_spec = node.sharding_spec, user_node.sharding_spec + _, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency( + origin_sharding_spec, target_sharding_spec) + + # NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel + # get mem cost for MetaInfo + mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence) + element_length = node._meta_data.element_size() + mem_cost.fwd.activation *= element_length + mem_cost.fwd.temp *= element_length + mem_cost.bwd.activation *= element_length + mem_cost.bwd.temp *= element_length + mem_cost.total.activation *= element_length + + meta_info.memory_cost = mem_cost + + # get computation cost for MetaInfo + compute_cost = TrainCycleItem(total_cost['forward'], total_cost['backward'], total_cost['total']) + meta_info.compute_cost = compute_cost + + # get tensor shape for MetaInfo + input_shape = compute_shape(origin_sharding_spec) + output_shape = compute_shape(target_sharding_spec) + + meta_info.fwd_in = [torch.rand(input_shape, device='meta')] + meta_info.fwd_buffer = [] + meta_info.fwd_out = [torch.rand(output_shape, device='meta')] + + return meta_info + + def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_index: int, op_data_name: str): """ This method will be invoked during runtime to apply the comm action following the instruction of comm spec. @@ -126,6 +174,8 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule): runtime_apply, args=(node, origin_dict_node, input_dict_node, node_to_index_dict[node], user_node_index)) + meta_info = construct_meta_info(node, user_node) + setattr(shape_consistency_node, 'best_metainfo', meta_info) new_args = list(user_node.args) new_kwargs = dict(user_node.kwargs) diff --git a/colossalai/tensor/shape_consistency.py b/colossalai/tensor/shape_consistency.py index 144712fc5..daf81034f 100644 --- a/colossalai/tensor/shape_consistency.py +++ b/colossalai/tensor/shape_consistency.py @@ -407,9 +407,6 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): def mem_cost(self, comm_action_sequence: List[CommSpec]) -> TrainCycleItem: """memory cost of the communication action sequence - TODO: Currently we just consider tensor numel in the shape consistency manger, - as the manager itself doesn't have the access to tensor dtype, we need to take - it into consideration in memory estimation. Args: comm_action_sequence (List[CommSpec]): list of communication actions @@ -420,9 +417,10 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): def compute_shape(sharding_spec: ShardingSpec): shape = sharding_spec.entire_shape + new_shape = [] for dim, shard in sharding_spec.dim_partition_dict.items(): - shape[dim] = shape[dim] // len(shard) - return shape + new_shape.append(shape[dim] // len(shard)) + return new_shape def gather_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): """analyze all_gather memory footprint @@ -461,7 +459,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): # generate a new tensor input_shape = compute_shape(comm_spec.sharding_spec) input_numel = np.prod(input_shape) - output_numel = input_numel // comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axes] + output_numel = input_numel // comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis] alloc_numel += output_numel peak_numel = max(peak_numel, alloc_numel) if discard_input: @@ -538,8 +536,9 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): # analyze memory footprint of forward comm actions sequence fwd_alloc_numel = 0 fwd_peak_numel = 0 - for idx, fwd_action, comm_spec in enumerate(zip(fwd_actions, comm_action_sequence)): + for idx, action_spec_pair in enumerate(zip(fwd_actions, comm_action_sequence)): # the first forward comm action will not discard input + fwd_action, comm_spec = action_spec_pair if idx == 0: fwd_action(comm_spec, False, fwd_alloc_numel, fwd_peak_numel) else: @@ -548,7 +547,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): # analyze memory footprint for backward comm actions sequence bwd_alloc_numel = 0 bwd_peak_numel = 0 - for idx, bwd_action, comm_spec in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))): + for idx, action_spec_pair in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))): + bwd_action, comm_spec = action_spec_pair bwd_action(comm_spec, True, bwd_alloc_numel, bwd_peak_numel) fwd_mem = MemoryCost(activation=fwd_alloc_numel, temp=fwd_peak_numel - fwd_alloc_numel) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_batchnorm_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_batchnorm_metainfo.py index 7acbbed8f..826c74666 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_batchnorm_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_batchnorm_metainfo.py @@ -37,7 +37,7 @@ def _batchnorm_module_mem_test(rank, world_size, port): # index of target node in computation graph node_index = 1 # total number of target node strategies - strategy_number = 4 + strategy_number = 9 mem_test_for_node_strategy(rank=rank, model=model, device_mesh=device_mesh, diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py index 62fe11e22..e9c0601eb 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py @@ -92,7 +92,7 @@ def _linear_function_mem_test(rank, world_size, port): model=model, device_mesh=device_mesh, node_index=2, - strategy_number=23, + strategy_number=24, input_args=[input], meta_arg_names=["input"])