Browse Source

Merge pull request #2258 from hpcaitech/debug/ckpt-autoparallel

[autockpt] provide option for activation checkpoint search in SPMD solver
pull/2312/head
Boyuan Yao 2 years ago committed by GitHub
parent
commit
d45695d94e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 46
      colossalai/auto_parallel/checkpoint/ckpt_solver_base.py
  2. 6
      colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py
  3. 41
      colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py
  4. 5
      colossalai/auto_parallel/meta_profiler/constants.py
  5. 13
      colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py
  6. 26
      colossalai/auto_parallel/meta_profiler/metainfo.py
  7. 113
      colossalai/auto_parallel/passes/comm_metainfo_pass.py
  8. 8
      colossalai/auto_parallel/passes/constants.py
  9. 95
      colossalai/auto_parallel/passes/meta_info_prop.py
  10. 49
      colossalai/auto_parallel/passes/runtime_apply_pass.py
  11. 2
      colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py
  12. 49
      colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py
  13. 4
      colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py
  14. 4
      colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py
  15. 4
      colossalai/fx/profiler/shard_utils.py
  16. 21
      colossalai/tensor/shape_consistency.py

46
colossalai/auto_parallel/checkpoint/ckpt_solver_base.py

@ -5,8 +5,12 @@ from typing import Any, List
import torch
from torch.fx import Graph, Node
from colossalai.auto_parallel.passes.runtime_apply_pass import (
runtime_apply,
runtime_apply_for_iterable_object,
runtime_comm_spec_apply,
)
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
from colossalai.fx.profiler.memory_utils import is_inplace
__all___ = ['CheckpointSolverBase']
@ -31,10 +35,11 @@ class CheckpointSolverBase(ABC):
free_memory: float = -1.0,
requires_linearize: bool = False,
cnode: List[str] = None,
optim_multiplier: float = 1.0,
):
"""CheckpointSolver class will integrate information provided by the components
and use an existing solver to find a possible optimal strategies combination for
target computing graph.
"""``CheckpointSolverBase`` class will integrate information provided by the components
and use an existing solver to find a possible optimal strategies combination for target
computing graph.
Existing Solvers:
Chen's Greedy solver: https://arxiv.org/abs/1604.06174 (CheckpointSolverChen)
@ -45,9 +50,11 @@ class CheckpointSolverBase(ABC):
free_memory (float): Memory constraint for the solution.
requires_linearize (bool): Whether the graph needs to be linearized.
cnode (List[str], optional): Common node List, should be the subset of input. Default to None.
optim_multiplier (float, optional): The multiplier of extra weight storage for the
``torch.optim.Optimizer``. Default to 1.0.
Warnings:
`MetaInfoProp` should be done before constructing the solver. Meta information of the graph is required.
Meta information of the graph is required for any ``CheckpointSolver``.
"""
# super-dainiu: this graph is a temporary graph which can refer to
# the owning module, but we will return another deepcopy of it after
@ -57,13 +64,14 @@ class CheckpointSolverBase(ABC):
_copy_output(graph, self.graph)
self.graph.set_codegen(ActivationCheckpointCodeGen())
# check if `MetaInfoProp` is done
# check if has meta information
if any(len(node.meta) == 0 for node in self.graph.nodes):
raise RuntimeError(
"Nodes meta information hasn't been prepared! Please run MetaInfoProp before constructing the solver!")
"Nodes meta information hasn't been prepared! Please extract from graph before constructing the solver!"
)
self.free_memory = free_memory
self.parameter_size = _get_param_size(self.graph.owning_module)
# parameter memory = parameter size + optimizer extra weight storage
self.free_memory = free_memory - _get_param_size(self.graph.owning_module) * (optim_multiplier + 1)
self.cnode = cnode
self.requires_linearize = requires_linearize
if self.requires_linearize:
@ -93,7 +101,7 @@ class CheckpointSolverBase(ABC):
the actual 'node' in linearized manner.
Remarks:
Do merge the inplace ops into the previous node.
Do merge the inplace ops and shape-consistency ops into the previous node.
"""
# Common nodes are type of nodes that could be seen as attributes and remain
@ -131,7 +139,23 @@ class CheckpointSolverBase(ABC):
bool
"""
return not sum([v for _, v in deps.items()]) and not any(map(is_inplace, n.users))
def _is_inplace(n: Node):
"""Get the inplace argument from ``torch.fx.Node``
"""
inplace = False
if n.op == "call_function":
inplace = n.kwargs.get("inplace", False)
elif n.op == "call_module":
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
return inplace
def _is_shape_consistency(n: Node):
"""Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``)
"""
return n.target in [runtime_apply, runtime_apply_for_iterable_object, runtime_comm_spec_apply]
return not sum([v for _, v in deps.items()]) and not any(map(_is_inplace, n.users)) and not any(
map(_is_shape_consistency, n.users))
# make sure that item in cnode is valid
if self.cnode:

6
colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py

@ -19,9 +19,9 @@ class CheckpointSolverChen(CheckpointSolverBase):
Note that this algorithm targets at memory optimization only, using techniques in appendix A.
Usage:
Assume that we have a `GraphModule`, and we already applied the `MetaInfoProp`
Assume that we have a ``GraphModule``, and we have already done the extractions
to the graph to retrieve all information needed, then we could use the following
code to find a solution using `CheckpointSolverChen`:
code to find a solution using ``CheckpointSolverChen``:
>>> solver = CheckpointSolverChen(gm.graph)
>>> chen_graph = solver.solve()
>>> gm.graph = chen_graph # set the graph to a new graph
@ -74,7 +74,7 @@ class CheckpointSolverChen(CheckpointSolverBase):
def grid_search(self) -> Set:
"""
Search ckpt strategy with b = 0, then run the allocation algorithm again with b = xy.
Grid search over [2/2 b, 2 b] for ckpt_opt over num_grids as in appendix A.
Grid search over [2/2 b, 2 b] for ``ckpt_opt`` over ``num_grids`` as in appendix A.
"""
_, b_approx = self.run_chen_greedy(0)
b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2))

41
colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py

@ -4,6 +4,7 @@ from typing import Any, Dict, List, Tuple
from torch import Tensor
from torch.fx import Graph, Node
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
from colossalai.fx.profiler import (
activation_size,
@ -22,15 +23,20 @@ __all__ = ['CheckpointSolverRotor']
class CheckpointSolverRotor(CheckpointSolverBase):
def __init__(self, graph: Graph, free_memory: float = -1, cnode: List[str] = None, memory_slots: int = 500):
def __init__(self,
graph: Graph,
free_memory: float = -1,
cnode: List[str] = None,
memory_slots: int = 500,
optim_multiplier: float = 1.0):
"""This is the simple implementation of dynamic programming algorithm rotor
in https://hal.inria.fr/hal-02352969. Some code are adapted from
https://gitlab.inria.fr/hiepacs/rotor.
Usage:
Assume that we have a `GraphModule`, and we already applied the `MetaInfoProp`
Assume that we have a ``GraphModule``, and we have already done the extractions
to the graph to retrieve all information needed, then we could use the following
code to find a solution using `CheckpointSolverRotor`:
code to find a solution using ``CheckpointSolverRotor``:
>>> solver = CheckpointSolverRotor(gm.graph, free_memory=torch.cuda.mem_get_info(device=0)[0])
>>> rotor_graph = solver.solve(force_python=True) # otherwise use C solver
>>> gm.graph = rotor_graph # set the graph to a new graph
@ -41,8 +47,10 @@ class CheckpointSolverRotor(CheckpointSolverBase):
Use ``torch.cuda.mem_get_info(device=0)[0]`` to estimate the free_memory. Defaults to -1.
cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None.
memory_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500.
optim_multiplier (float, optional): The multiplier of extra weight storage for the
``torch.optim.Optimizer``. Default to 1.0.
"""
super().__init__(graph, free_memory, True, cnode)
super().__init__(graph, free_memory, True, cnode, optim_multiplier)
self.memory_slots = memory_slots
# construct chain
@ -128,16 +136,24 @@ class CheckpointSolverRotor(CheckpointSolverBase):
xbar = 0
ftime = 0
btime = 0
fwd_mem_peak = 0
for n in node:
assert isinstance(n, Node), f'{n} is not a Node'
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
if n.target == runtime_apply or n.target == runtime_comm_spec_apply:
# in this case we need to calculate memory usage directly based on the statics that hooked in node.meta
xbar += n.meta['fwd_mem_out']
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'])
else:
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'] + cls._extract_unused_output(n))
# minimum flop count is required
ftime += max(calculate_fwd_time(n), 1.0)
btime += max(calculate_bwd_time(n), 1.0)
x = calculate_fwd_out(node[-1])
xbar = max(x, xbar)
ftmp = cls._extract_ftmp(node)
ftmp = fwd_mem_peak - xbar
btmp = cls._extract_btmp(node)
return ftime, btime, x, xbar, ftmp, btmp
@ -151,10 +167,9 @@ class CheckpointSolverRotor(CheckpointSolverBase):
return input_tensors
@staticmethod
def _extract_ftmp(node: List[Node]) -> int:
"""Extract ftmp from a list of nodes"""
n = node[-1]
return activation_size(n.meta['fwd_out']) - calculate_fwd_out(n)
def _extract_unused_output(node: Node) -> int:
"""Extract unused output from `torch.fx.Node`"""
return activation_size(node.meta['fwd_out']) - calculate_fwd_out(node)
@staticmethod
def _extract_btmp(node: List[Node]) -> int:
@ -290,8 +305,8 @@ class CheckpointSolverRotor(CheckpointSolverBase):
lhs (int): The left index of the interval to backtrack.
rhs (int): The right index of the interval to backtrack.
budget (int): The memory budget for processing this interval.
cost_table (List[Any]): See `._compute_table()` for definitions
back_ptr (List[Any]): See `._compute_table()` for definitions
cost_table (List[Any]): See ``._compute_table()`` for definitions
back_ptr (List[Any]): See ``._compute_table()`` for definitions
Raises:
ValueError: Can not process the chain.
@ -332,7 +347,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
@staticmethod
def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
"""Annotate the nodes in the node_list with activation checkpoint from the sequence.
"""Annotate the nodes in the ``node_list`` with activation checkpoint from the sequence.
Args:
sequence (Sequence): The sequence of executing nodes with activation checkpoint annotations.

5
colossalai/auto_parallel/meta_profiler/constants.py

@ -5,8 +5,11 @@ import torch.nn as nn
from ..tensor_shard.constants import *
# list of inplace operations
# list of inplace module
INPLACE_MODULE = [nn.ReLU]
# list of inplace operations
INPLACE_OPS = [torch.flatten]
# list of operations that do not save forward activations
NO_SAVE_ACTIVATION = [torch.add, torch.sub, operator.add, operator.sub]

13
colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py

@ -24,26 +24,25 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
input_op_data, other_op_data = [arg for arg in args if arg.type != OperationDataType.OUTPUT]
input_op_data = [arg for arg in args if arg.type != OperationDataType.OUTPUT]
output_op_data = next(filter(lambda arg: arg.type == OperationDataType.OUTPUT, args))
# construct forward args for flop mapping
fwd_in_args = [input_op_data.data, other_op_data.data]
fwd_in_args = [opdata.data for opdata in input_op_data]
fwd_out_args = [output_op_data.data]
# calculate cost
# calculate compute cost
# NOTE: we set bwd_compute_cost two times of fwd_compute_cost in this case
fwd_compute_cost = flop_mapping[torch.ops.aten._adaptive_avg_pool2d.default](fwd_in_args, fwd_out_args)
fwd_compute_cost = flop_mapping[torch.ops.aten.add.Tensor](fwd_in_args, fwd_out_args)
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# calculate memory cost
param_mem_cost = activation_size(
[arg.data for arg in [input_op_data, other_op_data] if arg.type == OperationDataType.PARAM])
param_mem_cost = activation_size([arg.data for arg in input_op_data if arg.type == OperationDataType.PARAM])
fwd_mem_cost = MemoryCost(
activation=activation_size([input_op_data.data, output_op_data.data]),
activation=activation_size(output_op_data.data),
parameter=param_mem_cost,
)
bwd_mem_cost = MemoryCost(
@ -60,7 +59,7 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_op_data.data, device='meta')]
fwd_in = []
fwd_buffer = []
fwd_out = [torch.zeros_like(output_op_data.data, device='meta')]

26
colossalai/auto_parallel/meta_profiler/metainfo.py

@ -1,6 +1,5 @@
from typing import Callable, List
import numpy as np
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
@ -13,7 +12,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
)
from colossalai.tensor.sharding_spec import ShardingSpec
from .constants import INPLACE_MODULE, NO_SAVE_ACTIVATION
from .constants import INPLACE_MODULE, INPLACE_OPS, NO_SAVE_ACTIVATION
from .registry import meta_register
__all__ = ['MetaInfo']
@ -71,25 +70,12 @@ class MetaInfo:
if self._strategy is not None and self._target is not None:
self.compute_metainfo()
def compute_sharded_tensor(self, operation_data: OperationData, sharding_spec: ShardingSpec) -> torch.Tensor:
def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: ShardingSpec) -> torch.Tensor:
"""
Compute sharded meta tensor based on the given data and sharding spec.
Compute sharded opdata based on the given data and sharding spec.
"""
shard_sequnce = sharding_spec.sharding_sequence
device_mesh = sharding_spec.device_mesh
shape = operation_data.data.shape
new_shape = []
for dim, shard in zip(shape, shard_sequnce):
if shard.is_replica:
# replica
new_shape.append(dim)
else:
# sharded according to device_mesh shape
new_shape.append(dim // np.prod(np.array([device_mesh.mesh_shape[i] for i in shard.shard_list])))
return OperationData(name=operation_data.name,
data=torch.zeros(new_shape, device="meta"),
data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"),
type=operation_data.type,
logical_shape=operation_data.logical_shape)
@ -113,11 +99,13 @@ class MetaInfo:
save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION
# construct args for meta_func
args = [self.compute_sharded_tensor(k, v) for k, v in self._strategy.sharding_specs.items()]
args = [self.compute_sharded_opdata(k, v) for k, v in self._strategy.sharding_specs.items()]
# construct kwargs
if self.target in INPLACE_MODULE:
kwargs = {'inplace': self.target.inplace}
elif self.target in INPLACE_OPS:
kwargs = {'inplace': True}
else:
kwargs = {'inplace': False}

113
colossalai/auto_parallel/passes/comm_metainfo_pass.py

@ -0,0 +1,113 @@
from typing import Dict
import torch
from torch.fx import GraphModule
from torch.fx.node import Node
from colossalai.auto_parallel.meta_profiler import MetaInfo
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
from colossalai.tensor.comm_spec import CommSpec
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
shape_consistency_manager = ShapeConsistencyManager()
def _construct_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
target_sharding_spec: ShardingSpec) -> MetaInfo:
# get comm_action_sequence and total_cost from shape_consistency_manager
_, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
origin_sharding_spec, target_sharding_spec)
meta_info = MetaInfo()
# NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel
# get mem cost for MetaInfo
mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence)
# extract user that has _meta_data and extract element length
input_node = next(n for n in node._input_nodes if hasattr(n, '_meta_data'))
element_length = input_node._meta_data.element_size()
mem_cost.fwd.activation *= element_length
mem_cost.fwd.temp *= element_length
mem_cost.bwd.activation *= element_length
mem_cost.bwd.temp *= element_length
mem_cost.total.activation *= element_length
meta_info.memory_cost = mem_cost
# get computation cost for MetaInfo
meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length,
total_cost['backward'] * element_length,
total_cost['total'] * element_length)
# get tensor shape for MetaInfo
origin_sharding_spec: ShardingSpec
target_sharding_spec: ShardingSpec
input_shape = origin_sharding_spec.get_sharded_shape_per_device()
output_shape = target_sharding_spec.get_sharded_shape_per_device()
meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
meta_info.fwd_buffer = []
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
return meta_info
def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -> MetaInfo:
"""
This method is used to construct `MetaInto` for shape consistency node
"""
# extract node index and user node index
args = node.args
node_index, user_node_index = args[3], args[4]
origin_sharding_spec, target_sharding_spec = origin_spec_dict[node_index], sharding_spec_dict[node_index][
user_node_index]
return _construct_meta_info(node, origin_sharding_spec, target_sharding_spec)
def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> MetaInfo:
# extract node_index and op_data_name
node_index, op_data_name = node.args[2], node.args[3]
comm_action = comm_actions_dict[node_index][op_data_name]
if isinstance(comm_action.comm_spec, CommSpec):
# this case is for all_reduce, there will be no memory cost
meta_info = MetaInfo()
meta_info.memory_cost = TrainCycleItem(MemoryCost(), MemoryCost(), MemoryCost)
output_node = next(n for n in node.users if hasattr(n, '_meta_data'))
element_length = output_node._meta_data.element_size()
total_cost = comm_action.comm_spec.get_comm_cost()
meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length,
total_cost['backward'] * element_length,
total_cost['total'] * element_length)
input_shape = output_shape = comm_action.comm_spec.sharding_spec.get_sharded_shape_per_device()
meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
meta_info.fwd_buffer = []
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
else:
# this case will be handled by shape consistency manager
origin_sharding_spec, target_sharding_spec = comm_action.comm_spec['src_spec'], comm_action.comm_spec[
'tgt_spec']
meta_info = _construct_meta_info(node, origin_sharding_spec, target_sharding_spec)
return meta_info
def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, origin_spec_dict: Dict,
comm_actions_dict: Dict) -> GraphModule:
"""
The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph.
"""
for node in gm.graph.nodes:
if node.target == runtime_apply:
setattr(node, 'best_metainfo', _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict))
elif node.target == runtime_comm_spec_apply:
setattr(node, 'best_metainfo', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict))
else:
pass
return gm

8
colossalai/auto_parallel/passes/constants.py

@ -0,0 +1,8 @@
import torch
OUTPUT_SAVED_OPS = [torch.nn.functional.relu, torch.nn.functional.softmax, torch.flatten]
OUTPUT_SAVED_MOD = [
torch.nn.ReLU,
torch.nn.Softmax,
]

95
colossalai/auto_parallel/passes/meta_info_prop.py

@ -1,17 +1,16 @@
import uuid
from dataclasses import asdict
from typing import Any, Dict, List, NamedTuple, Tuple
from typing import List
import torch
import torch.fx
from torch.fx import GraphModule
from torch.fx.node import Argument, Node, Target
from torch.utils._pytree import tree_map
from torch.fx.node import Node
from colossalai.auto_parallel.meta_profiler import MetaInfo
from colossalai.fx._compatibility import compatibility, is_compatible_with_meta
from colossalai.auto_parallel.passes.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
from colossalai.fx._compatibility import compatibility
from colossalai.fx.profiler import GraphInfo
from colossalai.fx.profiler.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
def _normalize_tuple(x):
@ -47,7 +46,7 @@ class MetaInfoProp:
"""
Check if the node is inplace operation.
"""
if node.op == 'call_method':
if node.op == 'call_module':
return node.graph.owning_module.get_submodule(node.target).__class__ in OUTPUT_SAVED_MOD
elif node.op == "call_function":
return node.target in OUTPUT_SAVED_OPS
@ -68,7 +67,7 @@ class MetaInfoProp:
"""
graph_info = GraphInfo()
out = _normalize_tuple(getattr(node, '_meta_data', None))
graph_info.fwd_out = list(out)
graph_info.fwd_out = list(out) if out[0] is not None else []
node.meta = {**asdict(graph_info)}
@compatibility(is_backward_compatible=False)
@ -97,66 +96,70 @@ class MetaInfoProp:
"""
Handle other kind of nodes
"""
assert hasattr(node, 'best_metainfo'), f"Cannot find best_metainfo in node {node}"
assert hasattr(node, 'best_metainfo'), f"Cannot find best_metainfo in node {node}, {node.op}"
graph_info = GraphInfo()
meta_info = node.best_metainfo
meta_info: MetaInfo
# set data_ptr for input_tensor in MetaInfo class
input_tensor: List[torch.Tensor] = meta_info.fwd_in
buffer_tensor: List[torch.Tensor] = meta_info.fwd_buffer
output_tensor: List[torch.Tensor] = meta_info.fwd_out
input_tensors: List[torch.Tensor] = meta_info.fwd_in
buffer_tensors: List[torch.Tensor] = meta_info.fwd_buffer
output_tensors: List[torch.Tensor] = meta_info.fwd_out
if len(input_tensor) > 0:
if self._is_inplace(node):
# inplace operation will not create new tensor, and it only has one parent node
# TODO: Verify this observation
# set data_ptr for input_tensor, buffer_tensor and output_tensor of current node
parent_node = list(node._input_nodes.keys())[0]
parent_tensor = parent_node.meta.get("fwd_out")[0]
parent_tensor: torch.Tensor
for tensor in input_tensors:
tensor.data_ptr = parent_tensor.data_ptr
for tensor in buffer_tensors:
tensor.data_ptr = parent_tensor.data_ptr
for tensor in output_tensors:
tensor.data_ptr = parent_tensor.data_ptr
else:
for par in node._input_nodes:
if par.meta:
if len(par.meta["fwd_out"]) > 0:
# set data_ptr for the input_tensor of current node from the output_tensor of its parent node
for tensor in par.meta["fwd_out"]:
tensor: torch.Tensor
target_tensor = next(
(x for x in input_tensor if not x.data_ptr() and x.shape == tensor.shape), None)
target_tensor.data_ptr = tensor.data_ptr
# set data_ptr for the input_tensor of current node from the output_tensor of its parent node
for tensor in par.meta.get("fwd_out", []):
tensor: torch.Tensor
target_input_tensor = next(
(x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None)
if target_input_tensor is not None:
target_input_tensor.data_ptr = tensor.data_ptr
# set data_ptr for tensor in input_tensor that is not set
for tensor in input_tensor:
for tensor in input_tensors:
if not tensor.data_ptr():
self._set_data_ptr(tensor)
# attach it to graph_info
graph_info.fwd_in = input_tensor
if self._is_inplace(node):
# inplace operation will not create new tensor
# set data_ptr for buffer_tensor and output_tensor of current node
for tensor in input_tensor:
tensor: torch.Tensor
target_buffer_tensor = next((x for x in buffer_tensor if not x.data_ptr() and x.shape == tensor.shape),
None)
target_output_tensor = next((x for x in output_tensor if not x.data_ptr() and x.shape == tensor.shape),
None)
target_buffer_tensor.data_ptr = tensor.data_ptr
target_output_tensor.data_ptr = tensor.data_ptr
# attach them to graph_info
graph_info.fwd_tmp = buffer_tensor
graph_info.fwd_out = output_tensor
else:
# set data_ptr for buffer_tensor
for tensor in buffer_tensor:
for tensor in buffer_tensors:
self._set_data_ptr(tensor)
# attach it to graph_info
graph_info.fwd_tmp = buffer_tensor
# set data_ptr for output_tensor
for tensor in output_tensor:
for tensor in output_tensors:
self._set_data_ptr(tensor)
# attach it to graph_info
graph_info.fwd_out = output_tensor
# attach them to graph_info
graph_info.fwd_in = input_tensors
graph_info.fwd_tmp = buffer_tensors
graph_info.fwd_out = output_tensors
# fetch other memory informations
memory_cost = meta_info.memory_cost
graph_info.fwd_mem_tmp = memory_cost.fwd.temp
graph_info.fwd_mem_out = memory_cost.fwd.activation
graph_info.bwd_mem_tmp = memory_cost.bwd.temp
graph_info.bwd_mem_out = memory_cost.bwd.activation
# fetch flop information
# here we use fwd_time and bwd_time to deal with the case that
# communication cost is a float
compute_cost = meta_info.compute_cost
graph_info.fwd_time = compute_cost.fwd
graph_info.bwd_time = compute_cost.bwd
node.meta = {**asdict(graph_info)}

49
colossalai/auto_parallel/passes/runtime_apply_pass.py

@ -47,53 +47,6 @@ def runtime_apply_for_iterable_object(node: Node, origin_dict: Dict, input_dict:
return rst
def construct_meta_info(node: Node, user_node: Node) -> MetaInfo:
"""
This method is used to construct `MetaInto` for shape consistency node
TODO: Actually we could attain the cost information from resharding cost in node
handler, we should modify this part in the future.
"""
def compute_shape(sharding_spec: ShardingSpec):
shape = sharding_spec.entire_shape
new_shape = []
for dim, shard in sharding_spec.dim_partition_dict.items():
new_shape.append(shape[dim] // len(shard))
return new_shape
meta_info = MetaInfo()
origin_sharding_spec, target_sharding_spec = node.sharding_spec, user_node.best_strategy.get_sharding_spec_by_name(
str(node.name))
_, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
origin_sharding_spec, target_sharding_spec)
# NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel
# get mem cost for MetaInfo
mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence)
element_length = node._meta_data.element_size()
mem_cost.fwd.activation *= element_length
mem_cost.fwd.temp *= element_length
mem_cost.bwd.activation *= element_length
mem_cost.bwd.temp *= element_length
mem_cost.total.activation *= element_length
meta_info.memory_cost = mem_cost
# get computation cost for MetaInfo
compute_cost = TrainCycleItem(total_cost['forward'], total_cost['backward'], total_cost['total'])
meta_info.compute_cost = compute_cost
# get tensor shape for MetaInfo
input_shape = compute_shape(origin_sharding_spec)
output_shape = compute_shape(target_sharding_spec)
meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
meta_info.fwd_buffer = []
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
return meta_info
def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_index: int, op_data_name: str):
"""
This method will be invoked during runtime to apply the comm action following the instruction of comm spec.
@ -175,8 +128,6 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
runtime_apply,
args=(node, origin_dict_node, input_dict_node,
node_to_index_dict[node], user_node_index))
meta_info = construct_meta_info(node, user_node)
setattr(shape_consistency_node, 'best_metainfo', meta_info)
new_args = list(user_node.args)
new_kwargs = dict(user_node.kwargs)

2
colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py

@ -16,7 +16,7 @@ __all__ = ['BinaryElementwiseHandler']
@operator_registry.register(BCAST_FUNC_OP)
class BinaryElementwiseHandler(NodeHandler):
class BinaryElementwiseHandler(MetaInfoNodeHandler):
"""
An BinaryBcastOpHandler is a node handler which deals with operations which have two
operands and broadcasting occurs such as torch.add.

