From ab38aebaced3e77f8fe5566b2ac28ad10ccd8eac Mon Sep 17 00:00:00 2001 From: Boyuan Yao <70263930+Cypher30@users.noreply.github.com> Date: Mon, 2 Jan 2023 16:25:18 +0800 Subject: [PATCH] [autoparallel] Hook all meta information on ResNet nodes for auto activation checkpoint (#2248) * [autoparallel] hook node meta on graph nodes for checkpoint solver * [autoparallel] polish code * [autoparallel] restore some node handlers * colossalai/auto_parallel/passes/meta_info_prop.py * [autoparallel] remove some unused import * [autoparallel] hook bwd_mem_out --- .../meta_registry/binary_elementwise_ops.py | 2 +- .../auto_parallel/meta_profiler/metainfo.py | 22 +--- .../passes/comm_metainfo_pass.py | 113 ++++++++++++++++++ .../auto_parallel/passes/meta_info_prop.py | 19 ++- .../passes/runtime_apply_pass.py | 49 -------- .../tensor_shard/node_handler/node_handler.py | 3 +- 6 files changed, 132 insertions(+), 76 deletions(-) create mode 100644 colossalai/auto_parallel/passes/comm_metainfo_pass.py diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py index eb8042368..b4cc58d05 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py @@ -60,7 +60,7 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) # store fwd_in, fwd_buffer, fwd_out - fwd_in = [torch.zeros_like(input_op_data.data, device='meta')] + fwd_in = [torch.zeros_like(input_op_data.data, device='meta'), torch.zeros_like(other_op_data.data, device='meta')] fwd_buffer = [] fwd_out = [torch.zeros_like(output_op_data.data, device='meta')] diff --git a/colossalai/auto_parallel/meta_profiler/metainfo.py b/colossalai/auto_parallel/meta_profiler/metainfo.py index 1f3463713..ff76e3059 100644 --- a/colossalai/auto_parallel/meta_profiler/metainfo.py +++ b/colossalai/auto_parallel/meta_profiler/metainfo.py @@ -1,6 +1,5 @@ from typing import Callable, List -import numpy as np import torch from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( @@ -71,25 +70,12 @@ class MetaInfo: if self._strategy is not None and self._target is not None: self.compute_metainfo() - def compute_sharded_tensor(self, operation_data: OperationData, sharding_spec: ShardingSpec) -> torch.Tensor: + def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: ShardingSpec) -> torch.Tensor: """ - Compute sharded meta tensor based on the given data and sharding spec. + Compute sharded opdata based on the given data and sharding spec. """ - shard_sequnce = sharding_spec.sharding_sequence - device_mesh = sharding_spec.device_mesh - shape = operation_data.data.shape - - new_shape = [] - for dim, shard in zip(shape, shard_sequnce): - if shard.is_replica: - # replica - new_shape.append(dim) - else: - # sharded according to device_mesh shape - new_shape.append(dim // np.prod(np.array([device_mesh.mesh_shape[i] for i in shard.shard_list]))) - return OperationData(name=operation_data.name, - data=torch.zeros(new_shape, device="meta"), + data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"), type=operation_data.type, logical_shape=operation_data.logical_shape) @@ -113,7 +99,7 @@ class MetaInfo: save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION # construct args for meta_func - args = [self.compute_sharded_tensor(k, v) for k, v in self._strategy.sharding_specs.items()] + args = [self.compute_sharded_opdata(k, v) for k, v in self._strategy.sharding_specs.items()] # construct kwargs if self.target in INPLACE_MODULE: diff --git a/colossalai/auto_parallel/passes/comm_metainfo_pass.py b/colossalai/auto_parallel/passes/comm_metainfo_pass.py new file mode 100644 index 000000000..5ab6289b7 --- /dev/null +++ b/colossalai/auto_parallel/passes/comm_metainfo_pass.py @@ -0,0 +1,113 @@ +from typing import Dict + +import torch +from torch.fx import GraphModule +from torch.fx.node import Node + +from colossalai.auto_parallel.meta_profiler import MetaInfo +from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem +from colossalai.tensor.comm_spec import CommSpec +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec + +shape_consistency_manager = ShapeConsistencyManager() + + +def _construct_meta_info(node: Node, origin_sharding_spec: ShardingSpec, + target_sharding_spec: ShardingSpec) -> MetaInfo: + # get comm_action_sequence and total_cost from shape_consistency_manager + _, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency( + origin_sharding_spec, target_sharding_spec) + + meta_info = MetaInfo() + # NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel + # get mem cost for MetaInfo + mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence) + # extract user that has _meta_data and extract element length + input_node = next(n for n in node._input_nodes if hasattr(n, '_meta_data')) + element_length = input_node._meta_data.element_size() + + mem_cost.fwd.activation *= element_length + mem_cost.fwd.temp *= element_length + mem_cost.bwd.activation *= element_length + mem_cost.bwd.temp *= element_length + mem_cost.total.activation *= element_length + + meta_info.memory_cost = mem_cost + + # get computation cost for MetaInfo + meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length, + total_cost['backward'] * element_length, + total_cost['total'] * element_length) + + # get tensor shape for MetaInfo + origin_sharding_spec: ShardingSpec + target_sharding_spec: ShardingSpec + input_shape = origin_sharding_spec.get_sharded_shape_per_device() + output_shape = target_sharding_spec.get_sharded_shape_per_device() + + meta_info.fwd_in = [torch.rand(input_shape, device='meta')] + meta_info.fwd_buffer = [] + meta_info.fwd_out = [torch.rand(output_shape, device='meta')] + + return meta_info + + +def _runtime_apply_meta_info(node: Node, original_sharding_spec_dict, sharding_spec_dict) -> MetaInfo: + """ + This method is used to construct `MetaInto` for shape consistency node + """ + + # extract node index and user node index + args = node.args + node_index, user_node_index = args[3], args[4] + origin_sharding_spec, target_sharding_spec = original_sharding_spec_dict[node_index], sharding_spec_dict[ + node_index][user_node_index] + + return _construct_meta_info(node, origin_sharding_spec, target_sharding_spec) + + +def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> MetaInfo: + # extract node_index and op_data_name + node_index, op_data_name = node.args[2], node.args[3] + + comm_action = comm_actions_dict[node_index][op_data_name] + if isinstance(comm_action.comm_spec, CommSpec): + # this case is for all_reduce, there will be no memory cost + meta_info = MetaInfo() + meta_info.memory_cost = TrainCycleItem(MemoryCost(), MemoryCost(), MemoryCost) + output_node = next(n for n in node.users if hasattr(n, '_meta_data')) + element_length = output_node._meta_data.element_size() + + total_cost = comm_action.comm_spec.get_comm_cost() + meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length, + total_cost['backward'] * element_length, + total_cost['total'] * element_length) + + input_shape = output_shape = comm_action.comm_spec.sharding_spec.get_sharded_shape_per_device() + meta_info.fwd_in = [torch.rand(input_shape, device='meta')] + meta_info.fwd_buffer = [] + meta_info.fwd_out = [torch.rand(output_shape, device='meta')] + else: + # this case will be handled by shape consistency manager + origin_sharding_spec, target_sharding_spec = comm_action.comm_spec['src_spec'], comm_action.comm_spec[ + 'tgt_spec'] + meta_info = _construct_meta_info(node, origin_sharding_spec, target_sharding_spec) + + return meta_info + + +def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, original_sharding_spec_dict: Dict, + comm_actions_dict: Dict): + """ + The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph. + """ + for node in gm.graph.nodes: + if node.target == runtime_apply: + setattr(node, 'best_metainfo', + _runtime_apply_meta_info(node, original_sharding_spec_dict, sharding_spec_dict)) + elif node.target == runtime_comm_spec_apply: + setattr(node, 'best_metainfo', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict)) + else: + pass diff --git a/colossalai/auto_parallel/passes/meta_info_prop.py b/colossalai/auto_parallel/passes/meta_info_prop.py index 1628bb285..607f7e17e 100644 --- a/colossalai/auto_parallel/passes/meta_info_prop.py +++ b/colossalai/auto_parallel/passes/meta_info_prop.py @@ -1,15 +1,14 @@ import uuid from dataclasses import asdict -from typing import Any, Dict, List, NamedTuple, Tuple +from typing import List import torch import torch.fx from torch.fx import GraphModule -from torch.fx.node import Argument, Node, Target -from torch.utils._pytree import tree_map +from torch.fx.node import Node from colossalai.auto_parallel.meta_profiler import MetaInfo -from colossalai.fx._compatibility import compatibility, is_compatible_with_meta +from colossalai.fx._compatibility import compatibility from colossalai.fx.profiler import GraphInfo from colossalai.fx.profiler.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS @@ -68,7 +67,7 @@ class MetaInfoProp: """ graph_info = GraphInfo() out = _normalize_tuple(getattr(node, '_meta_data', None)) - graph_info.fwd_out = list(out) + graph_info.fwd_out = list(out) if out[0] is not None else [] node.meta = {**asdict(graph_info)} @compatibility(is_backward_compatible=False) @@ -97,7 +96,7 @@ class MetaInfoProp: """ Handle other kind of nodes """ - assert hasattr(node, 'best_metainfo'), f"Cannot find best_metainfo in node {node}" + assert hasattr(node, 'best_metainfo'), f"Cannot find best_metainfo in node {node}, {node.op}" graph_info = GraphInfo() meta_info = node.best_metainfo meta_info: MetaInfo @@ -158,5 +157,13 @@ class MetaInfoProp: memory_cost = meta_info.memory_cost graph_info.fwd_mem_tmp = memory_cost.fwd.temp graph_info.bwd_mem_tmp = memory_cost.bwd.temp + graph_info.bwd_mem_out = memory_cost.bwd.activation + + # fetch flop information + # here we use fwd_time and bwd_time to deal with the case that + # communication cost is a float + compute_cost = meta_info.compute_cost + graph_info.fwd_time = compute_cost.fwd + graph_info.bwd_time = compute_cost.bwd node.meta = {**asdict(graph_info)} diff --git a/colossalai/auto_parallel/passes/runtime_apply_pass.py b/colossalai/auto_parallel/passes/runtime_apply_pass.py index 5d224542c..7f2aac42b 100644 --- a/colossalai/auto_parallel/passes/runtime_apply_pass.py +++ b/colossalai/auto_parallel/passes/runtime_apply_pass.py @@ -47,53 +47,6 @@ def runtime_apply_for_iterable_object(node: Node, origin_dict: Dict, input_dict: return rst -def construct_meta_info(node: Node, user_node: Node) -> MetaInfo: - """ - This method is used to construct `MetaInto` for shape consistency node - TODO: Actually we could attain the cost information from resharding cost in node - handler, we should modify this part in the future. - """ - - def compute_shape(sharding_spec: ShardingSpec): - shape = sharding_spec.entire_shape - new_shape = [] - for dim, shard in sharding_spec.dim_partition_dict.items(): - new_shape.append(shape[dim] // len(shard)) - return new_shape - - meta_info = MetaInfo() - origin_sharding_spec, target_sharding_spec = node.sharding_spec, user_node.best_strategy.get_sharding_spec_by_name( - str(node.name)) - _, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency( - origin_sharding_spec, target_sharding_spec) - - # NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel - # get mem cost for MetaInfo - mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence) - element_length = node._meta_data.element_size() - mem_cost.fwd.activation *= element_length - mem_cost.fwd.temp *= element_length - mem_cost.bwd.activation *= element_length - mem_cost.bwd.temp *= element_length - mem_cost.total.activation *= element_length - - meta_info.memory_cost = mem_cost - - # get computation cost for MetaInfo - compute_cost = TrainCycleItem(total_cost['forward'], total_cost['backward'], total_cost['total']) - meta_info.compute_cost = compute_cost - - # get tensor shape for MetaInfo - input_shape = compute_shape(origin_sharding_spec) - output_shape = compute_shape(target_sharding_spec) - - meta_info.fwd_in = [torch.rand(input_shape, device='meta')] - meta_info.fwd_buffer = [] - meta_info.fwd_out = [torch.rand(output_shape, device='meta')] - - return meta_info - - def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_index: int, op_data_name: str): """ This method will be invoked during runtime to apply the comm action following the instruction of comm spec. @@ -175,8 +128,6 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule): runtime_apply, args=(node, origin_dict_node, input_dict_node, node_to_index_dict[node], user_node_index)) - meta_info = construct_meta_info(node, user_node) - setattr(shape_consistency_node, 'best_metainfo', meta_info) new_args = list(user_node.args) new_kwargs = dict(user_node.kwargs) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py index 7dea256b3..af3cb5810 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py @@ -138,8 +138,7 @@ class NodeHandler(ABC): return None if self.node.op == 'call_module': - submod = self.node.graph.owning_module.get_submodule(self.node.target) - target = type(submod) + target = self.node.graph.owning_module.get_submodule(self.node.target) elif self.node.op == 'call_function': target = self.node.target elif self.node.op == 'call_method':