You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/auto_parallel/passes/meta_info_prop.py

166 lines
5.9 KiB

import uuid
from dataclasses import asdict
from typing import List
import torch
import torch.fx
from torch.fx import GraphModule
from torch.fx.node import Node
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo
from colossalai.auto_parallel.passes.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
from colossalai.fx._compatibility import compatibility
from colossalai.fx.profiler import GraphInfo
def _normalize_tuple(x):
if not isinstance(x, tuple):
return (x,)
return x
@compatibility(is_backward_compatible=False)
class MetaInfoProp:
def __init__(self, module: GraphModule) -> None:
self.module = module
self.func_dict = {
"placeholder": self.placeholder_handler,
"get_attr": self.get_attr_handler,
"output": self.output_handler,
"call_function": self.node_handler,
"call_module": self.node_handler,
"call_method": self.node_handler,
}
def _set_data_ptr(self, x):
"""
Set uuid to tensor
"""
if isinstance(x, torch.Tensor):
if not x.data_ptr():
data_ptr = uuid.uuid4()
x.data_ptr = lambda: data_ptr
def _is_inplace(self, node: Node):
"""
Check if the node is inplace operation.
"""
if node.op == "call_module":
return node.graph.owning_module.get_submodule(node.target).__class__ in OUTPUT_SAVED_MOD
elif node.op == "call_function":
return node.target in OUTPUT_SAVED_OPS
return False
def run(self) -> GraphModule:
"""
Run the meta information propagation pass on the module.
"""
for node in self.module.graph.nodes:
node: Node
self.func_dict[node.op](node)
@compatibility(is_backward_compatible=False)
def placeholder_handler(self, node: Node) -> None:
"""
Handle the placeholder node.
"""
graph_info = GraphInfo()
out = _normalize_tuple(getattr(node, "_meta_data", None))
graph_info.fwd_out = list(out) if out[0] is not None else []
node.meta = {**asdict(graph_info)}
@compatibility(is_backward_compatible=False)
def get_attr_handler(self, node: Node) -> None:
"""
Handle the get_attr node.
"""
graph_info = GraphInfo()
node.meta = {**asdict(graph_info)}
@compatibility(is_backward_compatible=False)
def output_handler(self, node: Node) -> None:
"""
Handle the output node.
"""
graph_info = GraphInfo()
output_tensors = []
for par in node._input_nodes:
if par.meta:
output_tensors += par.meta["fwd_out"]
graph_info.fwd_in = output_tensors
node.meta = {**asdict(graph_info)}
@compatibility(is_backward_compatible=False)
def node_handler(self, node: Node) -> None:
"""
Handle other kind of nodes
"""
assert hasattr(node, "best_strategy_info"), f"Cannot find best_strategy_info in node {node}, {node.op}"
graph_info = GraphInfo()
meta_info = node.best_strategy_info
meta_info: ShardMetaInfo
# set data_ptr for input_tensor in ShardMetaInfo class
input_tensors: List[torch.Tensor] = meta_info.fwd_in
buffer_tensors: List[torch.Tensor] = meta_info.fwd_buffer
output_tensors: List[torch.Tensor] = meta_info.fwd_out
if self._is_inplace(node):
# inplace operation will not create new tensor, and it only has one parent node
# TODO: Verify this observation
# set data_ptr for input_tensor, buffer_tensor and output_tensor of current node
parent_node = list(node._input_nodes.keys())[0]
parent_tensor = parent_node.meta.get("fwd_out")[0]
parent_tensor: torch.Tensor
for tensor in input_tensors:
tensor.data_ptr = parent_tensor.data_ptr
for tensor in buffer_tensors:
tensor.data_ptr = parent_tensor.data_ptr
for tensor in output_tensors:
tensor.data_ptr = parent_tensor.data_ptr
else:
for par in node._input_nodes:
# set data_ptr for the input_tensor of current node from the output_tensor of its parent node
for tensor in par.meta.get("fwd_out", []):
tensor: torch.Tensor
target_input_tensor = next(
(x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None
)
if target_input_tensor is not None:
target_input_tensor.data_ptr = tensor.data_ptr
# set data_ptr for tensor in input_tensor that is not set
for tensor in input_tensors:
if not tensor.data_ptr():
self._set_data_ptr(tensor)
# set data_ptr for buffer_tensor
for tensor in buffer_tensors:
self._set_data_ptr(tensor)
# set data_ptr for output_tensor
for tensor in output_tensors:
self._set_data_ptr(tensor)
# attach them to graph_info
graph_info.fwd_in = input_tensors
graph_info.fwd_tmp = buffer_tensors
graph_info.fwd_out = output_tensors
# fetch other memory information
memory_cost = meta_info.memory_cost
graph_info.fwd_mem_tmp = memory_cost.fwd.temp
graph_info.fwd_mem_out = memory_cost.fwd.activation
graph_info.bwd_mem_tmp = memory_cost.bwd.temp
graph_info.bwd_mem_out = memory_cost.bwd.activation
# fetch flop information
# here we use fwd_time and bwd_time to deal with the case that
# communication cost is a float
compute_cost = meta_info.compute_cost
graph_info.fwd_time = compute_cost.fwd
graph_info.bwd_time = compute_cost.bwd
node.meta = {**asdict(graph_info)}