49
colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py

@ -4,7 +4,7 @@ from typing import Dict, List, Tuple, Union
import torch
from torch.fx.node import Node
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo, meta_register
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData,
OperationDataType,
@ -138,8 +138,7 @@ class NodeHandler(ABC):
return None
if self.node.op == 'call_module':
submod = self.node.graph.owning_module.get_submodule(self.node.target)
target = type(submod)
target = self.node.graph.owning_module.get_submodule(self.node.target)
elif self.node.op == 'call_function':
target = self.node.target
elif self.node.op == 'call_method':
@ -235,15 +234,19 @@ class MetaInfoNodeHandler(NodeHandler):
"""
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
target = self.get_target_function()
metainfo_vector = []
for strategy in self.strategies_vector:
metainfo = MetaInfo(strategy, target)
strategy.compute_cost = metainfo.compute_cost
strategy.memory_cost = metainfo.memory_cost
metainfo_vector.append(metainfo)
# attach metainfos to the handler
setattr(self, "metainfo_vector", metainfo_vector)
# Currently we haven't patched all the torch functions and modules, so if the target
# is not patched, we will use the default cost model to compute the cost.
# TODO: patch all torch functions and modules to make it clean
if meta_register.has(target.__class__) or meta_register.has(target):
metainfo_vector = []
for strategy in self.strategies_vector:
metainfo = MetaInfo(strategy, target)
strategy.compute_cost = metainfo.compute_cost
strategy.memory_cost = metainfo.memory_cost
metainfo_vector.append(metainfo)
# attach metainfos to the handler
setattr(self, "metainfo_vector", metainfo_vector)
return self.strategies_vector
@ -282,14 +285,18 @@ class MetaInfoModuleHandler(ModuleHandler):
"""
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
target = self.get_target_function()
metainfo_vector = []
for strategy in self.strategies_vector:
metainfo = MetaInfo(strategy, target)
strategy.compute_cost = metainfo.compute_cost
strategy.memory_cost = metainfo.memory_cost
metainfo_vector.append(metainfo)
# attach metainfos to the handler
setattr(self, "metainfo_vector", metainfo_vector)
# Currently we haven't patched all the torch functions and modules, so if the target
# is not patched, we will use the default cost model to compute the cost.
# TODO: patch all torch functions and modules to make it clean
if meta_register.has(target.__class__) or meta_register.has(target):
metainfo_vector = []
for strategy in self.strategies_vector:
metainfo = MetaInfo(strategy, target)
strategy.compute_cost = metainfo.compute_cost
strategy.memory_cost = metainfo.memory_cost
metainfo_vector.append(metainfo)
# attach metainfos to the handler
setattr(self, "metainfo_vector", metainfo_vector)
return self.strategies_vector

