mirror of https://github.com/hpcaitech/ColossalAI
Merge pull request #2258 from hpcaitech/debug/ckpt-autoparallel
[autockpt] provide option for activation checkpoint search in SPMD solverpull/2312/head
commit
d45695d94e
|
@ -5,8 +5,12 @@ from typing import Any, List
|
|||
import torch
|
||||
from torch.fx import Graph, Node
|
||||
|
||||
from colossalai.auto_parallel.passes.runtime_apply_pass import (
|
||||
runtime_apply,
|
||||
runtime_apply_for_iterable_object,
|
||||
runtime_comm_spec_apply,
|
||||
)
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
|
||||
from colossalai.fx.profiler.memory_utils import is_inplace
|
||||
|
||||
__all___ = ['CheckpointSolverBase']
|
||||
|
||||
|
@ -31,10 +35,11 @@ class CheckpointSolverBase(ABC):
|
|||
free_memory: float = -1.0,
|
||||
requires_linearize: bool = False,
|
||||
cnode: List[str] = None,
|
||||
optim_multiplier: float = 1.0,
|
||||
):
|
||||
"""CheckpointSolver class will integrate information provided by the components
|
||||
and use an existing solver to find a possible optimal strategies combination for
|
||||
target computing graph.
|
||||
"""``CheckpointSolverBase`` class will integrate information provided by the components
|
||||
and use an existing solver to find a possible optimal strategies combination for target
|
||||
computing graph.
|
||||
|
||||
Existing Solvers:
|
||||
Chen's Greedy solver: https://arxiv.org/abs/1604.06174 (CheckpointSolverChen)
|
||||
|
@ -45,9 +50,11 @@ class CheckpointSolverBase(ABC):
|
|||
free_memory (float): Memory constraint for the solution.
|
||||
requires_linearize (bool): Whether the graph needs to be linearized.
|
||||
cnode (List[str], optional): Common node List, should be the subset of input. Default to None.
|
||||
optim_multiplier (float, optional): The multiplier of extra weight storage for the
|
||||
``torch.optim.Optimizer``. Default to 1.0.
|
||||
|
||||
Warnings:
|
||||
`MetaInfoProp` should be done before constructing the solver. Meta information of the graph is required.
|
||||
Meta information of the graph is required for any ``CheckpointSolver``.
|
||||
"""
|
||||
# super-dainiu: this graph is a temporary graph which can refer to
|
||||
# the owning module, but we will return another deepcopy of it after
|
||||
|
@ -57,13 +64,14 @@ class CheckpointSolverBase(ABC):
|
|||
_copy_output(graph, self.graph)
|
||||
self.graph.set_codegen(ActivationCheckpointCodeGen())
|
||||
|
||||
# check if `MetaInfoProp` is done
|
||||
# check if has meta information
|
||||
if any(len(node.meta) == 0 for node in self.graph.nodes):
|
||||
raise RuntimeError(
|
||||
"Nodes meta information hasn't been prepared! Please run MetaInfoProp before constructing the solver!")
|
||||
"Nodes meta information hasn't been prepared! Please extract from graph before constructing the solver!"
|
||||
)
|
||||
|
||||
self.free_memory = free_memory
|
||||
self.parameter_size = _get_param_size(self.graph.owning_module)
|
||||
# parameter memory = parameter size + optimizer extra weight storage
|
||||
self.free_memory = free_memory - _get_param_size(self.graph.owning_module) * (optim_multiplier + 1)
|
||||
self.cnode = cnode
|
||||
self.requires_linearize = requires_linearize
|
||||
if self.requires_linearize:
|
||||
|
@ -93,7 +101,7 @@ class CheckpointSolverBase(ABC):
|
|||
the actual 'node' in linearized manner.
|
||||
|
||||
Remarks:
|
||||
Do merge the inplace ops into the previous node.
|
||||
Do merge the inplace ops and shape-consistency ops into the previous node.
|
||||
"""
|
||||
|
||||
# Common nodes are type of nodes that could be seen as attributes and remain
|
||||
|
@ -131,7 +139,23 @@ class CheckpointSolverBase(ABC):
|
|||
bool
|
||||
"""
|
||||
|
||||
return not sum([v for _, v in deps.items()]) and not any(map(is_inplace, n.users))
|
||||
def _is_inplace(n: Node):
|
||||
"""Get the inplace argument from ``torch.fx.Node``
|
||||
"""
|
||||
inplace = False
|
||||
if n.op == "call_function":
|
||||
inplace = n.kwargs.get("inplace", False)
|
||||
elif n.op == "call_module":
|
||||
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
|
||||
return inplace
|
||||
|
||||
def _is_shape_consistency(n: Node):
|
||||
"""Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``)
|
||||
"""
|
||||
return n.target in [runtime_apply, runtime_apply_for_iterable_object, runtime_comm_spec_apply]
|
||||
|
||||
return not sum([v for _, v in deps.items()]) and not any(map(_is_inplace, n.users)) and not any(
|
||||
map(_is_shape_consistency, n.users))
|
||||
|
||||
# make sure that item in cnode is valid
|
||||
if self.cnode:
|
||||
|
|
|
@ -19,9 +19,9 @@ class CheckpointSolverChen(CheckpointSolverBase):
|
|||
Note that this algorithm targets at memory optimization only, using techniques in appendix A.
|
||||
|
||||
Usage:
|
||||
Assume that we have a `GraphModule`, and we already applied the `MetaInfoProp`
|
||||
Assume that we have a ``GraphModule``, and we have already done the extractions
|
||||
to the graph to retrieve all information needed, then we could use the following
|
||||
code to find a solution using `CheckpointSolverChen`:
|
||||
code to find a solution using ``CheckpointSolverChen``:
|
||||
>>> solver = CheckpointSolverChen(gm.graph)
|
||||
>>> chen_graph = solver.solve()
|
||||
>>> gm.graph = chen_graph # set the graph to a new graph
|
||||
|
@ -74,7 +74,7 @@ class CheckpointSolverChen(CheckpointSolverBase):
|
|||
def grid_search(self) -> Set:
|
||||
"""
|
||||
Search ckpt strategy with b = 0, then run the allocation algorithm again with b = √xy.
|
||||
Grid search over [√2/2 b, √2 b] for ckpt_opt over num_grids as in appendix A.
|
||||
Grid search over [√2/2 b, √2 b] for ``ckpt_opt`` over ``num_grids`` as in appendix A.
|
||||
"""
|
||||
_, b_approx = self.run_chen_greedy(0)
|
||||
b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2))
|
||||
|
|
|
@ -4,6 +4,7 @@ from typing import Any, Dict, List, Tuple
|
|||
from torch import Tensor
|
||||
from torch.fx import Graph, Node
|
||||
|
||||
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
|
||||
from colossalai.fx.profiler import (
|
||||
activation_size,
|
||||
|
@ -22,15 +23,20 @@ __all__ = ['CheckpointSolverRotor']
|
|||
|
||||
class CheckpointSolverRotor(CheckpointSolverBase):
|
||||
|
||||
def __init__(self, graph: Graph, free_memory: float = -1, cnode: List[str] = None, memory_slots: int = 500):
|
||||
def __init__(self,
|
||||
graph: Graph,
|
||||
free_memory: float = -1,
|
||||
cnode: List[str] = None,
|
||||
memory_slots: int = 500,
|
||||
optim_multiplier: float = 1.0):
|
||||
"""This is the simple implementation of dynamic programming algorithm rotor
|
||||
in https://hal.inria.fr/hal-02352969. Some code are adapted from
|
||||
https://gitlab.inria.fr/hiepacs/rotor.
|
||||
|
||||
Usage:
|
||||
Assume that we have a `GraphModule`, and we already applied the `MetaInfoProp`
|
||||
Assume that we have a ``GraphModule``, and we have already done the extractions
|
||||
to the graph to retrieve all information needed, then we could use the following
|
||||
code to find a solution using `CheckpointSolverRotor`:
|
||||
code to find a solution using ``CheckpointSolverRotor``:
|
||||
>>> solver = CheckpointSolverRotor(gm.graph, free_memory=torch.cuda.mem_get_info(device=0)[0])
|
||||
>>> rotor_graph = solver.solve(force_python=True) # otherwise use C solver
|
||||
>>> gm.graph = rotor_graph # set the graph to a new graph
|
||||
|
@ -41,8 +47,10 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
|||
Use ``torch.cuda.mem_get_info(device=0)[0]`` to estimate the free_memory. Defaults to -1.
|
||||
cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None.
|
||||
memory_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500.
|
||||
optim_multiplier (float, optional): The multiplier of extra weight storage for the
|
||||
``torch.optim.Optimizer``. Default to 1.0.
|
||||
"""
|
||||
super().__init__(graph, free_memory, True, cnode)
|
||||
super().__init__(graph, free_memory, True, cnode, optim_multiplier)
|
||||
self.memory_slots = memory_slots
|
||||
|
||||
# construct chain
|
||||
|
@ -128,16 +136,24 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
|||
xbar = 0
|
||||
ftime = 0
|
||||
btime = 0
|
||||
fwd_mem_peak = 0
|
||||
for n in node:
|
||||
assert isinstance(n, Node), f'{n} is not a Node'
|
||||
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
|
||||
if n.target == runtime_apply or n.target == runtime_comm_spec_apply:
|
||||
# in this case we need to calculate memory usage directly based on the statics that hooked in node.meta
|
||||
xbar += n.meta['fwd_mem_out']
|
||||
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'])
|
||||
else:
|
||||
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
|
||||
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'] + cls._extract_unused_output(n))
|
||||
|
||||
# minimum flop count is required
|
||||
ftime += max(calculate_fwd_time(n), 1.0)
|
||||
btime += max(calculate_bwd_time(n), 1.0)
|
||||
|
||||
x = calculate_fwd_out(node[-1])
|
||||
xbar = max(x, xbar)
|
||||
ftmp = cls._extract_ftmp(node)
|
||||
ftmp = fwd_mem_peak - xbar
|
||||
btmp = cls._extract_btmp(node)
|
||||
return ftime, btime, x, xbar, ftmp, btmp
|
||||
|
||||
|
@ -151,10 +167,9 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
|||
return input_tensors
|
||||
|
||||
@staticmethod
|
||||
def _extract_ftmp(node: List[Node]) -> int:
|
||||
"""Extract ftmp from a list of nodes"""
|
||||
n = node[-1]
|
||||
return activation_size(n.meta['fwd_out']) - calculate_fwd_out(n)
|
||||
def _extract_unused_output(node: Node) -> int:
|
||||
"""Extract unused output from `torch.fx.Node`"""
|
||||
return activation_size(node.meta['fwd_out']) - calculate_fwd_out(node)
|
||||
|
||||
@staticmethod
|
||||
def _extract_btmp(node: List[Node]) -> int:
|
||||
|
@ -290,8 +305,8 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
|||
lhs (int): The left index of the interval to backtrack.
|
||||
rhs (int): The right index of the interval to backtrack.
|
||||
budget (int): The memory budget for processing this interval.
|
||||
cost_table (List[Any]): See `._compute_table()` for definitions
|
||||
back_ptr (List[Any]): See `._compute_table()` for definitions
|
||||
cost_table (List[Any]): See ``._compute_table()`` for definitions
|
||||
back_ptr (List[Any]): See ``._compute_table()`` for definitions
|
||||
|
||||
Raises:
|
||||
ValueError: Can not process the chain.
|
||||
|
@ -332,7 +347,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
|||
|
||||
@staticmethod
|
||||
def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
|
||||
"""Annotate the nodes in the node_list with activation checkpoint from the sequence.
|
||||
"""Annotate the nodes in the ``node_list`` with activation checkpoint from the sequence.
|
||||
|
||||
Args:
|
||||
sequence (Sequence): The sequence of executing nodes with activation checkpoint annotations.
|
||||
|
|
|
@ -5,8 +5,11 @@ import torch.nn as nn
|
|||
|
||||
from ..tensor_shard.constants import *
|
||||
|
||||
# list of inplace operations
|
||||
# list of inplace module
|
||||
INPLACE_MODULE = [nn.ReLU]
|
||||
|
||||
# list of inplace operations
|
||||
INPLACE_OPS = [torch.flatten]
|
||||
|
||||
# list of operations that do not save forward activations
|
||||
NO_SAVE_ACTIVATION = [torch.add, torch.sub, operator.add, operator.sub]
|
||||
|
|
|
@ -24,26 +24,25 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train
|
|||
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
|
||||
"""
|
||||
|
||||
input_op_data, other_op_data = [arg for arg in args if arg.type != OperationDataType.OUTPUT]
|
||||
input_op_data = [arg for arg in args if arg.type != OperationDataType.OUTPUT]
|
||||
output_op_data = next(filter(lambda arg: arg.type == OperationDataType.OUTPUT, args))
|
||||
|
||||
# construct forward args for flop mapping
|
||||
fwd_in_args = [input_op_data.data, other_op_data.data]
|
||||
fwd_in_args = [opdata.data for opdata in input_op_data]
|
||||
fwd_out_args = [output_op_data.data]
|
||||
|
||||
# calculate cost
|
||||
|
||||
# calculate compute cost
|
||||
# NOTE: we set bwd_compute_cost two times of fwd_compute_cost in this case
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten._adaptive_avg_pool2d.default](fwd_in_args, fwd_out_args)
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.add.Tensor](fwd_in_args, fwd_out_args)
|
||||
bwd_compute_cost = fwd_compute_cost * 2
|
||||
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
|
||||
|
||||
# calculate memory cost
|
||||
param_mem_cost = activation_size(
|
||||
[arg.data for arg in [input_op_data, other_op_data] if arg.type == OperationDataType.PARAM])
|
||||
param_mem_cost = activation_size([arg.data for arg in input_op_data if arg.type == OperationDataType.PARAM])
|
||||
fwd_mem_cost = MemoryCost(
|
||||
activation=activation_size([input_op_data.data, output_op_data.data]),
|
||||
activation=activation_size(output_op_data.data),
|
||||
parameter=param_mem_cost,
|
||||
)
|
||||
bwd_mem_cost = MemoryCost(
|
||||
|
@ -60,7 +59,7 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train
|
|||
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||
|
||||
# store fwd_in, fwd_buffer, fwd_out
|
||||
fwd_in = [torch.zeros_like(input_op_data.data, device='meta')]
|
||||
fwd_in = []
|
||||
fwd_buffer = []
|
||||
fwd_out = [torch.zeros_like(output_op_data.data, device='meta')]
|
||||
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from typing import Callable, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
|
@ -13,7 +12,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
|||
)
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from .constants import INPLACE_MODULE, NO_SAVE_ACTIVATION
|
||||
from .constants import INPLACE_MODULE, INPLACE_OPS, NO_SAVE_ACTIVATION
|
||||
from .registry import meta_register
|
||||
|
||||
__all__ = ['MetaInfo']
|
||||
|
@ -71,25 +70,12 @@ class MetaInfo:
|
|||
if self._strategy is not None and self._target is not None:
|
||||
self.compute_metainfo()
|
||||
|
||||
def compute_sharded_tensor(self, operation_data: OperationData, sharding_spec: ShardingSpec) -> torch.Tensor:
|
||||
def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: ShardingSpec) -> torch.Tensor:
|
||||
"""
|
||||
Compute sharded meta tensor based on the given data and sharding spec.
|
||||
Compute sharded opdata based on the given data and sharding spec.
|
||||
"""
|
||||
shard_sequnce = sharding_spec.sharding_sequence
|
||||
device_mesh = sharding_spec.device_mesh
|
||||
shape = operation_data.data.shape
|
||||
|
||||
new_shape = []
|
||||
for dim, shard in zip(shape, shard_sequnce):
|
||||
if shard.is_replica:
|
||||
# replica
|
||||
new_shape.append(dim)
|
||||
else:
|
||||
# sharded according to device_mesh shape
|
||||
new_shape.append(dim // np.prod(np.array([device_mesh.mesh_shape[i] for i in shard.shard_list])))
|
||||
|
||||
return OperationData(name=operation_data.name,
|
||||
data=torch.zeros(new_shape, device="meta"),
|
||||
data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"),
|
||||
type=operation_data.type,
|
||||
logical_shape=operation_data.logical_shape)
|
||||
|
||||
|
@ -113,11 +99,13 @@ class MetaInfo:
|
|||
save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION
|
||||
|
||||
# construct args for meta_func
|
||||
args = [self.compute_sharded_tensor(k, v) for k, v in self._strategy.sharding_specs.items()]
|
||||
args = [self.compute_sharded_opdata(k, v) for k, v in self._strategy.sharding_specs.items()]
|
||||
|
||||
# construct kwargs
|
||||
if self.target in INPLACE_MODULE:
|
||||
kwargs = {'inplace': self.target.inplace}
|
||||
elif self.target in INPLACE_OPS:
|
||||
kwargs = {'inplace': True}
|
||||
else:
|
||||
kwargs = {'inplace': False}
|
||||
|
||||
|
|
|
@ -0,0 +1,113 @@
|
|||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.meta_profiler import MetaInfo
|
||||
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
|
||||
from colossalai.tensor.comm_spec import CommSpec
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
|
||||
def _construct_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
|
||||
target_sharding_spec: ShardingSpec) -> MetaInfo:
|
||||
# get comm_action_sequence and total_cost from shape_consistency_manager
|
||||
_, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
|
||||
origin_sharding_spec, target_sharding_spec)
|
||||
|
||||
meta_info = MetaInfo()
|
||||
# NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel
|
||||
# get mem cost for MetaInfo
|
||||
mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence)
|
||||
# extract user that has _meta_data and extract element length
|
||||
input_node = next(n for n in node._input_nodes if hasattr(n, '_meta_data'))
|
||||
element_length = input_node._meta_data.element_size()
|
||||
|
||||
mem_cost.fwd.activation *= element_length
|
||||
mem_cost.fwd.temp *= element_length
|
||||
mem_cost.bwd.activation *= element_length
|
||||
mem_cost.bwd.temp *= element_length
|
||||
mem_cost.total.activation *= element_length
|
||||
|
||||
meta_info.memory_cost = mem_cost
|
||||
|
||||
# get computation cost for MetaInfo
|
||||
meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length,
|
||||
total_cost['backward'] * element_length,
|
||||
total_cost['total'] * element_length)
|
||||
|
||||
# get tensor shape for MetaInfo
|
||||
origin_sharding_spec: ShardingSpec
|
||||
target_sharding_spec: ShardingSpec
|
||||
input_shape = origin_sharding_spec.get_sharded_shape_per_device()
|
||||
output_shape = target_sharding_spec.get_sharded_shape_per_device()
|
||||
|
||||
meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
|
||||
meta_info.fwd_buffer = []
|
||||
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
|
||||
|
||||
return meta_info
|
||||
|
||||
|
||||
def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -> MetaInfo:
|
||||
"""
|
||||
This method is used to construct `MetaInto` for shape consistency node
|
||||
"""
|
||||
|
||||
# extract node index and user node index
|
||||
args = node.args
|
||||
node_index, user_node_index = args[3], args[4]
|
||||
origin_sharding_spec, target_sharding_spec = origin_spec_dict[node_index], sharding_spec_dict[node_index][
|
||||
user_node_index]
|
||||
|
||||
return _construct_meta_info(node, origin_sharding_spec, target_sharding_spec)
|
||||
|
||||
|
||||
def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> MetaInfo:
|
||||
# extract node_index and op_data_name
|
||||
node_index, op_data_name = node.args[2], node.args[3]
|
||||
|
||||
comm_action = comm_actions_dict[node_index][op_data_name]
|
||||
if isinstance(comm_action.comm_spec, CommSpec):
|
||||
# this case is for all_reduce, there will be no memory cost
|
||||
meta_info = MetaInfo()
|
||||
meta_info.memory_cost = TrainCycleItem(MemoryCost(), MemoryCost(), MemoryCost)
|
||||
output_node = next(n for n in node.users if hasattr(n, '_meta_data'))
|
||||
element_length = output_node._meta_data.element_size()
|
||||
|
||||
total_cost = comm_action.comm_spec.get_comm_cost()
|
||||
meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length,
|
||||
total_cost['backward'] * element_length,
|
||||
total_cost['total'] * element_length)
|
||||
|
||||
input_shape = output_shape = comm_action.comm_spec.sharding_spec.get_sharded_shape_per_device()
|
||||
meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
|
||||
meta_info.fwd_buffer = []
|
||||
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
|
||||
else:
|
||||
# this case will be handled by shape consistency manager
|
||||
origin_sharding_spec, target_sharding_spec = comm_action.comm_spec['src_spec'], comm_action.comm_spec[
|
||||
'tgt_spec']
|
||||
meta_info = _construct_meta_info(node, origin_sharding_spec, target_sharding_spec)
|
||||
|
||||
return meta_info
|
||||
|
||||
|
||||
def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, origin_spec_dict: Dict,
|
||||
comm_actions_dict: Dict) -> GraphModule:
|
||||
"""
|
||||
The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph.
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if node.target == runtime_apply:
|
||||
setattr(node, 'best_metainfo', _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict))
|
||||
elif node.target == runtime_comm_spec_apply:
|
||||
setattr(node, 'best_metainfo', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict))
|
||||
else:
|
||||
pass
|
||||
return gm
|
|
@ -0,0 +1,8 @@
|
|||
import torch
|
||||
|
||||
OUTPUT_SAVED_OPS = [torch.nn.functional.relu, torch.nn.functional.softmax, torch.flatten]
|
||||
|
||||
OUTPUT_SAVED_MOD = [
|
||||
torch.nn.ReLU,
|
||||
torch.nn.Softmax,
|
||||
]
|
|
@ -1,17 +1,16 @@
|
|||
import uuid
|
||||
from dataclasses import asdict
|
||||
from typing import Any, Dict, List, NamedTuple, Tuple
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.node import Argument, Node, Target
|
||||
from torch.utils._pytree import tree_map
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.meta_profiler import MetaInfo
|
||||
from colossalai.fx._compatibility import compatibility, is_compatible_with_meta
|
||||
from colossalai.auto_parallel.passes.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
|
||||
from colossalai.fx._compatibility import compatibility
|
||||
from colossalai.fx.profiler import GraphInfo
|
||||
from colossalai.fx.profiler.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
|
||||
|
||||
|
||||
def _normalize_tuple(x):
|
||||
|
@ -47,7 +46,7 @@ class MetaInfoProp:
|
|||
"""
|
||||
Check if the node is inplace operation.
|
||||
"""
|
||||
if node.op == 'call_method':
|
||||
if node.op == 'call_module':
|
||||
return node.graph.owning_module.get_submodule(node.target).__class__ in OUTPUT_SAVED_MOD
|
||||
elif node.op == "call_function":
|
||||
return node.target in OUTPUT_SAVED_OPS
|
||||
|
@ -68,7 +67,7 @@ class MetaInfoProp:
|
|||
"""
|
||||
graph_info = GraphInfo()
|
||||
out = _normalize_tuple(getattr(node, '_meta_data', None))
|
||||
graph_info.fwd_out = list(out)
|
||||
graph_info.fwd_out = list(out) if out[0] is not None else []
|
||||
node.meta = {**asdict(graph_info)}
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
|
@ -97,66 +96,70 @@ class MetaInfoProp:
|
|||
"""
|
||||
Handle other kind of nodes
|
||||
"""
|
||||
assert hasattr(node, 'best_metainfo'), f"Cannot find best_metainfo in node {node}"
|
||||
assert hasattr(node, 'best_metainfo'), f"Cannot find best_metainfo in node {node}, {node.op}"
|
||||
graph_info = GraphInfo()
|
||||
meta_info = node.best_metainfo
|
||||
meta_info: MetaInfo
|
||||
|
||||
# set data_ptr for input_tensor in MetaInfo class
|
||||
input_tensor: List[torch.Tensor] = meta_info.fwd_in
|
||||
buffer_tensor: List[torch.Tensor] = meta_info.fwd_buffer
|
||||
output_tensor: List[torch.Tensor] = meta_info.fwd_out
|
||||
input_tensors: List[torch.Tensor] = meta_info.fwd_in
|
||||
buffer_tensors: List[torch.Tensor] = meta_info.fwd_buffer
|
||||
output_tensors: List[torch.Tensor] = meta_info.fwd_out
|
||||
|
||||
if len(input_tensor) > 0:
|
||||
if self._is_inplace(node):
|
||||
# inplace operation will not create new tensor, and it only has one parent node
|
||||
# TODO: Verify this observation
|
||||
# set data_ptr for input_tensor, buffer_tensor and output_tensor of current node
|
||||
parent_node = list(node._input_nodes.keys())[0]
|
||||
parent_tensor = parent_node.meta.get("fwd_out")[0]
|
||||
parent_tensor: torch.Tensor
|
||||
for tensor in input_tensors:
|
||||
tensor.data_ptr = parent_tensor.data_ptr
|
||||
for tensor in buffer_tensors:
|
||||
tensor.data_ptr = parent_tensor.data_ptr
|
||||
for tensor in output_tensors:
|
||||
tensor.data_ptr = parent_tensor.data_ptr
|
||||
|
||||
else:
|
||||
for par in node._input_nodes:
|
||||
if par.meta:
|
||||
if len(par.meta["fwd_out"]) > 0:
|
||||
# set data_ptr for the input_tensor of current node from the output_tensor of its parent node
|
||||
for tensor in par.meta["fwd_out"]:
|
||||
tensor: torch.Tensor
|
||||
target_tensor = next(
|
||||
(x for x in input_tensor if not x.data_ptr() and x.shape == tensor.shape), None)
|
||||
target_tensor.data_ptr = tensor.data_ptr
|
||||
# set data_ptr for the input_tensor of current node from the output_tensor of its parent node
|
||||
for tensor in par.meta.get("fwd_out", []):
|
||||
tensor: torch.Tensor
|
||||
target_input_tensor = next(
|
||||
(x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None)
|
||||
if target_input_tensor is not None:
|
||||
target_input_tensor.data_ptr = tensor.data_ptr
|
||||
|
||||
# set data_ptr for tensor in input_tensor that is not set
|
||||
for tensor in input_tensor:
|
||||
for tensor in input_tensors:
|
||||
if not tensor.data_ptr():
|
||||
self._set_data_ptr(tensor)
|
||||
|
||||
# attach it to graph_info
|
||||
graph_info.fwd_in = input_tensor
|
||||
|
||||
if self._is_inplace(node):
|
||||
# inplace operation will not create new tensor
|
||||
# set data_ptr for buffer_tensor and output_tensor of current node
|
||||
for tensor in input_tensor:
|
||||
tensor: torch.Tensor
|
||||
target_buffer_tensor = next((x for x in buffer_tensor if not x.data_ptr() and x.shape == tensor.shape),
|
||||
None)
|
||||
target_output_tensor = next((x for x in output_tensor if not x.data_ptr() and x.shape == tensor.shape),
|
||||
None)
|
||||
target_buffer_tensor.data_ptr = tensor.data_ptr
|
||||
target_output_tensor.data_ptr = tensor.data_ptr
|
||||
# attach them to graph_info
|
||||
graph_info.fwd_tmp = buffer_tensor
|
||||
graph_info.fwd_out = output_tensor
|
||||
|
||||
else:
|
||||
# set data_ptr for buffer_tensor
|
||||
for tensor in buffer_tensor:
|
||||
for tensor in buffer_tensors:
|
||||
self._set_data_ptr(tensor)
|
||||
# attach it to graph_info
|
||||
graph_info.fwd_tmp = buffer_tensor
|
||||
|
||||
# set data_ptr for output_tensor
|
||||
for tensor in output_tensor:
|
||||
for tensor in output_tensors:
|
||||
self._set_data_ptr(tensor)
|
||||
# attach it to graph_info
|
||||
graph_info.fwd_out = output_tensor
|
||||
|
||||
# attach them to graph_info
|
||||
graph_info.fwd_in = input_tensors
|
||||
graph_info.fwd_tmp = buffer_tensors
|
||||
graph_info.fwd_out = output_tensors
|
||||
|
||||
# fetch other memory informations
|
||||
memory_cost = meta_info.memory_cost
|
||||
graph_info.fwd_mem_tmp = memory_cost.fwd.temp
|
||||
graph_info.fwd_mem_out = memory_cost.fwd.activation
|
||||
graph_info.bwd_mem_tmp = memory_cost.bwd.temp
|
||||
graph_info.bwd_mem_out = memory_cost.bwd.activation
|
||||
|
||||
# fetch flop information
|
||||
# here we use fwd_time and bwd_time to deal with the case that
|
||||
# communication cost is a float
|
||||
compute_cost = meta_info.compute_cost
|
||||
graph_info.fwd_time = compute_cost.fwd
|
||||
graph_info.bwd_time = compute_cost.bwd
|
||||
|
||||
node.meta = {**asdict(graph_info)}
|
||||
|
|
|
@ -47,53 +47,6 @@ def runtime_apply_for_iterable_object(node: Node, origin_dict: Dict, input_dict:
|
|||
return rst
|
||||
|
||||
|
||||
def construct_meta_info(node: Node, user_node: Node) -> MetaInfo:
|
||||
"""
|
||||
This method is used to construct `MetaInto` for shape consistency node
|
||||
TODO: Actually we could attain the cost information from resharding cost in node
|
||||
handler, we should modify this part in the future.
|
||||
"""
|
||||
|
||||
def compute_shape(sharding_spec: ShardingSpec):
|
||||
shape = sharding_spec.entire_shape
|
||||
new_shape = []
|
||||
for dim, shard in sharding_spec.dim_partition_dict.items():
|
||||
new_shape.append(shape[dim] // len(shard))
|
||||
return new_shape
|
||||
|
||||
meta_info = MetaInfo()
|
||||
origin_sharding_spec, target_sharding_spec = node.sharding_spec, user_node.best_strategy.get_sharding_spec_by_name(
|
||||
str(node.name))
|
||||
_, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
|
||||
origin_sharding_spec, target_sharding_spec)
|
||||
|
||||
# NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel
|
||||
# get mem cost for MetaInfo
|
||||
mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence)
|
||||
element_length = node._meta_data.element_size()
|
||||
mem_cost.fwd.activation *= element_length
|
||||
mem_cost.fwd.temp *= element_length
|
||||
mem_cost.bwd.activation *= element_length
|
||||
mem_cost.bwd.temp *= element_length
|
||||
mem_cost.total.activation *= element_length
|
||||
|
||||
meta_info.memory_cost = mem_cost
|
||||
|
||||
# get computation cost for MetaInfo
|
||||
compute_cost = TrainCycleItem(total_cost['forward'], total_cost['backward'], total_cost['total'])
|
||||
meta_info.compute_cost = compute_cost
|
||||
|
||||
# get tensor shape for MetaInfo
|
||||
input_shape = compute_shape(origin_sharding_spec)
|
||||
output_shape = compute_shape(target_sharding_spec)
|
||||
|
||||
meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
|
||||
meta_info.fwd_buffer = []
|
||||
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
|
||||
|
||||
return meta_info
|
||||
|
||||
|
||||
def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_index: int, op_data_name: str):
|
||||
"""
|
||||
This method will be invoked during runtime to apply the comm action following the instruction of comm spec.
|
||||
|
@ -175,8 +128,6 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
|
|||
runtime_apply,
|
||||
args=(node, origin_dict_node, input_dict_node,
|
||||
node_to_index_dict[node], user_node_index))
|
||||
meta_info = construct_meta_info(node, user_node)
|
||||
setattr(shape_consistency_node, 'best_metainfo', meta_info)
|
||||
|
||||
new_args = list(user_node.args)
|
||||
new_kwargs = dict(user_node.kwargs)
|
||||
|
|
|
@ -16,7 +16,7 @@ __all__ = ['BinaryElementwiseHandler']
|
|||
|
||||
|
||||
@operator_registry.register(BCAST_FUNC_OP)
|
||||
class BinaryElementwiseHandler(NodeHandler):
|
||||
class BinaryElementwiseHandler(MetaInfoNodeHandler):
|
||||
"""
|
||||
An BinaryBcastOpHandler is a node handler which deals with operations which have two
|
||||
operands and broadcasting occurs such as torch.add.
|
||||
|
|
|
@ -4,7 +4,7 @@ from typing import Dict, List, Tuple, Union
|
|||
import torch
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo
|
||||
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo, meta_register
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
|
@ -138,8 +138,7 @@ class NodeHandler(ABC):
|
|||
return None
|
||||
|
||||
if self.node.op == 'call_module':
|
||||
submod = self.node.graph.owning_module.get_submodule(self.node.target)
|
||||
target = type(submod)
|
||||
target = self.node.graph.owning_module.get_submodule(self.node.target)
|
||||
elif self.node.op == 'call_function':
|
||||
target = self.node.target
|
||||
elif self.node.op == 'call_method':
|
||||
|
@ -235,15 +234,19 @@ class MetaInfoNodeHandler(NodeHandler):
|
|||
"""
|
||||
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
|
||||
target = self.get_target_function()
|
||||
metainfo_vector = []
|
||||
for strategy in self.strategies_vector:
|
||||
metainfo = MetaInfo(strategy, target)
|
||||
strategy.compute_cost = metainfo.compute_cost
|
||||
strategy.memory_cost = metainfo.memory_cost
|
||||
metainfo_vector.append(metainfo)
|
||||
# Currently we haven't patched all the torch functions and modules, so if the target
|
||||
# is not patched, we will use the default cost model to compute the cost.
|
||||
# TODO: patch all torch functions and modules to make it clean
|
||||
if meta_register.has(target.__class__) or meta_register.has(target):
|
||||
metainfo_vector = []
|
||||
for strategy in self.strategies_vector:
|
||||
metainfo = MetaInfo(strategy, target)
|
||||
strategy.compute_cost = metainfo.compute_cost
|
||||
strategy.memory_cost = metainfo.memory_cost
|
||||
metainfo_vector.append(metainfo)
|
||||
|
||||
# attach metainfos to the handler
|
||||
setattr(self, "metainfo_vector", metainfo_vector)
|
||||
# attach metainfos to the handler
|
||||
setattr(self, "metainfo_vector", metainfo_vector)
|
||||
|
||||
return self.strategies_vector
|
||||
|
||||
|
@ -282,14 +285,18 @@ class MetaInfoModuleHandler(ModuleHandler):
|
|||
"""
|
||||
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
|
||||
target = self.get_target_function()
|
||||
metainfo_vector = []
|
||||
for strategy in self.strategies_vector:
|
||||
metainfo = MetaInfo(strategy, target)
|
||||
strategy.compute_cost = metainfo.compute_cost
|
||||
strategy.memory_cost = metainfo.memory_cost
|
||||
metainfo_vector.append(metainfo)
|
||||
# Currently we haven't patched all the torch functions and modules, so if the target
|
||||
# is not patched, we will use the default cost model to compute the cost.
|
||||
# TODO: patch all torch functions and modules to make it clean
|
||||
if meta_register.has(target.__class__) or meta_register.has(target):
|
||||
metainfo_vector = []
|
||||
for strategy in self.strategies_vector:
|
||||
metainfo = MetaInfo(strategy, target)
|
||||
strategy.compute_cost = metainfo.compute_cost
|
||||
strategy.memory_cost = metainfo.memory_cost
|
||||
metainfo_vector.append(metainfo)
|
||||
|
||||
# attach metainfos to the handler
|
||||
setattr(self, "metainfo_vector", metainfo_vector)
|
||||
# attach metainfos to the handler
|
||||
setattr(self, "metainfo_vector", metainfo_vector)
|
||||
|
||||
return self.strategies_vector
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import Dict, List
|
|||
import torch
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType
|
||||
from .node_handler import NodeHandler
|
||||
from .node_handler import MetaInfoNodeHandler, NodeHandler
|
||||
from .registry import operator_registry
|
||||
from .strategy import ReshapeGenerator, StrategyGenerator
|
||||
|
||||
|
@ -13,7 +13,7 @@ __all__ = ['ReshapeHandler']
|
|||
@operator_registry.register(torch.flatten)
|
||||
@operator_registry.register(torch.Tensor.unsqueeze)
|
||||
@operator_registry.register(torch.nn.AdaptiveAvgPool2d)
|
||||
class ReshapeHandler(NodeHandler):
|
||||
class ReshapeHandler(MetaInfoNodeHandler):
|
||||
"""
|
||||
A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
|
||||
"""
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import Dict, List
|
|||
import torch
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType
|
||||
from .node_handler import NodeHandler
|
||||
from .node_handler import MetaInfoNodeHandler, NodeHandler
|
||||
from .registry import operator_registry
|
||||
from .strategy import StrategyGenerator, UnaryElementwiseGenerator
|
||||
|
||||
|
@ -19,7 +19,7 @@ __all__ = ['UnaryElementwiseHandler']
|
|||
@operator_registry.register(torch.nn.modules.dropout.Dropout)
|
||||
@operator_registry.register(torch.Tensor.contiguous)
|
||||
@operator_registry.register(torch.nn.functional.dropout)
|
||||
class UnaryElementwiseHandler(NodeHandler):
|
||||
class UnaryElementwiseHandler(MetaInfoNodeHandler):
|
||||
"""
|
||||
A UnaryElementwiseHandler which deals with the sharding strategies for UnaryElementwise Op.
|
||||
"""
|
||||
|
|
|
@ -100,7 +100,7 @@ def calculate_fwd_time(n: Node) -> float:
|
|||
fwd_time (float): the result of `fwd_time`
|
||||
"""
|
||||
# TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs
|
||||
return n.meta["fwd_flop"]
|
||||
return n.meta["fwd_time"]
|
||||
|
||||
|
||||
def calculate_bwd_time(n: Node) -> float:
|
||||
|
@ -111,4 +111,4 @@ def calculate_bwd_time(n: Node) -> float:
|
|||
bwd_time (float): the result of `bwd_time`
|
||||
"""
|
||||
# TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs
|
||||
return n.meta["bwd_flop"]
|
||||
return n.meta["bwd_time"]
|
||||
|
|
|
@ -441,6 +441,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
if discard_input:
|
||||
alloc_numel -= input_numel
|
||||
|
||||
return alloc_numel, peak_numel
|
||||
|
||||
def split_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
|
||||
"""analyze split memory footprint
|
||||
split will allocate memory for the output tensor if we don't apply shard on the first dimension of
|
||||
|
@ -478,11 +480,13 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
# kind of weird, and I think we could ignore it for now.
|
||||
pass
|
||||
|
||||
return alloc_numel, peak_numel
|
||||
|
||||
def reduce_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
|
||||
"""
|
||||
a dummy function for reduce memory footprint analysis, as the reduce action doesn't allocate extra memory
|
||||
"""
|
||||
pass
|
||||
return alloc_numel, peak_numel
|
||||
|
||||
def all2all_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
|
||||
"""analyze all_to_all memory footprint
|
||||
|
@ -508,11 +512,13 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
if discard_input:
|
||||
alloc_numel -= input_numel
|
||||
|
||||
return alloc_numel, peak_numel
|
||||
|
||||
def identity_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
|
||||
"""
|
||||
a dummy function for identity memory footprint analysis, as the identity action doesn't allocate extra memory
|
||||
"""
|
||||
pass
|
||||
return alloc_numel, peak_numel
|
||||
|
||||
pattern_to_func_dict = {
|
||||
CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: [gather_analysis, split_analysis],
|
||||
|
@ -539,17 +545,18 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
for idx, action_spec_pair in enumerate(zip(fwd_actions, comm_action_sequence)):
|
||||
# the first forward comm action will not discard input
|
||||
fwd_action, comm_spec = action_spec_pair
|
||||
if idx == 0:
|
||||
fwd_action(comm_spec, False, fwd_alloc_numel, fwd_peak_numel)
|
||||
else:
|
||||
fwd_action(comm_spec, True, fwd_alloc_numel, fwd_peak_numel)
|
||||
fwd_alloc_numel, fwd_peak_numel = fwd_action(comm_spec, False, fwd_alloc_numel,
|
||||
fwd_peak_numel) if idx == 0 else fwd_action(
|
||||
comm_spec, True, fwd_alloc_numel, fwd_peak_numel)
|
||||
|
||||
# analyze memory footprint for backward comm actions sequence
|
||||
bwd_alloc_numel = 0
|
||||
bwd_peak_numel = 0
|
||||
for idx, action_spec_pair in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))):
|
||||
bwd_action, comm_spec = action_spec_pair
|
||||
bwd_action(comm_spec, True, bwd_alloc_numel, bwd_peak_numel)
|
||||
bwd_alloc_numel, bwd_peak_numel = bwd_action(comm_spec, False, bwd_alloc_numel,
|
||||
bwd_peak_numel) if idx == 0 else bwd_action(
|
||||
comm_spec, True, bwd_alloc_numel, bwd_peak_numel)
|
||||
|
||||
fwd_mem = MemoryCost(activation=fwd_alloc_numel, temp=fwd_peak_numel - fwd_alloc_numel)
|
||||
bwd_mem = MemoryCost(activation=bwd_alloc_numel, temp=bwd_peak_numel - bwd_alloc_numel)
|
||||
|
|
Loading…
Reference in New Issue