|
|
|
from dataclasses import dataclass, field
|
|
|
|
from typing import Callable, ClassVar, Dict, List, Optional, Tuple, Union
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from torch.autograd.profiler_util import _format_memory, _format_time
|
|
|
|
from torch.fx import Graph, GraphModule, Node
|
|
|
|
|
|
|
|
from colossalai._analyzer.envs import MeshConfig
|
|
|
|
|
|
|
|
|
|
|
|
def intersect(a, b):
|
|
|
|
return {k: a[k] for k in a if k in b}
|
|
|
|
|
|
|
|
|
|
|
|
def subtract(a, b):
|
|
|
|
return {k: a[k] for k in a if k not in b}
|
|
|
|
|
|
|
|
|
|
|
|
def union(a, b):
|
|
|
|
return {**a, **b}
|
|
|
|
|
|
|
|
|
|
|
|
def compute_size_in_bytes(elem: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
|
|
|
|
"""Compute the size of a tensor or a collection of tensors in bytes.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
elem (torch.Tensor | Dict | List | Tuple | int): Arbitrary nested ``torch.Tensor`` data structure.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
int: The size of the tensor or the collection of tensors in bytes.
|
|
|
|
"""
|
|
|
|
nbytes = 0
|
|
|
|
if isinstance(elem, torch.Tensor):
|
|
|
|
if elem.is_quantized:
|
|
|
|
nbytes += elem.numel() * torch._empty_affine_quantized([], dtype=elem.dtype).element_size()
|
|
|
|
else:
|
|
|
|
nbytes += elem.numel() * torch.tensor([], dtype=elem.dtype).element_size()
|
|
|
|
elif isinstance(elem, dict):
|
|
|
|
value_list = [v for _, v in elem.items()]
|
|
|
|
nbytes += compute_size_in_bytes(value_list)
|
|
|
|
elif isinstance(elem, tuple) or isinstance(elem, list) or isinstance(elem, set):
|
|
|
|
for e in elem:
|
|
|
|
nbytes += compute_size_in_bytes(e)
|
|
|
|
return nbytes
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class MetaInfo:
|
|
|
|
r"""
|
|
|
|
The base class to store all profiling and static graph analysis information
|
|
|
|
needed for auto-parallel system in Colossal-AI.
|
|
|
|
============================================================================
|
|
|
|
-------------------------------
|
|
|
|
| FX.Node | <-----
|
|
|
|
[input/param] are ---> |[input/param] [grad_inp]| [grad_inp] contributes to the
|
|
|
|
placeholders (might be | | \__________ | | profiled peak memory in backward
|
|
|
|
saved for backward. | | \ | | pass. [grad_param] is calculated
|
|
|
|
| | \ | | separately.
|
|
|
|
| [interm] -------> [grad_int]| <-----
|
|
|
|
| | \_________ | | [grad_interm] marks the peak
|
|
|
|
| / \ \ | | memory in backward pass.
|
|
|
|
[x] is not counted ---> | [x] [interm] --> [grad_int]| <-----
|
|
|
|
in [interm] because | | \_____ | |
|
|
|
|
it is not saved for | | \ | |
|
|
|
|
backward. | [output] \ | | <----- [output] is potentially
|
|
|
|
------------------------------- [input] for the next node.
|
|
|
|
============================================================================
|
|
|
|
|
|
|
|
Accumulate Size = ALL_PREVIOUS_CTX U {Interm Size + Output Size}
|
|
|
|
Output Size = ([output] in global_ctx and not is_alias)
|
|
|
|
Temp Size = ([output] not in global_ctx and not is_alias)
|
|
|
|
Backward Size = ([grad_inp])
|
|
|
|
|
|
|
|
Usage:
|
|
|
|
>>> for node in graph.nodes:
|
|
|
|
>>> n_info = MetaInfo(node) # will create a new MetaInfo instance and store in node.meta['info']
|
|
|
|
>>> # if not exist, otherwise return the existing one
|
|
|
|
>>> n_info.to_recompute = ... # set the to_recompute attribute
|
|
|
|
|
|
|
|
Remarks:
|
|
|
|
This feature is experimental and all the entries are subject to change.
|
|
|
|
"""
|
|
|
|
|
|
|
|
# reference
|
|
|
|
node: Node
|
|
|
|
|
|
|
|
# directory
|
|
|
|
mod_dir: str = ''
|
|
|
|
|
|
|
|
# ctx[data_ptr] = Tensor
|
|
|
|
# mark the storage for ctx.save_for_backward
|
|
|
|
global_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # globally shared
|
|
|
|
curr_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # global_ctx till this node
|
|
|
|
|
|
|
|
# should be updated after each graph manipulation
|
|
|
|
# ============================== Update ====================================
|
|
|
|
# parameter and buffer within ``Node``
|
|
|
|
parameters: Dict[str, torch.nn.Parameter] = field(default_factory=lambda: {})
|
|
|
|
buffers: Dict[str, torch.Tensor] = field(default_factory=lambda: {})
|
|
|
|
|
|
|
|
inputs: Tuple[torch.Tensor] = ()
|
|
|
|
outputs: Tuple[torch.Tensor] = ()
|
|
|
|
is_alias: Tuple[bool] = () # whether the output is an alias of input
|
|
|
|
|
|
|
|
# compute cost
|
|
|
|
fwd_flop: Optional[int] = 0
|
|
|
|
bwd_flop: Optional[int] = 0
|
|
|
|
|
|
|
|
# communication cost (should be the size in bytes of communication)
|
|
|
|
fwd_comm: Optional[int] = 0
|
|
|
|
bwd_comm: Optional[int] = 0
|
|
|
|
|
|
|
|
# should keep the same whenever manipulated
|
|
|
|
# ============================= Invariant ==================================
|
|
|
|
activation_checkpoint: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen
|
|
|
|
to_offload: Optional[bool] = False
|
|
|
|
sharding_spec: str = 'RR'
|
|
|
|
|
|
|
|
def __new__(cls, node: Node, **kwargs):
|
|
|
|
orig_init = cls.__init__
|
|
|
|
|
|
|
|
# if initialized, return the existing one
|
|
|
|
# should disable the __init__ function
|
|
|
|
if node.meta.get('info', None) is not None:
|
|
|
|
|
|
|
|
def _dummy(self, *args, **kwargs):
|
|
|
|
if getattr(self, '_is_init', False):
|
|
|
|
self._is_init = True
|
|
|
|
orig_init(self, *args, **kwargs)
|
|
|
|
cls.__init__ = orig_init
|
|
|
|
|
|
|
|
cls.__init__ = _dummy
|
|
|
|
return node.meta['info']
|
|
|
|
return super().__new__(cls)
|
|
|
|
|
|
|
|
def __post_init__(self):
|
|
|
|
self.node.meta['info'] = self
|
|
|
|
|
|
|
|
@property
|
|
|
|
def fwd_time(self, tflops: float = MeshConfig.TFLOPS, bandwidth: float = MeshConfig.BANDWIDTH):
|
|
|
|
return self.fwd_flop / tflops + self.fwd_comm / bandwidth
|
|
|
|
|
|
|
|
@property
|
|
|
|
def bwd_time(self, tflops: float = MeshConfig.TFLOPS, bandwidth: float = MeshConfig.BANDWIDTH):
|
|
|
|
return self.bwd_flop / tflops + self.bwd_comm / bandwidth
|
|
|
|
|
|
|
|
@property
|
|
|
|
def param_size(self):
|
|
|
|
return compute_size_in_bytes(self.parameters)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def buffer_size(self):
|
|
|
|
return compute_size_in_bytes(self.buffers)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def output_size(self):
|
|
|
|
"""Used in CheckpointSolver"""
|
|
|
|
output_ctx = {
|
|
|
|
o.data_ptr(): o
|
|
|
|
for o, is_alias in zip(self.outputs, self.is_alias)
|
|
|
|
if not is_alias and isinstance(o, torch.Tensor) and not isinstance(o, torch.nn.Parameter)
|
|
|
|
}
|
|
|
|
return compute_size_in_bytes(intersect(self.global_ctx, output_ctx))
|
|
|
|
|
|
|
|
@property
|
|
|
|
def accumulate_size(self):
|
|
|
|
"""Used in CheckpointSolver"""
|
|
|
|
output_ctx = {
|
|
|
|
o.data_ptr(): o
|
|
|
|
for o, is_alias in zip(self.outputs, self.is_alias)
|
|
|
|
if not is_alias and isinstance(o, torch.Tensor) and not isinstance(o, torch.nn.Parameter)
|
|
|
|
}
|
|
|
|
return compute_size_in_bytes(union(self.curr_ctx, intersect(self.global_ctx, output_ctx)))
|
|
|
|
|
|
|
|
@property
|
|
|
|
def temp_size(self):
|
|
|
|
"""Used in CheckpointSolver"""
|
|
|
|
output_ctx = {
|
|
|
|
o.data_ptr(): o
|
|
|
|
for o, is_alias in zip(self.outputs, self.is_alias)
|
|
|
|
if not is_alias and isinstance(o, torch.Tensor) and not isinstance(o, torch.nn.Parameter)
|
|
|
|
}
|
|
|
|
return compute_size_in_bytes(subtract(output_ctx, self.global_ctx))
|
|
|
|
|
|
|
|
@property
|
|
|
|
def backward_size(self):
|
|
|
|
"""Used in CheckpointSolver"""
|
|
|
|
return compute_size_in_bytes(self.inputs)
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
s = f'Node {self.node.name}'
|
|
|
|
if self.parameters:
|
|
|
|
s += f'\n\thas parameter of size {_format_memory(self.param_size)}'
|
|
|
|
if self.buffers:
|
|
|
|
s += f'\n\thas buffer of size {_format_memory(self.buffer_size)}'
|
|
|
|
if self.output_size:
|
|
|
|
s += f'\n\thas output activation of size {_format_memory(self.output_size)}'
|
|
|
|
# if self.total_size:
|
|
|
|
# s += f'\n\thas total activation of size {_format_memory(self.total_size)}'
|
|
|
|
if self.temp_size:
|
|
|
|
s += f'\n\thas temp activation of size {_format_memory(self.temp_size)}'
|
|
|
|
if self.backward_size:
|
|
|
|
s += f'\n\thas backward activation of size {_format_memory(self.backward_size)}'
|
|
|
|
s += f'\n\tfwd_flop = {self.fwd_flop}'\
|
|
|
|
f'\n\tbwd_flop = {self.bwd_flop}'\
|
|
|
|
f'\n\tfwd_comm = {self.fwd_comm}'\
|
|
|
|
f'\n\tbwd_comm = {self.bwd_comm}'\
|
|
|
|
f'\n\tto_recompute = {self.to_recompute}'\
|
|
|
|
f'\n\tto_offload = {self.to_offload}'\
|
|
|
|
f'\n\tsharding_spec = {self.sharding_spec}'
|
|
|
|
return s
|