4
colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py

@ -3,7 +3,7 @@ from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from .node_handler import MetaInfoNodeHandler, NodeHandler
from .registry import operator_registry
from .strategy import ReshapeGenerator, StrategyGenerator
@ -13,7 +13,7 @@ __all__ = ['ReshapeHandler']
@operator_registry.register(torch.flatten)
@operator_registry.register(torch.Tensor.unsqueeze)
@operator_registry.register(torch.nn.AdaptiveAvgPool2d)
class ReshapeHandler(NodeHandler):
class ReshapeHandler(MetaInfoNodeHandler):
"""
A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
"""

4
colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py

@ -3,7 +3,7 @@ from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from .node_handler import MetaInfoNodeHandler, NodeHandler
from .registry import operator_registry
from .strategy import StrategyGenerator, UnaryElementwiseGenerator
@ -19,7 +19,7 @@ __all__ = ['UnaryElementwiseHandler']
@operator_registry.register(torch.nn.modules.dropout.Dropout)
@operator_registry.register(torch.Tensor.contiguous)
@operator_registry.register(torch.nn.functional.dropout)
class UnaryElementwiseHandler(NodeHandler):
class UnaryElementwiseHandler(MetaInfoNodeHandler):
"""
A UnaryElementwiseHandler which deals with the sharding strategies for UnaryElementwise Op.
"""

