import uuid from dataclasses import asdict from typing import List import torch import torch.fx from torch.fx import GraphModule from torch.fx.node import Node from colossalai.auto_parallel.meta_profiler import ShardMetaInfo from colossalai.auto_parallel.passes.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS from colossalai.fx._compatibility import compatibility from colossalai.fx.profiler import GraphInfo def _normalize_tuple(x): if not isinstance(x, tuple): return (x,) return x @compatibility(is_backward_compatible=False) class MetaInfoProp: def __init__(self, module: GraphModule) -> None: self.module = module self.func_dict = { 'placeholder': self.placeholder_handler, 'get_attr': self.get_attr_handler, 'output': self.output_handler, 'call_function': self.node_handler, 'call_module': self.node_handler, 'call_method': self.node_handler, } def _set_data_ptr(self, x): """ Set uuid to tensor """ if isinstance(x, torch.Tensor): if not x.data_ptr(): data_ptr = uuid.uuid4() x.data_ptr = lambda: data_ptr def _is_inplace(self, node: Node): """ Check if the node is inplace operation. """ if node.op == 'call_module': return node.graph.owning_module.get_submodule(node.target).__class__ in OUTPUT_SAVED_MOD elif node.op == "call_function": return node.target in OUTPUT_SAVED_OPS return False def run(self) -> GraphModule: """ Run the meta information propagation pass on the module. """ for node in self.module.graph.nodes: node: Node self.func_dict[node.op](node) @compatibility(is_backward_compatible=False) def placeholder_handler(self, node: Node) -> None: """ Handle the placeholder node. """ graph_info = GraphInfo() out = _normalize_tuple(getattr(node, '_meta_data', None)) graph_info.fwd_out = list(out) if out[0] is not None else [] node.meta = {**asdict(graph_info)} @compatibility(is_backward_compatible=False) def get_attr_handler(self, node: Node) -> None: """ Handle the get_attr node. """ graph_info = GraphInfo() node.meta = {**asdict(graph_info)} @compatibility(is_backward_compatible=False) def output_handler(self, node: Node) -> None: """ Handle the output node. """ graph_info = GraphInfo() output_tensors = [] for par in node._input_nodes: if par.meta: output_tensors += par.meta["fwd_out"] graph_info.fwd_in = output_tensors node.meta = {**asdict(graph_info)} @compatibility(is_backward_compatible=False) def node_handler(self, node: Node) -> None: """ Handle other kind of nodes """ assert hasattr(node, 'best_strategy_info'), f"Cannot find best_strategy_info in node {node}, {node.op}" graph_info = GraphInfo() meta_info = node.best_strategy_info meta_info: ShardMetaInfo # set data_ptr for input_tensor in ShardMetaInfo class input_tensors: List[torch.Tensor] = meta_info.fwd_in buffer_tensors: List[torch.Tensor] = meta_info.fwd_buffer output_tensors: List[torch.Tensor] = meta_info.fwd_out if self._is_inplace(node): # inplace operation will not create new tensor, and it only has one parent node # TODO: Verify this observation # set data_ptr for input_tensor, buffer_tensor and output_tensor of current node parent_node = list(node._input_nodes.keys())[0] parent_tensor = parent_node.meta.get("fwd_out")[0] parent_tensor: torch.Tensor for tensor in input_tensors: tensor.data_ptr = parent_tensor.data_ptr for tensor in buffer_tensors: tensor.data_ptr = parent_tensor.data_ptr for tensor in output_tensors: tensor.data_ptr = parent_tensor.data_ptr else: for par in node._input_nodes: # set data_ptr for the input_tensor of current node from the output_tensor of its parent node for tensor in par.meta.get("fwd_out", []): tensor: torch.Tensor target_input_tensor = next( (x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None) if target_input_tensor is not None: target_input_tensor.data_ptr = tensor.data_ptr # set data_ptr for tensor in input_tensor that is not set for tensor in input_tensors: if not tensor.data_ptr(): self._set_data_ptr(tensor) # set data_ptr for buffer_tensor for tensor in buffer_tensors: self._set_data_ptr(tensor) # set data_ptr for output_tensor for tensor in output_tensors: self._set_data_ptr(tensor) # attach them to graph_info graph_info.fwd_in = input_tensors graph_info.fwd_tmp = buffer_tensors graph_info.fwd_out = output_tensors # fetch other memory information memory_cost = meta_info.memory_cost graph_info.fwd_mem_tmp = memory_cost.fwd.temp graph_info.fwd_mem_out = memory_cost.fwd.activation graph_info.bwd_mem_tmp = memory_cost.bwd.temp graph_info.bwd_mem_out = memory_cost.bwd.activation # fetch flop information # here we use fwd_time and bwd_time to deal with the case that # communication cost is a float compute_cost = meta_info.compute_cost graph_info.fwd_time = compute_cost.fwd graph_info.bwd_time = compute_cost.bwd node.meta = {**asdict(graph_info)}