ColossalAI/colossalai/auto_parallel/passes/meta_info_prop.py

163 lines
5.8 KiB
Python
Raw Normal View History

import uuid
from dataclasses import asdict
from typing import Any, Dict, List, NamedTuple, Tuple
import torch
import torch.fx
from torch.fx import GraphModule
from torch.fx.node import Argument, Node, Target
from torch.utils._pytree import tree_map
from colossalai.auto_parallel.meta_profiler import MetaInfo
from colossalai.fx._compatibility import compatibility, is_compatible_with_meta
from colossalai.fx.profiler import GraphInfo
from colossalai.fx.profiler.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
def _normalize_tuple(x):
if not isinstance(x, tuple):
return (x,)
return x
@compatibility(is_backward_compatible=False)
class MetaInfoProp:
def __init__(self, module: GraphModule) -> None:
self.module = module
self.func_dict = {
'placeholder': self.placeholder_handler,
'get_attr': self.get_attr_handler,
'output': self.output_handler,
'call_function': self.node_handler,
'call_module': self.node_handler,
'call_method': self.node_handler,
}
def _set_data_ptr(self, x):
"""
Set uuid to tensor
"""
if isinstance(x, torch.Tensor):
if not x.data_ptr():
data_ptr = uuid.uuid4()
x.data_ptr = lambda: data_ptr
def _is_inplace(self, node: Node):
"""
Check if the node is inplace operation.
"""
if node.op == 'call_method':
return node.graph.owning_module.get_submodule(node.target).__class__ in OUTPUT_SAVED_MOD
elif node.op == "call_function":
return node.target in OUTPUT_SAVED_OPS
return False
def run(self) -> GraphModule:
"""
Run the meta information propagation pass on the module.
"""
for node in self.module.graph.nodes:
node: Node
self.func_dict[node.op](node)
@compatibility(is_backward_compatible=False)
def placeholder_handler(self, node: Node) -> None:
"""
Handle the placeholder node.
"""
graph_info = GraphInfo()
out = _normalize_tuple(getattr(node, '_meta_data', None))
graph_info.fwd_out = list(out)
node.meta = {**asdict(graph_info)}
@compatibility(is_backward_compatible=False)
def get_attr_handler(self, node: Node) -> None:
"""
Handle the get_attr node.
"""
graph_info = GraphInfo()
node.meta = {**asdict(graph_info)}
@compatibility(is_backward_compatible=False)
def output_handler(self, node: Node) -> None:
"""
Handle the output node.
"""
graph_info = GraphInfo()
output_tensors = []
for par in node._input_nodes:
if par.meta:
output_tensors += par.meta["fwd_out"]
graph_info.fwd_in = output_tensors
node.meta = {**asdict(graph_info)}
@compatibility(is_backward_compatible=False)
def node_handler(self, node: Node) -> None:
"""
Handle other kind of nodes
"""
assert hasattr(node, 'best_metainfo'), f"Cannot find best_metainfo in node {node}"
graph_info = GraphInfo()
meta_info = node.best_metainfo
meta_info: MetaInfo
# set data_ptr for input_tensor in MetaInfo class
input_tensor: List[torch.Tensor] = meta_info.fwd_in
buffer_tensor: List[torch.Tensor] = meta_info.fwd_buffer
output_tensor: List[torch.Tensor] = meta_info.fwd_out
if len(input_tensor) > 0:
for par in node._input_nodes:
if par.meta:
if len(par.meta["fwd_out"]) > 0:
# set data_ptr for the input_tensor of current node from the output_tensor of its parent node
for tensor in par.meta["fwd_out"]:
tensor: torch.Tensor
target_tensor = next(
(x for x in input_tensor if not x.data_ptr() and x.shape == tensor.shape), None)
target_tensor.data_ptr = tensor.data_ptr
# set data_ptr for tensor in input_tensor that is not set
for tensor in input_tensor:
if not tensor.data_ptr():
self._set_data_ptr(tensor)
# attach it to graph_info
graph_info.fwd_in = input_tensor
if self._is_inplace(node):
# inplace operation will not create new tensor
# set data_ptr for buffer_tensor and output_tensor of current node
for tensor in input_tensor:
tensor: torch.Tensor
target_buffer_tensor = next((x for x in buffer_tensor if not x.data_ptr() and x.shape == tensor.shape),
None)
target_output_tensor = next((x for x in output_tensor if not x.data_ptr() and x.shape == tensor.shape),
None)
target_buffer_tensor.data_ptr = tensor.data_ptr
target_output_tensor.data_ptr = tensor.data_ptr
# attach them to graph_info
graph_info.fwd_tmp = buffer_tensor
graph_info.fwd_out = output_tensor
else:
# set data_ptr for buffer_tensor
for tensor in buffer_tensor:
self._set_data_ptr(tensor)
# attach it to graph_info
graph_info.fwd_tmp = buffer_tensor
# set data_ptr for output_tensor
for tensor in output_tensor:
self._set_data_ptr(tensor)
# attach it to graph_info
graph_info.fwd_out = output_tensor
# fetch other memory informations
memory_cost = meta_info.memory_cost
graph_info.fwd_mem_tmp = memory_cost.fwd.temp
graph_info.bwd_mem_tmp = memory_cost.bwd.temp
node.meta = {**asdict(graph_info)}