4
colossalai/fx/profiler/shard_utils.py

@ -100,7 +100,7 @@ def calculate_fwd_time(n: Node) -> float:
fwd_time (float): the result of `fwd_time`
"""
# TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs
return n.meta["fwd_flop"]
return n.meta["fwd_time"]
def calculate_bwd_time(n: Node) -> float:
@ -111,4 +111,4 @@ def calculate_bwd_time(n: Node) -> float:
bwd_time (float): the result of `bwd_time`
"""
# TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs
return n.meta["bwd_flop"]
return n.meta["bwd_time"]

21
colossalai/tensor/shape_consistency.py

@ -441,6 +441,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
if discard_input:
alloc_numel -= input_numel
return alloc_numel, peak_numel
def split_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
"""analyze split memory footprint
split will allocate memory for the output tensor if we don't apply shard on the first dimension of
@ -478,11 +480,13 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
# kind of weird, and I think we could ignore it for now.
pass
return alloc_numel, peak_numel
def reduce_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
"""
a dummy function for reduce memory footprint analysis, as the reduce action doesn't allocate extra memory
"""
pass
return alloc_numel, peak_numel
def all2all_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
"""analyze all_to_all memory footprint
@ -508,11 +512,13 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
if discard_input:
alloc_numel -= input_numel
return alloc_numel, peak_numel
def identity_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
"""
a dummy function for identity memory footprint analysis, as the identity action doesn't allocate extra memory
"""
pass
return alloc_numel, peak_numel
pattern_to_func_dict = {
CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: [gather_analysis, split_analysis],
@ -539,17 +545,18 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
for idx, action_spec_pair in enumerate(zip(fwd_actions, comm_action_sequence)):
# the first forward comm action will not discard input
fwd_action, comm_spec = action_spec_pair
if idx == 0:
fwd_action(comm_spec, False, fwd_alloc_numel, fwd_peak_numel)
else:
fwd_action(comm_spec, True, fwd_alloc_numel, fwd_peak_numel)
fwd_alloc_numel, fwd_peak_numel = fwd_action(comm_spec, False, fwd_alloc_numel,
fwd_peak_numel) if idx == 0 else fwd_action(
comm_spec, True, fwd_alloc_numel, fwd_peak_numel)
# analyze memory footprint for backward comm actions sequence
bwd_alloc_numel = 0
bwd_peak_numel = 0
for idx, action_spec_pair in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))):
bwd_action, comm_spec = action_spec_pair
bwd_action(comm_spec, True, bwd_alloc_numel, bwd_peak_numel)
bwd_alloc_numel, bwd_peak_numel = bwd_action(comm_spec, False, bwd_alloc_numel,
bwd_peak_numel) if idx == 0 else bwd_action(
comm_spec, True, bwd_alloc_numel, bwd_peak_numel)
fwd_mem = MemoryCost(activation=fwd_alloc_numel, temp=fwd_peak_numel - fwd_alloc_numel)
bwd_mem = MemoryCost(activation=bwd_alloc_numel, temp=bwd_peak_numel - bwd_alloc_numel)

Loading…
Cancel
Save