diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py index 63eff31b2..b388d00ac 100644 --- a/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py @@ -5,8 +5,12 @@ from typing import Any, List import torch from torch.fx import Graph, Node +from colossalai.auto_parallel.passes.runtime_apply_pass import ( + runtime_apply, + runtime_apply_for_iterable_object, + runtime_comm_spec_apply, +) from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen -from colossalai.fx.profiler.memory_utils import is_inplace __all___ = ['CheckpointSolverBase'] @@ -31,10 +35,11 @@ class CheckpointSolverBase(ABC): free_memory: float = -1.0, requires_linearize: bool = False, cnode: List[str] = None, + optim_multiplier: float = 1.0, ): - """CheckpointSolver class will integrate information provided by the components - and use an existing solver to find a possible optimal strategies combination for - target computing graph. + """``CheckpointSolverBase`` class will integrate information provided by the components + and use an existing solver to find a possible optimal strategies combination for target + computing graph. Existing Solvers: Chen's Greedy solver: https://arxiv.org/abs/1604.06174 (CheckpointSolverChen) @@ -45,9 +50,11 @@ class CheckpointSolverBase(ABC): free_memory (float): Memory constraint for the solution. requires_linearize (bool): Whether the graph needs to be linearized. cnode (List[str], optional): Common node List, should be the subset of input. Default to None. + optim_multiplier (float, optional): The multiplier of extra weight storage for the + ``torch.optim.Optimizer``. Default to 1.0. Warnings: - `MetaInfoProp` should be done before constructing the solver. Meta information of the graph is required. + Meta information of the graph is required for any ``CheckpointSolver``. """ # super-dainiu: this graph is a temporary graph which can refer to # the owning module, but we will return another deepcopy of it after @@ -57,13 +64,14 @@ class CheckpointSolverBase(ABC): _copy_output(graph, self.graph) self.graph.set_codegen(ActivationCheckpointCodeGen()) - # check if `MetaInfoProp` is done + # check if has meta information if any(len(node.meta) == 0 for node in self.graph.nodes): raise RuntimeError( - "Nodes meta information hasn't been prepared! Please run MetaInfoProp before constructing the solver!") + "Nodes meta information hasn't been prepared! Please extract from graph before constructing the solver!" + ) - self.free_memory = free_memory - self.parameter_size = _get_param_size(self.graph.owning_module) + # parameter memory = parameter size + optimizer extra weight storage + self.free_memory = free_memory - _get_param_size(self.graph.owning_module) * (optim_multiplier + 1) self.cnode = cnode self.requires_linearize = requires_linearize if self.requires_linearize: @@ -93,7 +101,7 @@ class CheckpointSolverBase(ABC): the actual 'node' in linearized manner. Remarks: - Do merge the inplace ops into the previous node. + Do merge the inplace ops and shape-consistency ops into the previous node. """ # Common nodes are type of nodes that could be seen as attributes and remain @@ -131,7 +139,23 @@ class CheckpointSolverBase(ABC): bool """ - return not sum([v for _, v in deps.items()]) and not any(map(is_inplace, n.users)) + def _is_inplace(n: Node): + """Get the inplace argument from ``torch.fx.Node`` + """ + inplace = False + if n.op == "call_function": + inplace = n.kwargs.get("inplace", False) + elif n.op == "call_module": + inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False) + return inplace + + def _is_shape_consistency(n: Node): + """Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``) + """ + return n.target in [runtime_apply, runtime_apply_for_iterable_object, runtime_comm_spec_apply] + + return not sum([v for _, v in deps.items()]) and not any(map(_is_inplace, n.users)) and not any( + map(_is_shape_consistency, n.users)) # make sure that item in cnode is valid if self.cnode: diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py index 58878253e..19b2ef598 100644 --- a/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py @@ -19,9 +19,9 @@ class CheckpointSolverChen(CheckpointSolverBase): Note that this algorithm targets at memory optimization only, using techniques in appendix A. Usage: - Assume that we have a `GraphModule`, and we already applied the `MetaInfoProp` + Assume that we have a ``GraphModule``, and we have already done the extractions to the graph to retrieve all information needed, then we could use the following - code to find a solution using `CheckpointSolverChen`: + code to find a solution using ``CheckpointSolverChen``: >>> solver = CheckpointSolverChen(gm.graph) >>> chen_graph = solver.solve() >>> gm.graph = chen_graph # set the graph to a new graph @@ -74,7 +74,7 @@ class CheckpointSolverChen(CheckpointSolverBase): def grid_search(self) -> Set: """ Search ckpt strategy with b = 0, then run the allocation algorithm again with b = √xy. - Grid search over [√2/2 b, √2 b] for ckpt_opt over num_grids as in appendix A. + Grid search over [√2/2 b, √2 b] for ``ckpt_opt`` over ``num_grids`` as in appendix A. """ _, b_approx = self.run_chen_greedy(0) b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2)) diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py index 72bc67e02..41d23be5c 100644 --- a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py @@ -4,6 +4,7 @@ from typing import Any, Dict, List, Tuple from torch import Tensor from torch.fx import Graph, Node +from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions from colossalai.fx.profiler import ( activation_size, @@ -22,15 +23,20 @@ __all__ = ['CheckpointSolverRotor'] class CheckpointSolverRotor(CheckpointSolverBase): - def __init__(self, graph: Graph, free_memory: float = -1, cnode: List[str] = None, memory_slots: int = 500): + def __init__(self, + graph: Graph, + free_memory: float = -1, + cnode: List[str] = None, + memory_slots: int = 500, + optim_multiplier: float = 1.0): """This is the simple implementation of dynamic programming algorithm rotor in https://hal.inria.fr/hal-02352969. Some code are adapted from https://gitlab.inria.fr/hiepacs/rotor. Usage: - Assume that we have a `GraphModule`, and we already applied the `MetaInfoProp` + Assume that we have a ``GraphModule``, and we have already done the extractions to the graph to retrieve all information needed, then we could use the following - code to find a solution using `CheckpointSolverRotor`: + code to find a solution using ``CheckpointSolverRotor``: >>> solver = CheckpointSolverRotor(gm.graph, free_memory=torch.cuda.mem_get_info(device=0)[0]) >>> rotor_graph = solver.solve(force_python=True) # otherwise use C solver >>> gm.graph = rotor_graph # set the graph to a new graph @@ -41,8 +47,10 @@ class CheckpointSolverRotor(CheckpointSolverBase): Use ``torch.cuda.mem_get_info(device=0)[0]`` to estimate the free_memory. Defaults to -1. cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None. memory_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500. + optim_multiplier (float, optional): The multiplier of extra weight storage for the + ``torch.optim.Optimizer``. Default to 1.0. """ - super().__init__(graph, free_memory, True, cnode) + super().__init__(graph, free_memory, True, cnode, optim_multiplier) self.memory_slots = memory_slots # construct chain @@ -128,16 +136,24 @@ class CheckpointSolverRotor(CheckpointSolverBase): xbar = 0 ftime = 0 btime = 0 + fwd_mem_peak = 0 for n in node: assert isinstance(n, Node), f'{n} is not a Node' - xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n) + if n.target == runtime_apply or n.target == runtime_comm_spec_apply: + # in this case we need to calculate memory usage directly based on the statics that hooked in node.meta + xbar += n.meta['fwd_mem_out'] + fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp']) + else: + xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n) + fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'] + cls._extract_unused_output(n)) + # minimum flop count is required ftime += max(calculate_fwd_time(n), 1.0) btime += max(calculate_bwd_time(n), 1.0) x = calculate_fwd_out(node[-1]) xbar = max(x, xbar) - ftmp = cls._extract_ftmp(node) + ftmp = fwd_mem_peak - xbar btmp = cls._extract_btmp(node) return ftime, btime, x, xbar, ftmp, btmp @@ -151,10 +167,9 @@ class CheckpointSolverRotor(CheckpointSolverBase): return input_tensors @staticmethod - def _extract_ftmp(node: List[Node]) -> int: - """Extract ftmp from a list of nodes""" - n = node[-1] - return activation_size(n.meta['fwd_out']) - calculate_fwd_out(n) + def _extract_unused_output(node: Node) -> int: + """Extract unused output from `torch.fx.Node`""" + return activation_size(node.meta['fwd_out']) - calculate_fwd_out(node) @staticmethod def _extract_btmp(node: List[Node]) -> int: @@ -290,8 +305,8 @@ class CheckpointSolverRotor(CheckpointSolverBase): lhs (int): The left index of the interval to backtrack. rhs (int): The right index of the interval to backtrack. budget (int): The memory budget for processing this interval. - cost_table (List[Any]): See `._compute_table()` for definitions - back_ptr (List[Any]): See `._compute_table()` for definitions + cost_table (List[Any]): See ``._compute_table()`` for definitions + back_ptr (List[Any]): See ``._compute_table()`` for definitions Raises: ValueError: Can not process the chain. @@ -332,7 +347,7 @@ class CheckpointSolverRotor(CheckpointSolverBase): @staticmethod def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]): - """Annotate the nodes in the node_list with activation checkpoint from the sequence. + """Annotate the nodes in the ``node_list`` with activation checkpoint from the sequence. Args: sequence (Sequence): The sequence of executing nodes with activation checkpoint annotations. diff --git a/colossalai/auto_parallel/meta_profiler/constants.py b/colossalai/auto_parallel/meta_profiler/constants.py index 714674b7b..35b8c13ee 100644 --- a/colossalai/auto_parallel/meta_profiler/constants.py +++ b/colossalai/auto_parallel/meta_profiler/constants.py @@ -5,8 +5,11 @@ import torch.nn as nn from ..tensor_shard.constants import * -# list of inplace operations +# list of inplace module INPLACE_MODULE = [nn.ReLU] +# list of inplace operations +INPLACE_OPS = [torch.flatten] + # list of operations that do not save forward activations NO_SAVE_ACTIVATION = [torch.add, torch.sub, operator.add, operator.sub] diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py index eb8042368..281a92c0d 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py @@ -24,26 +24,25 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs """ - input_op_data, other_op_data = [arg for arg in args if arg.type != OperationDataType.OUTPUT] + input_op_data = [arg for arg in args if arg.type != OperationDataType.OUTPUT] output_op_data = next(filter(lambda arg: arg.type == OperationDataType.OUTPUT, args)) # construct forward args for flop mapping - fwd_in_args = [input_op_data.data, other_op_data.data] + fwd_in_args = [opdata.data for opdata in input_op_data] fwd_out_args = [output_op_data.data] # calculate cost # calculate compute cost # NOTE: we set bwd_compute_cost two times of fwd_compute_cost in this case - fwd_compute_cost = flop_mapping[torch.ops.aten._adaptive_avg_pool2d.default](fwd_in_args, fwd_out_args) + fwd_compute_cost = flop_mapping[torch.ops.aten.add.Tensor](fwd_in_args, fwd_out_args) bwd_compute_cost = fwd_compute_cost * 2 compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) # calculate memory cost - param_mem_cost = activation_size( - [arg.data for arg in [input_op_data, other_op_data] if arg.type == OperationDataType.PARAM]) + param_mem_cost = activation_size([arg.data for arg in input_op_data if arg.type == OperationDataType.PARAM]) fwd_mem_cost = MemoryCost( - activation=activation_size([input_op_data.data, output_op_data.data]), + activation=activation_size(output_op_data.data), parameter=param_mem_cost, ) bwd_mem_cost = MemoryCost( @@ -60,7 +59,7 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) # store fwd_in, fwd_buffer, fwd_out - fwd_in = [torch.zeros_like(input_op_data.data, device='meta')] + fwd_in = [] fwd_buffer = [] fwd_out = [torch.zeros_like(output_op_data.data, device='meta')] diff --git a/colossalai/auto_parallel/meta_profiler/metainfo.py b/colossalai/auto_parallel/meta_profiler/metainfo.py index 1f3463713..218187768 100644 --- a/colossalai/auto_parallel/meta_profiler/metainfo.py +++ b/colossalai/auto_parallel/meta_profiler/metainfo.py @@ -1,6 +1,5 @@ from typing import Callable, List -import numpy as np import torch from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( @@ -13,7 +12,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( ) from colossalai.tensor.sharding_spec import ShardingSpec -from .constants import INPLACE_MODULE, NO_SAVE_ACTIVATION +from .constants import INPLACE_MODULE, INPLACE_OPS, NO_SAVE_ACTIVATION from .registry import meta_register __all__ = ['MetaInfo'] @@ -71,25 +70,12 @@ class MetaInfo: if self._strategy is not None and self._target is not None: self.compute_metainfo() - def compute_sharded_tensor(self, operation_data: OperationData, sharding_spec: ShardingSpec) -> torch.Tensor: + def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: ShardingSpec) -> torch.Tensor: """ - Compute sharded meta tensor based on the given data and sharding spec. + Compute sharded opdata based on the given data and sharding spec. """ - shard_sequnce = sharding_spec.sharding_sequence - device_mesh = sharding_spec.device_mesh - shape = operation_data.data.shape - - new_shape = [] - for dim, shard in zip(shape, shard_sequnce): - if shard.is_replica: - # replica - new_shape.append(dim) - else: - # sharded according to device_mesh shape - new_shape.append(dim // np.prod(np.array([device_mesh.mesh_shape[i] for i in shard.shard_list]))) - return OperationData(name=operation_data.name, - data=torch.zeros(new_shape, device="meta"), + data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"), type=operation_data.type, logical_shape=operation_data.logical_shape) @@ -113,11 +99,13 @@ class MetaInfo: save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION # construct args for meta_func - args = [self.compute_sharded_tensor(k, v) for k, v in self._strategy.sharding_specs.items()] + args = [self.compute_sharded_opdata(k, v) for k, v in self._strategy.sharding_specs.items()] # construct kwargs if self.target in INPLACE_MODULE: kwargs = {'inplace': self.target.inplace} + elif self.target in INPLACE_OPS: + kwargs = {'inplace': True} else: kwargs = {'inplace': False} diff --git a/colossalai/auto_parallel/passes/comm_metainfo_pass.py b/colossalai/auto_parallel/passes/comm_metainfo_pass.py new file mode 100644 index 000000000..ab3acb056 --- /dev/null +++ b/colossalai/auto_parallel/passes/comm_metainfo_pass.py @@ -0,0 +1,113 @@ +from typing import Dict + +import torch +from torch.fx import GraphModule +from torch.fx.node import Node + +from colossalai.auto_parallel.meta_profiler import MetaInfo +from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem +from colossalai.tensor.comm_spec import CommSpec +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec + +shape_consistency_manager = ShapeConsistencyManager() + + +def _construct_meta_info(node: Node, origin_sharding_spec: ShardingSpec, + target_sharding_spec: ShardingSpec) -> MetaInfo: + # get comm_action_sequence and total_cost from shape_consistency_manager + _, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency( + origin_sharding_spec, target_sharding_spec) + + meta_info = MetaInfo() + # NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel + # get mem cost for MetaInfo + mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence) + # extract user that has _meta_data and extract element length + input_node = next(n for n in node._input_nodes if hasattr(n, '_meta_data')) + element_length = input_node._meta_data.element_size() + + mem_cost.fwd.activation *= element_length + mem_cost.fwd.temp *= element_length + mem_cost.bwd.activation *= element_length + mem_cost.bwd.temp *= element_length + mem_cost.total.activation *= element_length + + meta_info.memory_cost = mem_cost + + # get computation cost for MetaInfo + meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length, + total_cost['backward'] * element_length, + total_cost['total'] * element_length) + + # get tensor shape for MetaInfo + origin_sharding_spec: ShardingSpec + target_sharding_spec: ShardingSpec + input_shape = origin_sharding_spec.get_sharded_shape_per_device() + output_shape = target_sharding_spec.get_sharded_shape_per_device() + + meta_info.fwd_in = [torch.rand(input_shape, device='meta')] + meta_info.fwd_buffer = [] + meta_info.fwd_out = [torch.rand(output_shape, device='meta')] + + return meta_info + + +def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -> MetaInfo: + """ + This method is used to construct `MetaInto` for shape consistency node + """ + + # extract node index and user node index + args = node.args + node_index, user_node_index = args[3], args[4] + origin_sharding_spec, target_sharding_spec = origin_spec_dict[node_index], sharding_spec_dict[node_index][ + user_node_index] + + return _construct_meta_info(node, origin_sharding_spec, target_sharding_spec) + + +def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> MetaInfo: + # extract node_index and op_data_name + node_index, op_data_name = node.args[2], node.args[3] + + comm_action = comm_actions_dict[node_index][op_data_name] + if isinstance(comm_action.comm_spec, CommSpec): + # this case is for all_reduce, there will be no memory cost + meta_info = MetaInfo() + meta_info.memory_cost = TrainCycleItem(MemoryCost(), MemoryCost(), MemoryCost) + output_node = next(n for n in node.users if hasattr(n, '_meta_data')) + element_length = output_node._meta_data.element_size() + + total_cost = comm_action.comm_spec.get_comm_cost() + meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length, + total_cost['backward'] * element_length, + total_cost['total'] * element_length) + + input_shape = output_shape = comm_action.comm_spec.sharding_spec.get_sharded_shape_per_device() + meta_info.fwd_in = [torch.rand(input_shape, device='meta')] + meta_info.fwd_buffer = [] + meta_info.fwd_out = [torch.rand(output_shape, device='meta')] + else: + # this case will be handled by shape consistency manager + origin_sharding_spec, target_sharding_spec = comm_action.comm_spec['src_spec'], comm_action.comm_spec[ + 'tgt_spec'] + meta_info = _construct_meta_info(node, origin_sharding_spec, target_sharding_spec) + + return meta_info + + +def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, origin_spec_dict: Dict, + comm_actions_dict: Dict) -> GraphModule: + """ + The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph. + """ + for node in gm.graph.nodes: + if node.target == runtime_apply: + setattr(node, 'best_metainfo', _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict)) + elif node.target == runtime_comm_spec_apply: + setattr(node, 'best_metainfo', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict)) + else: + pass + return gm diff --git a/colossalai/auto_parallel/passes/constants.py b/colossalai/auto_parallel/passes/constants.py new file mode 100644 index 000000000..b86088474 --- /dev/null +++ b/colossalai/auto_parallel/passes/constants.py @@ -0,0 +1,8 @@ +import torch + +OUTPUT_SAVED_OPS = [torch.nn.functional.relu, torch.nn.functional.softmax, torch.flatten] + +OUTPUT_SAVED_MOD = [ + torch.nn.ReLU, + torch.nn.Softmax, +] diff --git a/colossalai/auto_parallel/passes/meta_info_prop.py b/colossalai/auto_parallel/passes/meta_info_prop.py index 1628bb285..f7e07ef1e 100644 --- a/colossalai/auto_parallel/passes/meta_info_prop.py +++ b/colossalai/auto_parallel/passes/meta_info_prop.py @@ -1,17 +1,16 @@ import uuid from dataclasses import asdict -from typing import Any, Dict, List, NamedTuple, Tuple +from typing import List import torch import torch.fx from torch.fx import GraphModule -from torch.fx.node import Argument, Node, Target -from torch.utils._pytree import tree_map +from torch.fx.node import Node from colossalai.auto_parallel.meta_profiler import MetaInfo -from colossalai.fx._compatibility import compatibility, is_compatible_with_meta +from colossalai.auto_parallel.passes.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS +from colossalai.fx._compatibility import compatibility from colossalai.fx.profiler import GraphInfo -from colossalai.fx.profiler.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS def _normalize_tuple(x): @@ -47,7 +46,7 @@ class MetaInfoProp: """ Check if the node is inplace operation. """ - if node.op == 'call_method': + if node.op == 'call_module': return node.graph.owning_module.get_submodule(node.target).__class__ in OUTPUT_SAVED_MOD elif node.op == "call_function": return node.target in OUTPUT_SAVED_OPS @@ -68,7 +67,7 @@ class MetaInfoProp: """ graph_info = GraphInfo() out = _normalize_tuple(getattr(node, '_meta_data', None)) - graph_info.fwd_out = list(out) + graph_info.fwd_out = list(out) if out[0] is not None else [] node.meta = {**asdict(graph_info)} @compatibility(is_backward_compatible=False) @@ -97,66 +96,70 @@ class MetaInfoProp: """ Handle other kind of nodes """ - assert hasattr(node, 'best_metainfo'), f"Cannot find best_metainfo in node {node}" + assert hasattr(node, 'best_metainfo'), f"Cannot find best_metainfo in node {node}, {node.op}" graph_info = GraphInfo() meta_info = node.best_metainfo meta_info: MetaInfo # set data_ptr for input_tensor in MetaInfo class - input_tensor: List[torch.Tensor] = meta_info.fwd_in - buffer_tensor: List[torch.Tensor] = meta_info.fwd_buffer - output_tensor: List[torch.Tensor] = meta_info.fwd_out + input_tensors: List[torch.Tensor] = meta_info.fwd_in + buffer_tensors: List[torch.Tensor] = meta_info.fwd_buffer + output_tensors: List[torch.Tensor] = meta_info.fwd_out - if len(input_tensor) > 0: + if self._is_inplace(node): + # inplace operation will not create new tensor, and it only has one parent node + # TODO: Verify this observation + # set data_ptr for input_tensor, buffer_tensor and output_tensor of current node + parent_node = list(node._input_nodes.keys())[0] + parent_tensor = parent_node.meta.get("fwd_out")[0] + parent_tensor: torch.Tensor + for tensor in input_tensors: + tensor.data_ptr = parent_tensor.data_ptr + for tensor in buffer_tensors: + tensor.data_ptr = parent_tensor.data_ptr + for tensor in output_tensors: + tensor.data_ptr = parent_tensor.data_ptr + + else: for par in node._input_nodes: - if par.meta: - if len(par.meta["fwd_out"]) > 0: - # set data_ptr for the input_tensor of current node from the output_tensor of its parent node - for tensor in par.meta["fwd_out"]: - tensor: torch.Tensor - target_tensor = next( - (x for x in input_tensor if not x.data_ptr() and x.shape == tensor.shape), None) - target_tensor.data_ptr = tensor.data_ptr + # set data_ptr for the input_tensor of current node from the output_tensor of its parent node + for tensor in par.meta.get("fwd_out", []): + tensor: torch.Tensor + target_input_tensor = next( + (x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None) + if target_input_tensor is not None: + target_input_tensor.data_ptr = tensor.data_ptr # set data_ptr for tensor in input_tensor that is not set - for tensor in input_tensor: + for tensor in input_tensors: if not tensor.data_ptr(): self._set_data_ptr(tensor) - # attach it to graph_info - graph_info.fwd_in = input_tensor - - if self._is_inplace(node): - # inplace operation will not create new tensor - # set data_ptr for buffer_tensor and output_tensor of current node - for tensor in input_tensor: - tensor: torch.Tensor - target_buffer_tensor = next((x for x in buffer_tensor if not x.data_ptr() and x.shape == tensor.shape), - None) - target_output_tensor = next((x for x in output_tensor if not x.data_ptr() and x.shape == tensor.shape), - None) - target_buffer_tensor.data_ptr = tensor.data_ptr - target_output_tensor.data_ptr = tensor.data_ptr - # attach them to graph_info - graph_info.fwd_tmp = buffer_tensor - graph_info.fwd_out = output_tensor - - else: # set data_ptr for buffer_tensor - for tensor in buffer_tensor: + for tensor in buffer_tensors: self._set_data_ptr(tensor) - # attach it to graph_info - graph_info.fwd_tmp = buffer_tensor # set data_ptr for output_tensor - for tensor in output_tensor: + for tensor in output_tensors: self._set_data_ptr(tensor) - # attach it to graph_info - graph_info.fwd_out = output_tensor + + # attach them to graph_info + graph_info.fwd_in = input_tensors + graph_info.fwd_tmp = buffer_tensors + graph_info.fwd_out = output_tensors # fetch other memory informations memory_cost = meta_info.memory_cost graph_info.fwd_mem_tmp = memory_cost.fwd.temp + graph_info.fwd_mem_out = memory_cost.fwd.activation graph_info.bwd_mem_tmp = memory_cost.bwd.temp + graph_info.bwd_mem_out = memory_cost.bwd.activation + + # fetch flop information + # here we use fwd_time and bwd_time to deal with the case that + # communication cost is a float + compute_cost = meta_info.compute_cost + graph_info.fwd_time = compute_cost.fwd + graph_info.bwd_time = compute_cost.bwd node.meta = {**asdict(graph_info)} diff --git a/colossalai/auto_parallel/passes/runtime_apply_pass.py b/colossalai/auto_parallel/passes/runtime_apply_pass.py index 5d224542c..7f2aac42b 100644 --- a/colossalai/auto_parallel/passes/runtime_apply_pass.py +++ b/colossalai/auto_parallel/passes/runtime_apply_pass.py @@ -47,53 +47,6 @@ def runtime_apply_for_iterable_object(node: Node, origin_dict: Dict, input_dict: return rst -def construct_meta_info(node: Node, user_node: Node) -> MetaInfo: - """ - This method is used to construct `MetaInto` for shape consistency node - TODO: Actually we could attain the cost information from resharding cost in node - handler, we should modify this part in the future. - """ - - def compute_shape(sharding_spec: ShardingSpec): - shape = sharding_spec.entire_shape - new_shape = [] - for dim, shard in sharding_spec.dim_partition_dict.items(): - new_shape.append(shape[dim] // len(shard)) - return new_shape - - meta_info = MetaInfo() - origin_sharding_spec, target_sharding_spec = node.sharding_spec, user_node.best_strategy.get_sharding_spec_by_name( - str(node.name)) - _, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency( - origin_sharding_spec, target_sharding_spec) - - # NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel - # get mem cost for MetaInfo - mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence) - element_length = node._meta_data.element_size() - mem_cost.fwd.activation *= element_length - mem_cost.fwd.temp *= element_length - mem_cost.bwd.activation *= element_length - mem_cost.bwd.temp *= element_length - mem_cost.total.activation *= element_length - - meta_info.memory_cost = mem_cost - - # get computation cost for MetaInfo - compute_cost = TrainCycleItem(total_cost['forward'], total_cost['backward'], total_cost['total']) - meta_info.compute_cost = compute_cost - - # get tensor shape for MetaInfo - input_shape = compute_shape(origin_sharding_spec) - output_shape = compute_shape(target_sharding_spec) - - meta_info.fwd_in = [torch.rand(input_shape, device='meta')] - meta_info.fwd_buffer = [] - meta_info.fwd_out = [torch.rand(output_shape, device='meta')] - - return meta_info - - def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_index: int, op_data_name: str): """ This method will be invoked during runtime to apply the comm action following the instruction of comm spec. @@ -175,8 +128,6 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule): runtime_apply, args=(node, origin_dict_node, input_dict_node, node_to_index_dict[node], user_node_index)) - meta_info = construct_meta_info(node, user_node) - setattr(shape_consistency_node, 'best_metainfo', meta_info) new_args = list(user_node.args) new_kwargs = dict(user_node.kwargs) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py index e8ae363e9..f510f7477 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py @@ -16,7 +16,7 @@ __all__ = ['BinaryElementwiseHandler'] @operator_registry.register(BCAST_FUNC_OP) -class BinaryElementwiseHandler(NodeHandler): +class BinaryElementwiseHandler(MetaInfoNodeHandler): """ An BinaryBcastOpHandler is a node handler which deals with operations which have two operands and broadcasting occurs such as torch.add. diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py index 7dea256b3..78dc58c90 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py @@ -4,7 +4,7 @@ from typing import Dict, List, Tuple, Union import torch from torch.fx.node import Node -from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo +from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo, meta_register from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( OperationData, OperationDataType, @@ -138,8 +138,7 @@ class NodeHandler(ABC): return None if self.node.op == 'call_module': - submod = self.node.graph.owning_module.get_submodule(self.node.target) - target = type(submod) + target = self.node.graph.owning_module.get_submodule(self.node.target) elif self.node.op == 'call_function': target = self.node.target elif self.node.op == 'call_method': @@ -235,15 +234,19 @@ class MetaInfoNodeHandler(NodeHandler): """ super().register_strategy(compute_resharding_cost=compute_resharding_cost) target = self.get_target_function() - metainfo_vector = [] - for strategy in self.strategies_vector: - metainfo = MetaInfo(strategy, target) - strategy.compute_cost = metainfo.compute_cost - strategy.memory_cost = metainfo.memory_cost - metainfo_vector.append(metainfo) - - # attach metainfos to the handler - setattr(self, "metainfo_vector", metainfo_vector) + # Currently we haven't patched all the torch functions and modules, so if the target + # is not patched, we will use the default cost model to compute the cost. + # TODO: patch all torch functions and modules to make it clean + if meta_register.has(target.__class__) or meta_register.has(target): + metainfo_vector = [] + for strategy in self.strategies_vector: + metainfo = MetaInfo(strategy, target) + strategy.compute_cost = metainfo.compute_cost + strategy.memory_cost = metainfo.memory_cost + metainfo_vector.append(metainfo) + + # attach metainfos to the handler + setattr(self, "metainfo_vector", metainfo_vector) return self.strategies_vector @@ -282,14 +285,18 @@ class MetaInfoModuleHandler(ModuleHandler): """ super().register_strategy(compute_resharding_cost=compute_resharding_cost) target = self.get_target_function() - metainfo_vector = [] - for strategy in self.strategies_vector: - metainfo = MetaInfo(strategy, target) - strategy.compute_cost = metainfo.compute_cost - strategy.memory_cost = metainfo.memory_cost - metainfo_vector.append(metainfo) - - # attach metainfos to the handler - setattr(self, "metainfo_vector", metainfo_vector) + # Currently we haven't patched all the torch functions and modules, so if the target + # is not patched, we will use the default cost model to compute the cost. + # TODO: patch all torch functions and modules to make it clean + if meta_register.has(target.__class__) or meta_register.has(target): + metainfo_vector = [] + for strategy in self.strategies_vector: + metainfo = MetaInfo(strategy, target) + strategy.compute_cost = metainfo.compute_cost + strategy.memory_cost = metainfo.memory_cost + metainfo_vector.append(metainfo) + + # attach metainfos to the handler + setattr(self, "metainfo_vector", metainfo_vector) return self.strategies_vector diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py index b46348716..7763b1884 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py @@ -3,7 +3,7 @@ from typing import Dict, List import torch from ..sharding_strategy import OperationData, OperationDataType -from .node_handler import NodeHandler +from .node_handler import MetaInfoNodeHandler, NodeHandler from .registry import operator_registry from .strategy import ReshapeGenerator, StrategyGenerator @@ -13,7 +13,7 @@ __all__ = ['ReshapeHandler'] @operator_registry.register(torch.flatten) @operator_registry.register(torch.Tensor.unsqueeze) @operator_registry.register(torch.nn.AdaptiveAvgPool2d) -class ReshapeHandler(NodeHandler): +class ReshapeHandler(MetaInfoNodeHandler): """ A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape. """ diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py index bda160906..0362de780 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py @@ -3,7 +3,7 @@ from typing import Dict, List import torch from ..sharding_strategy import OperationData, OperationDataType -from .node_handler import NodeHandler +from .node_handler import MetaInfoNodeHandler, NodeHandler from .registry import operator_registry from .strategy import StrategyGenerator, UnaryElementwiseGenerator @@ -19,7 +19,7 @@ __all__ = ['UnaryElementwiseHandler'] @operator_registry.register(torch.nn.modules.dropout.Dropout) @operator_registry.register(torch.Tensor.contiguous) @operator_registry.register(torch.nn.functional.dropout) -class UnaryElementwiseHandler(NodeHandler): +class UnaryElementwiseHandler(MetaInfoNodeHandler): """ A UnaryElementwiseHandler which deals with the sharding strategies for UnaryElementwise Op. """ diff --git a/colossalai/fx/profiler/shard_utils.py b/colossalai/fx/profiler/shard_utils.py index a765e5055..34feefb43 100644 --- a/colossalai/fx/profiler/shard_utils.py +++ b/colossalai/fx/profiler/shard_utils.py @@ -100,7 +100,7 @@ def calculate_fwd_time(n: Node) -> float: fwd_time (float): the result of `fwd_time` """ # TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs - return n.meta["fwd_flop"] + return n.meta["fwd_time"] def calculate_bwd_time(n: Node) -> float: @@ -111,4 +111,4 @@ def calculate_bwd_time(n: Node) -> float: bwd_time (float): the result of `bwd_time` """ # TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs - return n.meta["bwd_flop"] + return n.meta["bwd_time"] diff --git a/colossalai/tensor/shape_consistency.py b/colossalai/tensor/shape_consistency.py index daf81034f..2831b10a3 100644 --- a/colossalai/tensor/shape_consistency.py +++ b/colossalai/tensor/shape_consistency.py @@ -441,6 +441,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): if discard_input: alloc_numel -= input_numel + return alloc_numel, peak_numel + def split_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): """analyze split memory footprint split will allocate memory for the output tensor if we don't apply shard on the first dimension of @@ -478,11 +480,13 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): # kind of weird, and I think we could ignore it for now. pass + return alloc_numel, peak_numel + def reduce_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): """ a dummy function for reduce memory footprint analysis, as the reduce action doesn't allocate extra memory """ - pass + return alloc_numel, peak_numel def all2all_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): """analyze all_to_all memory footprint @@ -508,11 +512,13 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): if discard_input: alloc_numel -= input_numel + return alloc_numel, peak_numel + def identity_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): """ a dummy function for identity memory footprint analysis, as the identity action doesn't allocate extra memory """ - pass + return alloc_numel, peak_numel pattern_to_func_dict = { CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: [gather_analysis, split_analysis], @@ -539,17 +545,18 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): for idx, action_spec_pair in enumerate(zip(fwd_actions, comm_action_sequence)): # the first forward comm action will not discard input fwd_action, comm_spec = action_spec_pair - if idx == 0: - fwd_action(comm_spec, False, fwd_alloc_numel, fwd_peak_numel) - else: - fwd_action(comm_spec, True, fwd_alloc_numel, fwd_peak_numel) + fwd_alloc_numel, fwd_peak_numel = fwd_action(comm_spec, False, fwd_alloc_numel, + fwd_peak_numel) if idx == 0 else fwd_action( + comm_spec, True, fwd_alloc_numel, fwd_peak_numel) # analyze memory footprint for backward comm actions sequence bwd_alloc_numel = 0 bwd_peak_numel = 0 for idx, action_spec_pair in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))): bwd_action, comm_spec = action_spec_pair - bwd_action(comm_spec, True, bwd_alloc_numel, bwd_peak_numel) + bwd_alloc_numel, bwd_peak_numel = bwd_action(comm_spec, False, bwd_alloc_numel, + bwd_peak_numel) if idx == 0 else bwd_action( + comm_spec, True, bwd_alloc_numel, bwd_peak_numel) fwd_mem = MemoryCost(activation=fwd_alloc_numel, temp=fwd_peak_numel - fwd_alloc_numel) bwd_mem = MemoryCost(activation=bwd_alloc_numel, temp=bwd_peak_numel - bwd_alloc_numel)