mirror of https://github.com/hpcaitech/ColossalAI
Merge pull request #2258 from hpcaitech/debug/ckpt-autoparallel
[autockpt] provide option for activation checkpoint search in SPMD solverpull/2312/head
commit
d45695d94e
|
@ -5,8 +5,12 @@ from typing import Any, List
|
||||||
import torch
|
import torch
|
||||||
from torch.fx import Graph, Node
|
from torch.fx import Graph, Node
|
||||||
|
|
||||||
|
from colossalai.auto_parallel.passes.runtime_apply_pass import (
|
||||||
|
runtime_apply,
|
||||||
|
runtime_apply_for_iterable_object,
|
||||||
|
runtime_comm_spec_apply,
|
||||||
|
)
|
||||||
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
|
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
|
||||||
from colossalai.fx.profiler.memory_utils import is_inplace
|
|
||||||
|
|
||||||
__all___ = ['CheckpointSolverBase']
|
__all___ = ['CheckpointSolverBase']
|
||||||
|
|
||||||
|
@ -31,10 +35,11 @@ class CheckpointSolverBase(ABC):
|
||||||
free_memory: float = -1.0,
|
free_memory: float = -1.0,
|
||||||
requires_linearize: bool = False,
|
requires_linearize: bool = False,
|
||||||
cnode: List[str] = None,
|
cnode: List[str] = None,
|
||||||
|
optim_multiplier: float = 1.0,
|
||||||
):
|
):
|
||||||
"""CheckpointSolver class will integrate information provided by the components
|
"""``CheckpointSolverBase`` class will integrate information provided by the components
|
||||||
and use an existing solver to find a possible optimal strategies combination for
|
and use an existing solver to find a possible optimal strategies combination for target
|
||||||
target computing graph.
|
computing graph.
|
||||||
|
|
||||||
Existing Solvers:
|
Existing Solvers:
|
||||||
Chen's Greedy solver: https://arxiv.org/abs/1604.06174 (CheckpointSolverChen)
|
Chen's Greedy solver: https://arxiv.org/abs/1604.06174 (CheckpointSolverChen)
|
||||||
|
@ -45,9 +50,11 @@ class CheckpointSolverBase(ABC):
|
||||||
free_memory (float): Memory constraint for the solution.
|
free_memory (float): Memory constraint for the solution.
|
||||||
requires_linearize (bool): Whether the graph needs to be linearized.
|
requires_linearize (bool): Whether the graph needs to be linearized.
|
||||||
cnode (List[str], optional): Common node List, should be the subset of input. Default to None.
|
cnode (List[str], optional): Common node List, should be the subset of input. Default to None.
|
||||||
|
optim_multiplier (float, optional): The multiplier of extra weight storage for the
|
||||||
|
``torch.optim.Optimizer``. Default to 1.0.
|
||||||
|
|
||||||
Warnings:
|
Warnings:
|
||||||
`MetaInfoProp` should be done before constructing the solver. Meta information of the graph is required.
|
Meta information of the graph is required for any ``CheckpointSolver``.
|
||||||
"""
|
"""
|
||||||
# super-dainiu: this graph is a temporary graph which can refer to
|
# super-dainiu: this graph is a temporary graph which can refer to
|
||||||
# the owning module, but we will return another deepcopy of it after
|
# the owning module, but we will return another deepcopy of it after
|
||||||
|
@ -57,13 +64,14 @@ class CheckpointSolverBase(ABC):
|
||||||
_copy_output(graph, self.graph)
|
_copy_output(graph, self.graph)
|
||||||
self.graph.set_codegen(ActivationCheckpointCodeGen())
|
self.graph.set_codegen(ActivationCheckpointCodeGen())
|
||||||
|
|
||||||
# check if `MetaInfoProp` is done
|
# check if has meta information
|
||||||
if any(len(node.meta) == 0 for node in self.graph.nodes):
|
if any(len(node.meta) == 0 for node in self.graph.nodes):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Nodes meta information hasn't been prepared! Please run MetaInfoProp before constructing the solver!")
|
"Nodes meta information hasn't been prepared! Please extract from graph before constructing the solver!"
|
||||||
|
)
|
||||||
|
|
||||||
self.free_memory = free_memory
|
# parameter memory = parameter size + optimizer extra weight storage
|
||||||
self.parameter_size = _get_param_size(self.graph.owning_module)
|
self.free_memory = free_memory - _get_param_size(self.graph.owning_module) * (optim_multiplier + 1)
|
||||||
self.cnode = cnode
|
self.cnode = cnode
|
||||||
self.requires_linearize = requires_linearize
|
self.requires_linearize = requires_linearize
|
||||||
if self.requires_linearize:
|
if self.requires_linearize:
|
||||||
|
@ -93,7 +101,7 @@ class CheckpointSolverBase(ABC):
|
||||||
the actual 'node' in linearized manner.
|
the actual 'node' in linearized manner.
|
||||||
|
|
||||||
Remarks:
|
Remarks:
|
||||||
Do merge the inplace ops into the previous node.
|
Do merge the inplace ops and shape-consistency ops into the previous node.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Common nodes are type of nodes that could be seen as attributes and remain
|
# Common nodes are type of nodes that could be seen as attributes and remain
|
||||||
|
@ -131,7 +139,23 @@ class CheckpointSolverBase(ABC):
|
||||||
bool
|
bool
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return not sum([v for _, v in deps.items()]) and not any(map(is_inplace, n.users))
|
def _is_inplace(n: Node):
|
||||||
|
"""Get the inplace argument from ``torch.fx.Node``
|
||||||
|
"""
|
||||||
|
inplace = False
|
||||||
|
if n.op == "call_function":
|
||||||
|
inplace = n.kwargs.get("inplace", False)
|
||||||
|
elif n.op == "call_module":
|
||||||
|
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
|
||||||
|
return inplace
|
||||||
|
|
||||||
|
def _is_shape_consistency(n: Node):
|
||||||
|
"""Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``)
|
||||||
|
"""
|
||||||
|
return n.target in [runtime_apply, runtime_apply_for_iterable_object, runtime_comm_spec_apply]
|
||||||
|
|
||||||
|
return not sum([v for _, v in deps.items()]) and not any(map(_is_inplace, n.users)) and not any(
|
||||||
|
map(_is_shape_consistency, n.users))
|
||||||
|
|
||||||
# make sure that item in cnode is valid
|
# make sure that item in cnode is valid
|
||||||
if self.cnode:
|
if self.cnode:
|
||||||
|
|
|
@ -19,9 +19,9 @@ class CheckpointSolverChen(CheckpointSolverBase):
|
||||||
Note that this algorithm targets at memory optimization only, using techniques in appendix A.
|
Note that this algorithm targets at memory optimization only, using techniques in appendix A.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
Assume that we have a `GraphModule`, and we already applied the `MetaInfoProp`
|
Assume that we have a ``GraphModule``, and we have already done the extractions
|
||||||
to the graph to retrieve all information needed, then we could use the following
|
to the graph to retrieve all information needed, then we could use the following
|
||||||
code to find a solution using `CheckpointSolverChen`:
|
code to find a solution using ``CheckpointSolverChen``:
|
||||||
>>> solver = CheckpointSolverChen(gm.graph)
|
>>> solver = CheckpointSolverChen(gm.graph)
|
||||||
>>> chen_graph = solver.solve()
|
>>> chen_graph = solver.solve()
|
||||||
>>> gm.graph = chen_graph # set the graph to a new graph
|
>>> gm.graph = chen_graph # set the graph to a new graph
|
||||||
|
@ -74,7 +74,7 @@ class CheckpointSolverChen(CheckpointSolverBase):
|
||||||
def grid_search(self) -> Set:
|
def grid_search(self) -> Set:
|
||||||
"""
|
"""
|
||||||
Search ckpt strategy with b = 0, then run the allocation algorithm again with b = √xy.
|
Search ckpt strategy with b = 0, then run the allocation algorithm again with b = √xy.
|
||||||
Grid search over [√2/2 b, √2 b] for ckpt_opt over num_grids as in appendix A.
|
Grid search over [√2/2 b, √2 b] for ``ckpt_opt`` over ``num_grids`` as in appendix A.
|
||||||
"""
|
"""
|
||||||
_, b_approx = self.run_chen_greedy(0)
|
_, b_approx = self.run_chen_greedy(0)
|
||||||
b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2))
|
b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2))
|
||||||
|
|
|
@ -4,6 +4,7 @@ from typing import Any, Dict, List, Tuple
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.fx import Graph, Node
|
from torch.fx import Graph, Node
|
||||||
|
|
||||||
|
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply
|
||||||
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
|
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
|
||||||
from colossalai.fx.profiler import (
|
from colossalai.fx.profiler import (
|
||||||
activation_size,
|
activation_size,
|
||||||
|
@ -22,15 +23,20 @@ __all__ = ['CheckpointSolverRotor']
|
||||||
|
|
||||||
class CheckpointSolverRotor(CheckpointSolverBase):
|
class CheckpointSolverRotor(CheckpointSolverBase):
|
||||||
|
|
||||||
def __init__(self, graph: Graph, free_memory: float = -1, cnode: List[str] = None, memory_slots: int = 500):
|
def __init__(self,
|
||||||
|
graph: Graph,
|
||||||
|
free_memory: float = -1,
|
||||||
|
cnode: List[str] = None,
|
||||||
|
memory_slots: int = 500,
|
||||||
|
optim_multiplier: float = 1.0):
|
||||||
"""This is the simple implementation of dynamic programming algorithm rotor
|
"""This is the simple implementation of dynamic programming algorithm rotor
|
||||||
in https://hal.inria.fr/hal-02352969. Some code are adapted from
|
in https://hal.inria.fr/hal-02352969. Some code are adapted from
|
||||||
https://gitlab.inria.fr/hiepacs/rotor.
|
https://gitlab.inria.fr/hiepacs/rotor.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
Assume that we have a `GraphModule`, and we already applied the `MetaInfoProp`
|
Assume that we have a ``GraphModule``, and we have already done the extractions
|
||||||
to the graph to retrieve all information needed, then we could use the following
|
to the graph to retrieve all information needed, then we could use the following
|
||||||
code to find a solution using `CheckpointSolverRotor`:
|
code to find a solution using ``CheckpointSolverRotor``:
|
||||||
>>> solver = CheckpointSolverRotor(gm.graph, free_memory=torch.cuda.mem_get_info(device=0)[0])
|
>>> solver = CheckpointSolverRotor(gm.graph, free_memory=torch.cuda.mem_get_info(device=0)[0])
|
||||||
>>> rotor_graph = solver.solve(force_python=True) # otherwise use C solver
|
>>> rotor_graph = solver.solve(force_python=True) # otherwise use C solver
|
||||||
>>> gm.graph = rotor_graph # set the graph to a new graph
|
>>> gm.graph = rotor_graph # set the graph to a new graph
|
||||||
|
@ -41,8 +47,10 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
||||||
Use ``torch.cuda.mem_get_info(device=0)[0]`` to estimate the free_memory. Defaults to -1.
|
Use ``torch.cuda.mem_get_info(device=0)[0]`` to estimate the free_memory. Defaults to -1.
|
||||||
cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None.
|
cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None.
|
||||||
memory_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500.
|
memory_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500.
|
||||||
|
optim_multiplier (float, optional): The multiplier of extra weight storage for the
|
||||||
|
``torch.optim.Optimizer``. Default to 1.0.
|
||||||
"""
|
"""
|
||||||
super().__init__(graph, free_memory, True, cnode)
|
super().__init__(graph, free_memory, True, cnode, optim_multiplier)
|
||||||
self.memory_slots = memory_slots
|
self.memory_slots = memory_slots
|
||||||
|
|
||||||
# construct chain
|
# construct chain
|
||||||
|
@ -128,16 +136,24 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
||||||
xbar = 0
|
xbar = 0
|
||||||
ftime = 0
|
ftime = 0
|
||||||
btime = 0
|
btime = 0
|
||||||
|
fwd_mem_peak = 0
|
||||||
for n in node:
|
for n in node:
|
||||||
assert isinstance(n, Node), f'{n} is not a Node'
|
assert isinstance(n, Node), f'{n} is not a Node'
|
||||||
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
|
if n.target == runtime_apply or n.target == runtime_comm_spec_apply:
|
||||||
|
# in this case we need to calculate memory usage directly based on the statics that hooked in node.meta
|
||||||
|
xbar += n.meta['fwd_mem_out']
|
||||||
|
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'])
|
||||||
|
else:
|
||||||
|
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
|
||||||
|
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'] + cls._extract_unused_output(n))
|
||||||
|
|
||||||
# minimum flop count is required
|
# minimum flop count is required
|
||||||
ftime += max(calculate_fwd_time(n), 1.0)
|
ftime += max(calculate_fwd_time(n), 1.0)
|
||||||
btime += max(calculate_bwd_time(n), 1.0)
|
btime += max(calculate_bwd_time(n), 1.0)
|
||||||
|
|
||||||
x = calculate_fwd_out(node[-1])
|
x = calculate_fwd_out(node[-1])
|
||||||
xbar = max(x, xbar)
|
xbar = max(x, xbar)
|
||||||
ftmp = cls._extract_ftmp(node)
|
ftmp = fwd_mem_peak - xbar
|
||||||
btmp = cls._extract_btmp(node)
|
btmp = cls._extract_btmp(node)
|
||||||
return ftime, btime, x, xbar, ftmp, btmp
|
return ftime, btime, x, xbar, ftmp, btmp
|
||||||
|
|
||||||
|
@ -151,10 +167,9 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
||||||
return input_tensors
|
return input_tensors
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_ftmp(node: List[Node]) -> int:
|
def _extract_unused_output(node: Node) -> int:
|
||||||
"""Extract ftmp from a list of nodes"""
|
"""Extract unused output from `torch.fx.Node`"""
|
||||||
n = node[-1]
|
return activation_size(node.meta['fwd_out']) - calculate_fwd_out(node)
|
||||||
return activation_size(n.meta['fwd_out']) - calculate_fwd_out(n)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_btmp(node: List[Node]) -> int:
|
def _extract_btmp(node: List[Node]) -> int:
|
||||||
|
@ -290,8 +305,8 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
||||||
lhs (int): The left index of the interval to backtrack.
|
lhs (int): The left index of the interval to backtrack.
|
||||||
rhs (int): The right index of the interval to backtrack.
|
rhs (int): The right index of the interval to backtrack.
|
||||||
budget (int): The memory budget for processing this interval.
|
budget (int): The memory budget for processing this interval.
|
||||||
cost_table (List[Any]): See `._compute_table()` for definitions
|
cost_table (List[Any]): See ``._compute_table()`` for definitions
|
||||||
back_ptr (List[Any]): See `._compute_table()` for definitions
|
back_ptr (List[Any]): See ``._compute_table()`` for definitions
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: Can not process the chain.
|
ValueError: Can not process the chain.
|
||||||
|
@ -332,7 +347,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
|
def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
|
||||||
"""Annotate the nodes in the node_list with activation checkpoint from the sequence.
|
"""Annotate the nodes in the ``node_list`` with activation checkpoint from the sequence.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sequence (Sequence): The sequence of executing nodes with activation checkpoint annotations.
|
sequence (Sequence): The sequence of executing nodes with activation checkpoint annotations.
|
||||||
|
|
|
@ -5,8 +5,11 @@ import torch.nn as nn
|
||||||
|
|
||||||
from ..tensor_shard.constants import *
|
from ..tensor_shard.constants import *
|
||||||
|
|
||||||
# list of inplace operations
|
# list of inplace module
|
||||||
INPLACE_MODULE = [nn.ReLU]
|
INPLACE_MODULE = [nn.ReLU]
|
||||||
|
|
||||||
|
# list of inplace operations
|
||||||
|
INPLACE_OPS = [torch.flatten]
|
||||||
|
|
||||||
# list of operations that do not save forward activations
|
# list of operations that do not save forward activations
|
||||||
NO_SAVE_ACTIVATION = [torch.add, torch.sub, operator.add, operator.sub]
|
NO_SAVE_ACTIVATION = [torch.add, torch.sub, operator.add, operator.sub]
|
||||||
|
|
|
@ -24,26 +24,25 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train
|
||||||
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
|
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
|
||||||
"""
|
"""
|
||||||
|
|
||||||
input_op_data, other_op_data = [arg for arg in args if arg.type != OperationDataType.OUTPUT]
|
input_op_data = [arg for arg in args if arg.type != OperationDataType.OUTPUT]
|
||||||
output_op_data = next(filter(lambda arg: arg.type == OperationDataType.OUTPUT, args))
|
output_op_data = next(filter(lambda arg: arg.type == OperationDataType.OUTPUT, args))
|
||||||
|
|
||||||
# construct forward args for flop mapping
|
# construct forward args for flop mapping
|
||||||
fwd_in_args = [input_op_data.data, other_op_data.data]
|
fwd_in_args = [opdata.data for opdata in input_op_data]
|
||||||
fwd_out_args = [output_op_data.data]
|
fwd_out_args = [output_op_data.data]
|
||||||
|
|
||||||
# calculate cost
|
# calculate cost
|
||||||
|
|
||||||
# calculate compute cost
|
# calculate compute cost
|
||||||
# NOTE: we set bwd_compute_cost two times of fwd_compute_cost in this case
|
# NOTE: we set bwd_compute_cost two times of fwd_compute_cost in this case
|
||||||
fwd_compute_cost = flop_mapping[torch.ops.aten._adaptive_avg_pool2d.default](fwd_in_args, fwd_out_args)
|
fwd_compute_cost = flop_mapping[torch.ops.aten.add.Tensor](fwd_in_args, fwd_out_args)
|
||||||
bwd_compute_cost = fwd_compute_cost * 2
|
bwd_compute_cost = fwd_compute_cost * 2
|
||||||
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
|
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
|
||||||
|
|
||||||
# calculate memory cost
|
# calculate memory cost
|
||||||
param_mem_cost = activation_size(
|
param_mem_cost = activation_size([arg.data for arg in input_op_data if arg.type == OperationDataType.PARAM])
|
||||||
[arg.data for arg in [input_op_data, other_op_data] if arg.type == OperationDataType.PARAM])
|
|
||||||
fwd_mem_cost = MemoryCost(
|
fwd_mem_cost = MemoryCost(
|
||||||
activation=activation_size([input_op_data.data, output_op_data.data]),
|
activation=activation_size(output_op_data.data),
|
||||||
parameter=param_mem_cost,
|
parameter=param_mem_cost,
|
||||||
)
|
)
|
||||||
bwd_mem_cost = MemoryCost(
|
bwd_mem_cost = MemoryCost(
|
||||||
|
@ -60,7 +59,7 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train
|
||||||
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||||
|
|
||||||
# store fwd_in, fwd_buffer, fwd_out
|
# store fwd_in, fwd_buffer, fwd_out
|
||||||
fwd_in = [torch.zeros_like(input_op_data.data, device='meta')]
|
fwd_in = []
|
||||||
fwd_buffer = []
|
fwd_buffer = []
|
||||||
fwd_out = [torch.zeros_like(output_op_data.data, device='meta')]
|
fwd_out = [torch.zeros_like(output_op_data.data, device='meta')]
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
from typing import Callable, List
|
from typing import Callable, List
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||||
|
@ -13,7 +12,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||||
)
|
)
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
|
||||||
from .constants import INPLACE_MODULE, NO_SAVE_ACTIVATION
|
from .constants import INPLACE_MODULE, INPLACE_OPS, NO_SAVE_ACTIVATION
|
||||||
from .registry import meta_register
|
from .registry import meta_register
|
||||||
|
|
||||||
__all__ = ['MetaInfo']
|
__all__ = ['MetaInfo']
|
||||||
|
@ -71,25 +70,12 @@ class MetaInfo:
|
||||||
if self._strategy is not None and self._target is not None:
|
if self._strategy is not None and self._target is not None:
|
||||||
self.compute_metainfo()
|
self.compute_metainfo()
|
||||||
|
|
||||||
def compute_sharded_tensor(self, operation_data: OperationData, sharding_spec: ShardingSpec) -> torch.Tensor:
|
def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: ShardingSpec) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Compute sharded meta tensor based on the given data and sharding spec.
|
Compute sharded opdata based on the given data and sharding spec.
|
||||||
"""
|
"""
|
||||||
shard_sequnce = sharding_spec.sharding_sequence
|
|
||||||
device_mesh = sharding_spec.device_mesh
|
|
||||||
shape = operation_data.data.shape
|
|
||||||
|
|
||||||
new_shape = []
|
|
||||||
for dim, shard in zip(shape, shard_sequnce):
|
|
||||||
if shard.is_replica:
|
|
||||||
# replica
|
|
||||||
new_shape.append(dim)
|
|
||||||
else:
|
|
||||||
# sharded according to device_mesh shape
|
|
||||||
new_shape.append(dim // np.prod(np.array([device_mesh.mesh_shape[i] for i in shard.shard_list])))
|
|
||||||
|
|
||||||
return OperationData(name=operation_data.name,
|
return OperationData(name=operation_data.name,
|
||||||
data=torch.zeros(new_shape, device="meta"),
|
data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"),
|
||||||
type=operation_data.type,
|
type=operation_data.type,
|
||||||
logical_shape=operation_data.logical_shape)
|
logical_shape=operation_data.logical_shape)
|
||||||
|
|
||||||
|
@ -113,11 +99,13 @@ class MetaInfo:
|
||||||
save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION
|
save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION
|
||||||
|
|
||||||
# construct args for meta_func
|
# construct args for meta_func
|
||||||
args = [self.compute_sharded_tensor(k, v) for k, v in self._strategy.sharding_specs.items()]
|
args = [self.compute_sharded_opdata(k, v) for k, v in self._strategy.sharding_specs.items()]
|
||||||
|
|
||||||
# construct kwargs
|
# construct kwargs
|
||||||
if self.target in INPLACE_MODULE:
|
if self.target in INPLACE_MODULE:
|
||||||
kwargs = {'inplace': self.target.inplace}
|
kwargs = {'inplace': self.target.inplace}
|
||||||
|
elif self.target in INPLACE_OPS:
|
||||||
|
kwargs = {'inplace': True}
|
||||||
else:
|
else:
|
||||||
kwargs = {'inplace': False}
|
kwargs = {'inplace': False}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,113 @@
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.fx import GraphModule
|
||||||
|
from torch.fx.node import Node
|
||||||
|
|
||||||
|
from colossalai.auto_parallel.meta_profiler import MetaInfo
|
||||||
|
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply
|
||||||
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
|
||||||
|
from colossalai.tensor.comm_spec import CommSpec
|
||||||
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||||
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
|
||||||
|
shape_consistency_manager = ShapeConsistencyManager()
|
||||||
|
|
||||||
|
|
||||||
|
def _construct_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
|
||||||
|
target_sharding_spec: ShardingSpec) -> MetaInfo:
|
||||||
|
# get comm_action_sequence and total_cost from shape_consistency_manager
|
||||||
|
_, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
|
||||||
|
origin_sharding_spec, target_sharding_spec)
|
||||||
|
|
||||||
|
meta_info = MetaInfo()
|
||||||
|
# NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel
|
||||||
|
# get mem cost for MetaInfo
|
||||||
|
mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence)
|
||||||
|
# extract user that has _meta_data and extract element length
|
||||||
|
input_node = next(n for n in node._input_nodes if hasattr(n, '_meta_data'))
|
||||||
|
element_length = input_node._meta_data.element_size()
|
||||||
|
|
||||||
|
mem_cost.fwd.activation *= element_length
|
||||||
|
mem_cost.fwd.temp *= element_length
|
||||||
|
mem_cost.bwd.activation *= element_length
|
||||||
|
mem_cost.bwd.temp *= element_length
|
||||||
|
mem_cost.total.activation *= element_length
|
||||||
|
|
||||||
|
meta_info.memory_cost = mem_cost
|
||||||
|
|
||||||
|
# get computation cost for MetaInfo
|
||||||
|
meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length,
|
||||||
|
total_cost['backward'] * element_length,
|
||||||
|
total_cost['total'] * element_length)
|
||||||
|
|
||||||
|
# get tensor shape for MetaInfo
|
||||||
|
origin_sharding_spec: ShardingSpec
|
||||||
|
target_sharding_spec: ShardingSpec
|
||||||
|
input_shape = origin_sharding_spec.get_sharded_shape_per_device()
|
||||||
|
output_shape = target_sharding_spec.get_sharded_shape_per_device()
|
||||||
|
|
||||||
|
meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
|
||||||
|
meta_info.fwd_buffer = []
|
||||||
|
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
|
||||||
|
|
||||||
|
return meta_info
|
||||||
|
|
||||||
|
|
||||||
|
def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -> MetaInfo:
|
||||||
|
"""
|
||||||
|
This method is used to construct `MetaInto` for shape consistency node
|
||||||
|
"""
|
||||||
|
|
||||||
|
# extract node index and user node index
|
||||||
|
args = node.args
|
||||||
|
node_index, user_node_index = args[3], args[4]
|
||||||
|
origin_sharding_spec, target_sharding_spec = origin_spec_dict[node_index], sharding_spec_dict[node_index][
|
||||||
|
user_node_index]
|
||||||
|
|
||||||
|
return _construct_meta_info(node, origin_sharding_spec, target_sharding_spec)
|
||||||
|
|
||||||
|
|
||||||
|
def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> MetaInfo:
|
||||||
|
# extract node_index and op_data_name
|
||||||
|
node_index, op_data_name = node.args[2], node.args[3]
|
||||||
|
|
||||||
|
comm_action = comm_actions_dict[node_index][op_data_name]
|
||||||
|
if isinstance(comm_action.comm_spec, CommSpec):
|
||||||
|
# this case is for all_reduce, there will be no memory cost
|
||||||
|
meta_info = MetaInfo()
|
||||||
|
meta_info.memory_cost = TrainCycleItem(MemoryCost(), MemoryCost(), MemoryCost)
|
||||||
|
output_node = next(n for n in node.users if hasattr(n, '_meta_data'))
|
||||||
|
element_length = output_node._meta_data.element_size()
|
||||||
|
|
||||||
|
total_cost = comm_action.comm_spec.get_comm_cost()
|
||||||
|
meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length,
|
||||||
|
total_cost['backward'] * element_length,
|
||||||
|
total_cost['total'] * element_length)
|
||||||
|
|
||||||
|
input_shape = output_shape = comm_action.comm_spec.sharding_spec.get_sharded_shape_per_device()
|
||||||
|
meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
|
||||||
|
meta_info.fwd_buffer = []
|
||||||
|
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
|
||||||
|
else:
|
||||||
|
# this case will be handled by shape consistency manager
|
||||||
|
origin_sharding_spec, target_sharding_spec = comm_action.comm_spec['src_spec'], comm_action.comm_spec[
|
||||||
|
'tgt_spec']
|
||||||
|
meta_info = _construct_meta_info(node, origin_sharding_spec, target_sharding_spec)
|
||||||
|
|
||||||
|
return meta_info
|
||||||
|
|
||||||
|
|
||||||
|
def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, origin_spec_dict: Dict,
|
||||||
|
comm_actions_dict: Dict) -> GraphModule:
|
||||||
|
"""
|
||||||
|
The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph.
|
||||||
|
"""
|
||||||
|
for node in gm.graph.nodes:
|
||||||
|
if node.target == runtime_apply:
|
||||||
|
setattr(node, 'best_metainfo', _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict))
|
||||||
|
elif node.target == runtime_comm_spec_apply:
|
||||||
|
setattr(node, 'best_metainfo', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict))
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
return gm
|
|
@ -0,0 +1,8 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
OUTPUT_SAVED_OPS = [torch.nn.functional.relu, torch.nn.functional.softmax, torch.flatten]
|
||||||
|
|
||||||
|
OUTPUT_SAVED_MOD = [
|
||||||
|
torch.nn.ReLU,
|
||||||
|
torch.nn.Softmax,
|
||||||
|
]
|
|
@ -1,17 +1,16 @@
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from typing import Any, Dict, List, NamedTuple, Tuple
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.fx
|
import torch.fx
|
||||||
from torch.fx import GraphModule
|
from torch.fx import GraphModule
|
||||||
from torch.fx.node import Argument, Node, Target
|
from torch.fx.node import Node
|
||||||
from torch.utils._pytree import tree_map
|
|
||||||
|
|
||||||
from colossalai.auto_parallel.meta_profiler import MetaInfo
|
from colossalai.auto_parallel.meta_profiler import MetaInfo
|
||||||
from colossalai.fx._compatibility import compatibility, is_compatible_with_meta
|
from colossalai.auto_parallel.passes.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
|
||||||
|
from colossalai.fx._compatibility import compatibility
|
||||||
from colossalai.fx.profiler import GraphInfo
|
from colossalai.fx.profiler import GraphInfo
|
||||||
from colossalai.fx.profiler.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_tuple(x):
|
def _normalize_tuple(x):
|
||||||
|
@ -47,7 +46,7 @@ class MetaInfoProp:
|
||||||
"""
|
"""
|
||||||
Check if the node is inplace operation.
|
Check if the node is inplace operation.
|
||||||
"""
|
"""
|
||||||
if node.op == 'call_method':
|
if node.op == 'call_module':
|
||||||
return node.graph.owning_module.get_submodule(node.target).__class__ in OUTPUT_SAVED_MOD
|
return node.graph.owning_module.get_submodule(node.target).__class__ in OUTPUT_SAVED_MOD
|
||||||
elif node.op == "call_function":
|
elif node.op == "call_function":
|
||||||
return node.target in OUTPUT_SAVED_OPS
|
return node.target in OUTPUT_SAVED_OPS
|
||||||
|
@ -68,7 +67,7 @@ class MetaInfoProp:
|
||||||
"""
|
"""
|
||||||
graph_info = GraphInfo()
|
graph_info = GraphInfo()
|
||||||
out = _normalize_tuple(getattr(node, '_meta_data', None))
|
out = _normalize_tuple(getattr(node, '_meta_data', None))
|
||||||
graph_info.fwd_out = list(out)
|
graph_info.fwd_out = list(out) if out[0] is not None else []
|
||||||
node.meta = {**asdict(graph_info)}
|
node.meta = {**asdict(graph_info)}
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=False)
|
@compatibility(is_backward_compatible=False)
|
||||||
|
@ -97,66 +96,70 @@ class MetaInfoProp:
|
||||||
"""
|
"""
|
||||||
Handle other kind of nodes
|
Handle other kind of nodes
|
||||||
"""
|
"""
|
||||||
assert hasattr(node, 'best_metainfo'), f"Cannot find best_metainfo in node {node}"
|
assert hasattr(node, 'best_metainfo'), f"Cannot find best_metainfo in node {node}, {node.op}"
|
||||||
graph_info = GraphInfo()
|
graph_info = GraphInfo()
|
||||||
meta_info = node.best_metainfo
|
meta_info = node.best_metainfo
|
||||||
meta_info: MetaInfo
|
meta_info: MetaInfo
|
||||||
|
|
||||||
# set data_ptr for input_tensor in MetaInfo class
|
# set data_ptr for input_tensor in MetaInfo class
|
||||||
input_tensor: List[torch.Tensor] = meta_info.fwd_in
|
input_tensors: List[torch.Tensor] = meta_info.fwd_in
|
||||||
buffer_tensor: List[torch.Tensor] = meta_info.fwd_buffer
|
buffer_tensors: List[torch.Tensor] = meta_info.fwd_buffer
|
||||||
output_tensor: List[torch.Tensor] = meta_info.fwd_out
|
output_tensors: List[torch.Tensor] = meta_info.fwd_out
|
||||||
|
|
||||||
if len(input_tensor) > 0:
|
if self._is_inplace(node):
|
||||||
|
# inplace operation will not create new tensor, and it only has one parent node
|
||||||
|
# TODO: Verify this observation
|
||||||
|
# set data_ptr for input_tensor, buffer_tensor and output_tensor of current node
|
||||||
|
parent_node = list(node._input_nodes.keys())[0]
|
||||||
|
parent_tensor = parent_node.meta.get("fwd_out")[0]
|
||||||
|
parent_tensor: torch.Tensor
|
||||||
|
for tensor in input_tensors:
|
||||||
|
tensor.data_ptr = parent_tensor.data_ptr
|
||||||
|
for tensor in buffer_tensors:
|
||||||
|
tensor.data_ptr = parent_tensor.data_ptr
|
||||||
|
for tensor in output_tensors:
|
||||||
|
tensor.data_ptr = parent_tensor.data_ptr
|
||||||
|
|
||||||
|
else:
|
||||||
for par in node._input_nodes:
|
for par in node._input_nodes:
|
||||||
if par.meta:
|
# set data_ptr for the input_tensor of current node from the output_tensor of its parent node
|
||||||
if len(par.meta["fwd_out"]) > 0:
|
for tensor in par.meta.get("fwd_out", []):
|
||||||
# set data_ptr for the input_tensor of current node from the output_tensor of its parent node
|
tensor: torch.Tensor
|
||||||
for tensor in par.meta["fwd_out"]:
|
target_input_tensor = next(
|
||||||
tensor: torch.Tensor
|
(x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None)
|
||||||
target_tensor = next(
|
if target_input_tensor is not None:
|
||||||
(x for x in input_tensor if not x.data_ptr() and x.shape == tensor.shape), None)
|
target_input_tensor.data_ptr = tensor.data_ptr
|
||||||
target_tensor.data_ptr = tensor.data_ptr
|
|
||||||
|
|
||||||
# set data_ptr for tensor in input_tensor that is not set
|
# set data_ptr for tensor in input_tensor that is not set
|
||||||
for tensor in input_tensor:
|
for tensor in input_tensors:
|
||||||
if not tensor.data_ptr():
|
if not tensor.data_ptr():
|
||||||
self._set_data_ptr(tensor)
|
self._set_data_ptr(tensor)
|
||||||
|
|
||||||
# attach it to graph_info
|
|
||||||
graph_info.fwd_in = input_tensor
|
|
||||||
|
|
||||||
if self._is_inplace(node):
|
|
||||||
# inplace operation will not create new tensor
|
|
||||||
# set data_ptr for buffer_tensor and output_tensor of current node
|
|
||||||
for tensor in input_tensor:
|
|
||||||
tensor: torch.Tensor
|
|
||||||
target_buffer_tensor = next((x for x in buffer_tensor if not x.data_ptr() and x.shape == tensor.shape),
|
|
||||||
None)
|
|
||||||
target_output_tensor = next((x for x in output_tensor if not x.data_ptr() and x.shape == tensor.shape),
|
|
||||||
None)
|
|
||||||
target_buffer_tensor.data_ptr = tensor.data_ptr
|
|
||||||
target_output_tensor.data_ptr = tensor.data_ptr
|
|
||||||
# attach them to graph_info
|
|
||||||
graph_info.fwd_tmp = buffer_tensor
|
|
||||||
graph_info.fwd_out = output_tensor
|
|
||||||
|
|
||||||
else:
|
|
||||||
# set data_ptr for buffer_tensor
|
# set data_ptr for buffer_tensor
|
||||||
for tensor in buffer_tensor:
|
for tensor in buffer_tensors:
|
||||||
self._set_data_ptr(tensor)
|
self._set_data_ptr(tensor)
|
||||||
# attach it to graph_info
|
|
||||||
graph_info.fwd_tmp = buffer_tensor
|
|
||||||
|
|
||||||
# set data_ptr for output_tensor
|
# set data_ptr for output_tensor
|
||||||
for tensor in output_tensor:
|
for tensor in output_tensors:
|
||||||
self._set_data_ptr(tensor)
|
self._set_data_ptr(tensor)
|
||||||
# attach it to graph_info
|
|
||||||
graph_info.fwd_out = output_tensor
|
# attach them to graph_info
|
||||||
|
graph_info.fwd_in = input_tensors
|
||||||
|
graph_info.fwd_tmp = buffer_tensors
|
||||||
|
graph_info.fwd_out = output_tensors
|
||||||
|
|
||||||
# fetch other memory informations
|
# fetch other memory informations
|
||||||
memory_cost = meta_info.memory_cost
|
memory_cost = meta_info.memory_cost
|
||||||
graph_info.fwd_mem_tmp = memory_cost.fwd.temp
|
graph_info.fwd_mem_tmp = memory_cost.fwd.temp
|
||||||
|
graph_info.fwd_mem_out = memory_cost.fwd.activation
|
||||||
graph_info.bwd_mem_tmp = memory_cost.bwd.temp
|
graph_info.bwd_mem_tmp = memory_cost.bwd.temp
|
||||||
|
graph_info.bwd_mem_out = memory_cost.bwd.activation
|
||||||
|
|
||||||
|
# fetch flop information
|
||||||
|
# here we use fwd_time and bwd_time to deal with the case that
|
||||||
|
# communication cost is a float
|
||||||
|
compute_cost = meta_info.compute_cost
|
||||||
|
graph_info.fwd_time = compute_cost.fwd
|
||||||
|
graph_info.bwd_time = compute_cost.bwd
|
||||||
|
|
||||||
node.meta = {**asdict(graph_info)}
|
node.meta = {**asdict(graph_info)}
|
||||||
|
|
|
@ -47,53 +47,6 @@ def runtime_apply_for_iterable_object(node: Node, origin_dict: Dict, input_dict:
|
||||||
return rst
|
return rst
|
||||||
|
|
||||||
|
|
||||||
def construct_meta_info(node: Node, user_node: Node) -> MetaInfo:
|
|
||||||
"""
|
|
||||||
This method is used to construct `MetaInto` for shape consistency node
|
|
||||||
TODO: Actually we could attain the cost information from resharding cost in node
|
|
||||||
handler, we should modify this part in the future.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def compute_shape(sharding_spec: ShardingSpec):
|
|
||||||
shape = sharding_spec.entire_shape
|
|
||||||
new_shape = []
|
|
||||||
for dim, shard in sharding_spec.dim_partition_dict.items():
|
|
||||||
new_shape.append(shape[dim] // len(shard))
|
|
||||||
return new_shape
|
|
||||||
|
|
||||||
meta_info = MetaInfo()
|
|
||||||
origin_sharding_spec, target_sharding_spec = node.sharding_spec, user_node.best_strategy.get_sharding_spec_by_name(
|
|
||||||
str(node.name))
|
|
||||||
_, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
|
|
||||||
origin_sharding_spec, target_sharding_spec)
|
|
||||||
|
|
||||||
# NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel
|
|
||||||
# get mem cost for MetaInfo
|
|
||||||
mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence)
|
|
||||||
element_length = node._meta_data.element_size()
|
|
||||||
mem_cost.fwd.activation *= element_length
|
|
||||||
mem_cost.fwd.temp *= element_length
|
|
||||||
mem_cost.bwd.activation *= element_length
|
|
||||||
mem_cost.bwd.temp *= element_length
|
|
||||||
mem_cost.total.activation *= element_length
|
|
||||||
|
|
||||||
meta_info.memory_cost = mem_cost
|
|
||||||
|
|
||||||
# get computation cost for MetaInfo
|
|
||||||
compute_cost = TrainCycleItem(total_cost['forward'], total_cost['backward'], total_cost['total'])
|
|
||||||
meta_info.compute_cost = compute_cost
|
|
||||||
|
|
||||||
# get tensor shape for MetaInfo
|
|
||||||
input_shape = compute_shape(origin_sharding_spec)
|
|
||||||
output_shape = compute_shape(target_sharding_spec)
|
|
||||||
|
|
||||||
meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
|
|
||||||
meta_info.fwd_buffer = []
|
|
||||||
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
|
|
||||||
|
|
||||||
return meta_info
|
|
||||||
|
|
||||||
|
|
||||||
def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_index: int, op_data_name: str):
|
def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_index: int, op_data_name: str):
|
||||||
"""
|
"""
|
||||||
This method will be invoked during runtime to apply the comm action following the instruction of comm spec.
|
This method will be invoked during runtime to apply the comm action following the instruction of comm spec.
|
||||||
|
@ -175,8 +128,6 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
|
||||||
runtime_apply,
|
runtime_apply,
|
||||||
args=(node, origin_dict_node, input_dict_node,
|
args=(node, origin_dict_node, input_dict_node,
|
||||||
node_to_index_dict[node], user_node_index))
|
node_to_index_dict[node], user_node_index))
|
||||||
meta_info = construct_meta_info(node, user_node)
|
|
||||||
setattr(shape_consistency_node, 'best_metainfo', meta_info)
|
|
||||||
|
|
||||||
new_args = list(user_node.args)
|
new_args = list(user_node.args)
|
||||||
new_kwargs = dict(user_node.kwargs)
|
new_kwargs = dict(user_node.kwargs)
|
||||||
|
|
|
@ -16,7 +16,7 @@ __all__ = ['BinaryElementwiseHandler']
|
||||||
|
|
||||||
|
|
||||||
@operator_registry.register(BCAST_FUNC_OP)
|
@operator_registry.register(BCAST_FUNC_OP)
|
||||||
class BinaryElementwiseHandler(NodeHandler):
|
class BinaryElementwiseHandler(MetaInfoNodeHandler):
|
||||||
"""
|
"""
|
||||||
An BinaryBcastOpHandler is a node handler which deals with operations which have two
|
An BinaryBcastOpHandler is a node handler which deals with operations which have two
|
||||||
operands and broadcasting occurs such as torch.add.
|
operands and broadcasting occurs such as torch.add.
|
||||||
|
|
|
@ -4,7 +4,7 @@ from typing import Dict, List, Tuple, Union
|
||||||
import torch
|
import torch
|
||||||
from torch.fx.node import Node
|
from torch.fx.node import Node
|
||||||
|
|
||||||
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo
|
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo, meta_register
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||||
OperationData,
|
OperationData,
|
||||||
OperationDataType,
|
OperationDataType,
|
||||||
|
@ -138,8 +138,7 @@ class NodeHandler(ABC):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if self.node.op == 'call_module':
|
if self.node.op == 'call_module':
|
||||||
submod = self.node.graph.owning_module.get_submodule(self.node.target)
|
target = self.node.graph.owning_module.get_submodule(self.node.target)
|
||||||
target = type(submod)
|
|
||||||
elif self.node.op == 'call_function':
|
elif self.node.op == 'call_function':
|
||||||
target = self.node.target
|
target = self.node.target
|
||||||
elif self.node.op == 'call_method':
|
elif self.node.op == 'call_method':
|
||||||
|
@ -235,15 +234,19 @@ class MetaInfoNodeHandler(NodeHandler):
|
||||||
"""
|
"""
|
||||||
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
|
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
|
||||||
target = self.get_target_function()
|
target = self.get_target_function()
|
||||||
metainfo_vector = []
|
# Currently we haven't patched all the torch functions and modules, so if the target
|
||||||
for strategy in self.strategies_vector:
|
# is not patched, we will use the default cost model to compute the cost.
|
||||||
metainfo = MetaInfo(strategy, target)
|
# TODO: patch all torch functions and modules to make it clean
|
||||||
strategy.compute_cost = metainfo.compute_cost
|
if meta_register.has(target.__class__) or meta_register.has(target):
|
||||||
strategy.memory_cost = metainfo.memory_cost
|
metainfo_vector = []
|
||||||
metainfo_vector.append(metainfo)
|
for strategy in self.strategies_vector:
|
||||||
|
metainfo = MetaInfo(strategy, target)
|
||||||
|
strategy.compute_cost = metainfo.compute_cost
|
||||||
|
strategy.memory_cost = metainfo.memory_cost
|
||||||
|
metainfo_vector.append(metainfo)
|
||||||
|
|
||||||
# attach metainfos to the handler
|
# attach metainfos to the handler
|
||||||
setattr(self, "metainfo_vector", metainfo_vector)
|
setattr(self, "metainfo_vector", metainfo_vector)
|
||||||
|
|
||||||
return self.strategies_vector
|
return self.strategies_vector
|
||||||
|
|
||||||
|
@ -282,14 +285,18 @@ class MetaInfoModuleHandler(ModuleHandler):
|
||||||
"""
|
"""
|
||||||
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
|
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
|
||||||
target = self.get_target_function()
|
target = self.get_target_function()
|
||||||
metainfo_vector = []
|
# Currently we haven't patched all the torch functions and modules, so if the target
|
||||||
for strategy in self.strategies_vector:
|
# is not patched, we will use the default cost model to compute the cost.
|
||||||
metainfo = MetaInfo(strategy, target)
|
# TODO: patch all torch functions and modules to make it clean
|
||||||
strategy.compute_cost = metainfo.compute_cost
|
if meta_register.has(target.__class__) or meta_register.has(target):
|
||||||
strategy.memory_cost = metainfo.memory_cost
|
metainfo_vector = []
|
||||||
metainfo_vector.append(metainfo)
|
for strategy in self.strategies_vector:
|
||||||
|
metainfo = MetaInfo(strategy, target)
|
||||||
|
strategy.compute_cost = metainfo.compute_cost
|
||||||
|
strategy.memory_cost = metainfo.memory_cost
|
||||||
|
metainfo_vector.append(metainfo)
|
||||||
|
|
||||||
# attach metainfos to the handler
|
# attach metainfos to the handler
|
||||||
setattr(self, "metainfo_vector", metainfo_vector)
|
setattr(self, "metainfo_vector", metainfo_vector)
|
||||||
|
|
||||||
return self.strategies_vector
|
return self.strategies_vector
|
||||||
|
|
|
@ -3,7 +3,7 @@ from typing import Dict, List
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..sharding_strategy import OperationData, OperationDataType
|
from ..sharding_strategy import OperationData, OperationDataType
|
||||||
from .node_handler import NodeHandler
|
from .node_handler import MetaInfoNodeHandler, NodeHandler
|
||||||
from .registry import operator_registry
|
from .registry import operator_registry
|
||||||
from .strategy import ReshapeGenerator, StrategyGenerator
|
from .strategy import ReshapeGenerator, StrategyGenerator
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ __all__ = ['ReshapeHandler']
|
||||||
@operator_registry.register(torch.flatten)
|
@operator_registry.register(torch.flatten)
|
||||||
@operator_registry.register(torch.Tensor.unsqueeze)
|
@operator_registry.register(torch.Tensor.unsqueeze)
|
||||||
@operator_registry.register(torch.nn.AdaptiveAvgPool2d)
|
@operator_registry.register(torch.nn.AdaptiveAvgPool2d)
|
||||||
class ReshapeHandler(NodeHandler):
|
class ReshapeHandler(MetaInfoNodeHandler):
|
||||||
"""
|
"""
|
||||||
A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
|
A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -3,7 +3,7 @@ from typing import Dict, List
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..sharding_strategy import OperationData, OperationDataType
|
from ..sharding_strategy import OperationData, OperationDataType
|
||||||
from .node_handler import NodeHandler
|
from .node_handler import MetaInfoNodeHandler, NodeHandler
|
||||||
from .registry import operator_registry
|
from .registry import operator_registry
|
||||||
from .strategy import StrategyGenerator, UnaryElementwiseGenerator
|
from .strategy import StrategyGenerator, UnaryElementwiseGenerator
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ __all__ = ['UnaryElementwiseHandler']
|
||||||
@operator_registry.register(torch.nn.modules.dropout.Dropout)
|
@operator_registry.register(torch.nn.modules.dropout.Dropout)
|
||||||
@operator_registry.register(torch.Tensor.contiguous)
|
@operator_registry.register(torch.Tensor.contiguous)
|
||||||
@operator_registry.register(torch.nn.functional.dropout)
|
@operator_registry.register(torch.nn.functional.dropout)
|
||||||
class UnaryElementwiseHandler(NodeHandler):
|
class UnaryElementwiseHandler(MetaInfoNodeHandler):
|
||||||
"""
|
"""
|
||||||
A UnaryElementwiseHandler which deals with the sharding strategies for UnaryElementwise Op.
|
A UnaryElementwiseHandler which deals with the sharding strategies for UnaryElementwise Op.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -100,7 +100,7 @@ def calculate_fwd_time(n: Node) -> float:
|
||||||
fwd_time (float): the result of `fwd_time`
|
fwd_time (float): the result of `fwd_time`
|
||||||
"""
|
"""
|
||||||
# TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs
|
# TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs
|
||||||
return n.meta["fwd_flop"]
|
return n.meta["fwd_time"]
|
||||||
|
|
||||||
|
|
||||||
def calculate_bwd_time(n: Node) -> float:
|
def calculate_bwd_time(n: Node) -> float:
|
||||||
|
@ -111,4 +111,4 @@ def calculate_bwd_time(n: Node) -> float:
|
||||||
bwd_time (float): the result of `bwd_time`
|
bwd_time (float): the result of `bwd_time`
|
||||||
"""
|
"""
|
||||||
# TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs
|
# TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs
|
||||||
return n.meta["bwd_flop"]
|
return n.meta["bwd_time"]
|
||||||
|
|
|
@ -441,6 +441,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||||
if discard_input:
|
if discard_input:
|
||||||
alloc_numel -= input_numel
|
alloc_numel -= input_numel
|
||||||
|
|
||||||
|
return alloc_numel, peak_numel
|
||||||
|
|
||||||
def split_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
|
def split_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
|
||||||
"""analyze split memory footprint
|
"""analyze split memory footprint
|
||||||
split will allocate memory for the output tensor if we don't apply shard on the first dimension of
|
split will allocate memory for the output tensor if we don't apply shard on the first dimension of
|
||||||
|
@ -478,11 +480,13 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||||
# kind of weird, and I think we could ignore it for now.
|
# kind of weird, and I think we could ignore it for now.
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
return alloc_numel, peak_numel
|
||||||
|
|
||||||
def reduce_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
|
def reduce_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
|
||||||
"""
|
"""
|
||||||
a dummy function for reduce memory footprint analysis, as the reduce action doesn't allocate extra memory
|
a dummy function for reduce memory footprint analysis, as the reduce action doesn't allocate extra memory
|
||||||
"""
|
"""
|
||||||
pass
|
return alloc_numel, peak_numel
|
||||||
|
|
||||||
def all2all_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
|
def all2all_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
|
||||||
"""analyze all_to_all memory footprint
|
"""analyze all_to_all memory footprint
|
||||||
|
@ -508,11 +512,13 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||||
if discard_input:
|
if discard_input:
|
||||||
alloc_numel -= input_numel
|
alloc_numel -= input_numel
|
||||||
|
|
||||||
|
return alloc_numel, peak_numel
|
||||||
|
|
||||||
def identity_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
|
def identity_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
|
||||||
"""
|
"""
|
||||||
a dummy function for identity memory footprint analysis, as the identity action doesn't allocate extra memory
|
a dummy function for identity memory footprint analysis, as the identity action doesn't allocate extra memory
|
||||||
"""
|
"""
|
||||||
pass
|
return alloc_numel, peak_numel
|
||||||
|
|
||||||
pattern_to_func_dict = {
|
pattern_to_func_dict = {
|
||||||
CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: [gather_analysis, split_analysis],
|
CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: [gather_analysis, split_analysis],
|
||||||
|
@ -539,17 +545,18 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||||
for idx, action_spec_pair in enumerate(zip(fwd_actions, comm_action_sequence)):
|
for idx, action_spec_pair in enumerate(zip(fwd_actions, comm_action_sequence)):
|
||||||
# the first forward comm action will not discard input
|
# the first forward comm action will not discard input
|
||||||
fwd_action, comm_spec = action_spec_pair
|
fwd_action, comm_spec = action_spec_pair
|
||||||
if idx == 0:
|
fwd_alloc_numel, fwd_peak_numel = fwd_action(comm_spec, False, fwd_alloc_numel,
|
||||||
fwd_action(comm_spec, False, fwd_alloc_numel, fwd_peak_numel)
|
fwd_peak_numel) if idx == 0 else fwd_action(
|
||||||
else:
|
comm_spec, True, fwd_alloc_numel, fwd_peak_numel)
|
||||||
fwd_action(comm_spec, True, fwd_alloc_numel, fwd_peak_numel)
|
|
||||||
|
|
||||||
# analyze memory footprint for backward comm actions sequence
|
# analyze memory footprint for backward comm actions sequence
|
||||||
bwd_alloc_numel = 0
|
bwd_alloc_numel = 0
|
||||||
bwd_peak_numel = 0
|
bwd_peak_numel = 0
|
||||||
for idx, action_spec_pair in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))):
|
for idx, action_spec_pair in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))):
|
||||||
bwd_action, comm_spec = action_spec_pair
|
bwd_action, comm_spec = action_spec_pair
|
||||||
bwd_action(comm_spec, True, bwd_alloc_numel, bwd_peak_numel)
|
bwd_alloc_numel, bwd_peak_numel = bwd_action(comm_spec, False, bwd_alloc_numel,
|
||||||
|
bwd_peak_numel) if idx == 0 else bwd_action(
|
||||||
|
comm_spec, True, bwd_alloc_numel, bwd_peak_numel)
|
||||||
|
|
||||||
fwd_mem = MemoryCost(activation=fwd_alloc_numel, temp=fwd_peak_numel - fwd_alloc_numel)
|
fwd_mem = MemoryCost(activation=fwd_alloc_numel, temp=fwd_peak_numel - fwd_alloc_numel)
|
||||||
bwd_mem = MemoryCost(activation=bwd_alloc_numel, temp=bwd_peak_numel - bwd_alloc_numel)
|
bwd_mem = MemoryCost(activation=bwd_alloc_numel, temp=bwd_peak_numel - bwd_alloc_numel)
|
||||||
|
|
Loading…
Reference in New Issue