mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
409 lines
15 KiB
409 lines
15 KiB
import time |
|
from functools import partial |
|
from typing import Any, Callable, Dict, Tuple |
|
|
|
import torch |
|
from torch.fx import Graph, Node |
|
from torch.fx.node import Argument, Target |
|
from torch.nn.parameter import Parameter |
|
from torch.utils._pytree import tree_map |
|
|
|
from .._compatibility import compatibility |
|
from .constants import ALIAS_ATEN, OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS |
|
from .dataflow import GraphInfo, Phase, autograd_graph_analysis, is_phase |
|
from .memory_utils import activation_size, parameter_size |
|
from .opcount import flop_mapping |
|
from .tensor import MetaTensor |
|
|
|
__all__ = ["profile_function", "profile_module", "profile_method"] |
|
|
|
# super-dainiu: this cache should be global, otherwise it cannot |
|
# track duplicated tensors between nodes |
|
cache = set() |
|
|
|
# a global identifier for inplace ops |
|
do_not_cache = False |
|
|
|
|
|
def normalize_tuple(x): |
|
if not isinstance(x, tuple): |
|
return (x,) |
|
return x |
|
|
|
|
|
def is_autogradable(x): |
|
return isinstance(x, torch.Tensor) and x.is_floating_point() |
|
|
|
|
|
def detach_variables(x): |
|
if isinstance(x, torch.Tensor): |
|
requires_grad = x.requires_grad |
|
x = x.detach() |
|
x.requires_grad = requires_grad |
|
|
|
return x |
|
|
|
|
|
@compatibility(is_backward_compatible=True) |
|
def _profile_concrete(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]: |
|
"""Profile a Callable function with args and kwargs on concrete devices by https://github.com/Cypher30 |
|
To profile the actual forward memory, we first run target in the context torch.no_grad() to get |
|
the fwd_mem_out, then we run target with grad enable to found the extra memory stored in the memory |
|
by memory allocated minus the fwd_mem_out. |
|
To profile the actual backward memory, we first make dummy gradient for torch.autograd.backward, then |
|
find the bwd_mem_tmp with memory peak during the process minus bwd_mem_out(it is actually equal to size |
|
of args and kwargs). |
|
We also add time stamps to profile the real forward and backward time. |
|
|
|
Args: |
|
target (Callable): A Callable function |
|
args (Any): Arguments |
|
kwargs (Any): Arguments |
|
|
|
Returns: |
|
Tuple[Tuple[Any, ...], GraphInfo]: Output for next node & memory cost and real forward and backward |
|
time. |
|
""" |
|
|
|
graphinfo = GraphInfo() |
|
|
|
# detach input from the graph |
|
args = tree_map(detach_variables, args) |
|
kwargs = tree_map(detach_variables, kwargs) |
|
if isinstance(target, str): |
|
# args[0] is the `self` object for this method call |
|
self_obj, *args_tail = args |
|
|
|
# calculate fwd_mem_out |
|
mem_stamp0 = torch.cuda.memory_allocated() |
|
with torch.no_grad(): |
|
out = getattr(self_obj, target)(*args_tail, **kwargs) |
|
mem_stamp1 = torch.cuda.memory_allocated() |
|
graphinfo.fwd_mem_out = mem_stamp1 - mem_stamp0 |
|
del out |
|
|
|
# calculate fwd_mem_tmp & fwd_time |
|
mem_stamp0 = torch.cuda.memory_allocated() |
|
fwd_time0 = time.time() |
|
out = getattr(self_obj, target)(*args_tail, **kwargs) |
|
fwd_time1 = time.time() |
|
graphinfo.fwd_time = fwd_time1 - fwd_time0 |
|
mem_stamp1 = torch.cuda.memory_allocated() |
|
graphinfo.fwd_mem_tmp = mem_stamp1 - mem_stamp0 - graphinfo.fwd_mem_out |
|
|
|
# calculate bwd_mem_tmp & bwd_time |
|
grad_tensors = tree_map(lambda x: torch.ones_like(x) if isinstance(x, torch.Tensor) else None, out) |
|
torch.cuda.reset_peak_memory_stats() |
|
mem_stamp0 = torch.cuda.memory_allocated() |
|
bwd_time0 = time.time() |
|
torch.autograd.backward(out, grad_tensors=grad_tensors) |
|
bwd_time1 = time.time() |
|
graphinfo.bwd_time = bwd_time1 - bwd_time0 |
|
mem_stamp1 = torch.cuda.max_memory_allocated() |
|
|
|
# calculate bwd memory stats |
|
# NOTE: the module should add param to bwd_mem_out for bwd_mem_tmp calculation |
|
graphinfo.bwd_mem_out = activation_size(args) + activation_size(kwargs) |
|
graphinfo.bwd_mem_out += parameter_size(target.__self__) if hasattr(target.__self__, "parameters") else 0 |
|
graphinfo.bwd_mem_tmp = mem_stamp1 - mem_stamp0 - graphinfo.bwd_mem_out |
|
|
|
else: |
|
# calculate fwd_mem_out |
|
mem_stamp0 = torch.cuda.memory_allocated() |
|
with torch.no_grad(): |
|
out = target(*args, **kwargs) |
|
mem_stamp1 = torch.cuda.memory_allocated() |
|
graphinfo.fwd_mem_out = mem_stamp1 - mem_stamp0 |
|
del out |
|
|
|
# calculate fwd_mem_tmp & fwd_time |
|
mem_stamp0 = torch.cuda.memory_allocated() |
|
fwd_time0 = time.time() |
|
out = target(*args, **kwargs) |
|
fwd_time1 = time.time() |
|
graphinfo.fwd_time = fwd_time1 - fwd_time0 |
|
mem_stamp1 = torch.cuda.memory_allocated() |
|
graphinfo.fwd_mem_tmp = mem_stamp1 - mem_stamp0 - graphinfo.fwd_mem_out |
|
|
|
# calculate bwd_mem_tmp & bwd_time |
|
grad_tensors = tree_map(lambda x: torch.ones_like(x) if isinstance(x, torch.Tensor) else None, out) |
|
torch.cuda.reset_peak_memory_stats() |
|
mem_stamp0 = torch.cuda.memory_allocated() |
|
bwd_time0 = time.time() |
|
torch.autograd.backward(out, grad_tensors=grad_tensors) |
|
bwd_time1 = time.time() |
|
graphinfo.bwd_time = bwd_time1 - bwd_time0 |
|
mem_stamp1 = torch.cuda.max_memory_allocated() |
|
|
|
# calculate bwd memory stats |
|
# NOTE: the module should add param to bwd_mem_out for bwd_mem_tmp calculation |
|
graphinfo.bwd_mem_out = activation_size(args) + activation_size(kwargs) |
|
graphinfo.bwd_mem_out += parameter_size(target.__self__) if hasattr(target.__self__, "parameters") else 0 |
|
graphinfo.bwd_mem_tmp = mem_stamp1 - mem_stamp0 - graphinfo.bwd_mem_out |
|
|
|
return tree_map(detach_variables, out), graphinfo |
|
|
|
|
|
@compatibility(is_backward_compatible=False) |
|
def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]: |
|
""" |
|
Profile a Callable function with args and kwargs on meta devices. |
|
|
|
Args: |
|
target (Callable): A Callable function |
|
args (Any): Argument |
|
kwargs (Any): Argument |
|
|
|
Returns: |
|
out (Tuple[Any, ...]): The argument value that was retrieved. |
|
meta_info (GraphInfo): The memory cost and FLOPs estimated with `MetaTensor`. |
|
""" |
|
# This subgraph traces aten level ops inside one node. |
|
subgraph = Graph() |
|
|
|
# `flop_count`` serves as a global dictionary to store results. |
|
flop_count = { |
|
Phase.FORWARD: 0, |
|
Phase.BACKWARD: 0, |
|
} |
|
|
|
# FlopTensor not only get the flop statistics of a single node, |
|
# it also build a full autograd graph for this node. |
|
# This makes sure we can analyze the dependencies of memory, and |
|
# decide which forward intermediate results should be kept until |
|
# backward is executed. |
|
# Hopefully, this attempt will provide a better estimation of memory. |
|
class FlopTensor(MetaTensor): |
|
_node: Node = None |
|
|
|
def __repr__(self): |
|
if self.grad_fn: |
|
return f"FlopTensor({self._tensor}, fake_device='{self.device}', size={tuple(self.shape)}, grad_fn={self.grad_fn})" |
|
return f"FlopTensor({self._tensor}, fake_device='{self.device}', size={tuple(self.shape)}, requires_grad={self.requires_grad})" |
|
|
|
@classmethod |
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): |
|
args_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, args) |
|
kwargs_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, kwargs) |
|
node = subgraph.create_node("call_function", func, args_node, kwargs_node) |
|
|
|
out = super().__torch_dispatch__(func, types, args, kwargs) |
|
|
|
flop_count[phase] += flop_mapping[func](args, normalize_tuple(out)) |
|
node.meta["phase"] = phase |
|
|
|
# super-dainiu: in `nn.MultiheadAttention` this weird thing occurs, |
|
# i.e. `Phase.PLACEHOLDER` tensors are aliased and saved during |
|
# `Phase.FORWARD` |
|
if phase == Phase.FORWARD: |
|
if all(map(partial(is_phase, phase=Phase.PLACEHOLDER), node.all_input_nodes)) and func in ALIAS_ATEN: |
|
node.meta["phase"] = Phase.PLACEHOLDER |
|
|
|
# TODO(yby): specify `saved_tensors` for backward memory estimation |
|
node.meta["saved_tensor"] = [] |
|
if phase == Phase.BACKWARD: |
|
node.meta["saved_tensor"] = normalize_tuple(out) |
|
|
|
def wrap(x): |
|
if isinstance(x, MetaTensor): |
|
x = FlopTensor(x) |
|
x._node = node |
|
return x |
|
|
|
out = tree_map(wrap, out) |
|
return out |
|
|
|
def wrap(x): |
|
if isinstance(x, torch.Tensor): |
|
x = FlopTensor(x) |
|
if is_autogradable(x): |
|
x.requires_grad_(True) |
|
x._node = subgraph.create_node( |
|
"placeholder", |
|
"placeholder", |
|
(subgraph._root,), |
|
name=subgraph._graph_namespace.create_name("input", x._tensor), |
|
) |
|
x._node.meta["phase"] = Phase.PLACEHOLDER |
|
x._node.meta["saved_tensor"] = [] |
|
return x |
|
|
|
# Basically, we need to detach the args and kwargs from the outer graph. |
|
args = tree_map(wrap, args) |
|
kwargs = tree_map(wrap, kwargs) |
|
|
|
def pack(x): |
|
global cache, do_not_cache |
|
if isinstance(x, FlopTensor) and not x._tensor.data_ptr() in cache: |
|
tensor = x._tensor.detach() |
|
tensor.data_ptr = x._tensor.data_ptr |
|
x._node.meta["saved_tensor"] += [tensor] |
|
if not do_not_cache: |
|
cache.add(x._tensor.data_ptr()) |
|
return x |
|
|
|
def unpack(x): |
|
return x |
|
|
|
# `phase` will mark the phase of autograd from outside scope. |
|
phase = Phase.FORWARD |
|
# mark saved tensors with saved_tensors_hooks |
|
with torch.autograd.graph.saved_tensors_hooks(pack, unpack): |
|
if isinstance(target, str): |
|
# args[0] is the `self` object for this method call |
|
self_obj, *args_tail = args |
|
out = getattr(self_obj, target)(*args_tail, **kwargs) |
|
else: |
|
out = target(*args, **kwargs) |
|
|
|
# If the output is not a floating point `torch.Tensor` or it does not |
|
# requires grad, then we should not run backward for this node. |
|
if all(map(lambda x: is_autogradable(x) and x.requires_grad, normalize_tuple(out))): |
|
grad_out = [torch.zeros_like(t) for t in normalize_tuple(out)] |
|
phase = Phase.BACKWARD |
|
torch.autograd.backward( |
|
out, |
|
grad_out, |
|
) |
|
|
|
graph_info = autograd_graph_analysis(subgraph) |
|
graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Phase.FORWARD], flop_count[Phase.BACKWARD] |
|
|
|
def extract_tensor(x: Any): |
|
if isinstance(x, MetaTensor): |
|
tensor = x._tensor.detach() |
|
tensor.data_ptr = x._tensor.data_ptr |
|
return tensor |
|
if not isinstance(x, torch.finfo): |
|
return x |
|
|
|
graph_info.fwd_out = list(map(extract_tensor, normalize_tuple(out))) |
|
|
|
def unwrap(x): |
|
return MetaTensor(x) if isinstance(x, torch.Tensor) else x |
|
|
|
return tree_map(unwrap, out), graph_info |
|
|
|
|
|
@compatibility(is_backward_compatible=True) |
|
def profile_function(target: "Target", device: str = "meta") -> Callable: |
|
""" |
|
Wrap a `call_function` node or `torch.nn.functional` in order to |
|
record the memory cost and FLOPs of the execution. |
|
|
|
Warnings: |
|
You may only use tensors with `device=meta` for this wrapped function. |
|
Only original `torch.nn.functional` are available. |
|
|
|
Examples: |
|
>>> input = torch.rand(100, 100, 100, 100, device='meta') |
|
>>> func = torch.nn.functional.relu |
|
>>> output, meta_info = profile_function(func)(input) |
|
""" |
|
|
|
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: |
|
# find the grad for parameter in args and kwargs |
|
param_size = 0 |
|
|
|
def get_param_size(x): |
|
nonlocal param_size |
|
if isinstance(x, Parameter): |
|
param_size += activation_size(x) |
|
|
|
tree_map(get_param_size, args) |
|
tree_map(get_param_size, kwargs) |
|
|
|
# If there is an argument that this `call_function` is inplace, we should |
|
# still run the profiling but discard some results regarding `target` |
|
global do_not_cache |
|
|
|
inplace = kwargs.get("inplace", False) |
|
if target in OUTPUT_SAVED_OPS: |
|
do_not_cache = True |
|
if inplace: |
|
do_not_cache = True |
|
kwargs["inplace"] = False |
|
if device == "meta": |
|
out, meta = _profile_meta(func, *args, **kwargs) |
|
else: |
|
out, meta = _profile_concrete(func, *args, **kwargs) |
|
if inplace: |
|
kwargs["inplace"] = True |
|
meta.bwd_mem_tmp = 0 |
|
meta.bwd_mem_out = 0 |
|
do_not_cache = False |
|
|
|
meta.bwd_mem_out -= param_size |
|
return out, meta |
|
|
|
f.__name__ = target.__name__ |
|
func = target |
|
return f |
|
|
|
|
|
@compatibility(is_backward_compatible=True) |
|
def profile_method(target: "Target", device: str = "meta") -> Callable: |
|
""" |
|
Wrap a `call_method` node |
|
record the memory cost and FLOPs of the execution. |
|
""" |
|
|
|
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: |
|
# execute the method and return the result |
|
assert isinstance(target, str), f"{target} instance is not str." |
|
if device == "meta": |
|
out, meta = _profile_meta(target, *args, **kwargs) |
|
else: |
|
out, meta = _profile_concrete(target, *args, **kwargs) |
|
return out, meta |
|
|
|
return f |
|
|
|
|
|
@compatibility(is_backward_compatible=True) |
|
def profile_module(module: torch.nn.Module, device: str = "meta") -> Callable: |
|
""" |
|
Wrap a `call_module` node or `torch.nn` in order to |
|
record the memory cost and FLOPs of the execution. |
|
|
|
Warnings: |
|
You may only use tensors with `device=meta` for this wrapped function. |
|
Only original `torch.nn` are available. |
|
|
|
Example: |
|
>>> input = torch.rand(4, 3, 224, 224, device='meta') |
|
>>> mod = torch.nn.Conv2d(3, 128, 3) |
|
>>> output, meta_info = profile_module(mod)(input) |
|
""" |
|
|
|
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: |
|
# calculate parameter size |
|
param_size = parameter_size(module) |
|
|
|
# If there is an argument that this `call_module` is inplace, we should |
|
# still run the profiling but discard some results regarding `module`. |
|
global do_not_cache |
|
|
|
inplace = getattr(module, "inplace", False) |
|
if type(module) in OUTPUT_SAVED_MOD: |
|
do_not_cache = True |
|
if inplace: |
|
do_not_cache = True |
|
module.inplace = False |
|
if device == "meta": |
|
out, meta = _profile_meta(func, *args, **kwargs) |
|
else: |
|
out, meta = _profile_concrete(func, *args, **kwargs) |
|
if inplace: |
|
module.inplace = True |
|
meta.bwd_mem_tmp = 0 |
|
meta.bwd_mem_out = 0 |
|
do_not_cache = False |
|
|
|
# grad for param will not be counted |
|
meta.bwd_mem_out -= param_size |
|
return out, meta |
|
|
|
f.__name__ = module.__class__.__name__ |
|
func = module.forward |
|
return f
|
|
|