mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
165 lines
5.9 KiB
165 lines
5.9 KiB
import uuid |
|
from dataclasses import asdict |
|
from typing import List |
|
|
|
import torch |
|
import torch.fx |
|
from torch.fx import GraphModule |
|
from torch.fx.node import Node |
|
|
|
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo |
|
from colossalai.auto_parallel.passes.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS |
|
from colossalai.fx._compatibility import compatibility |
|
from colossalai.fx.profiler import GraphInfo |
|
|
|
|
|
def _normalize_tuple(x): |
|
if not isinstance(x, tuple): |
|
return (x,) |
|
return x |
|
|
|
|
|
@compatibility(is_backward_compatible=False) |
|
class MetaInfoProp: |
|
def __init__(self, module: GraphModule) -> None: |
|
self.module = module |
|
self.func_dict = { |
|
"placeholder": self.placeholder_handler, |
|
"get_attr": self.get_attr_handler, |
|
"output": self.output_handler, |
|
"call_function": self.node_handler, |
|
"call_module": self.node_handler, |
|
"call_method": self.node_handler, |
|
} |
|
|
|
def _set_data_ptr(self, x): |
|
""" |
|
Set uuid to tensor |
|
""" |
|
if isinstance(x, torch.Tensor): |
|
if not x.data_ptr(): |
|
data_ptr = uuid.uuid4() |
|
x.data_ptr = lambda: data_ptr |
|
|
|
def _is_inplace(self, node: Node): |
|
""" |
|
Check if the node is inplace operation. |
|
""" |
|
if node.op == "call_module": |
|
return node.graph.owning_module.get_submodule(node.target).__class__ in OUTPUT_SAVED_MOD |
|
elif node.op == "call_function": |
|
return node.target in OUTPUT_SAVED_OPS |
|
return False |
|
|
|
def run(self) -> GraphModule: |
|
""" |
|
Run the meta information propagation pass on the module. |
|
""" |
|
for node in self.module.graph.nodes: |
|
node: Node |
|
self.func_dict[node.op](node) |
|
|
|
@compatibility(is_backward_compatible=False) |
|
def placeholder_handler(self, node: Node) -> None: |
|
""" |
|
Handle the placeholder node. |
|
""" |
|
graph_info = GraphInfo() |
|
out = _normalize_tuple(getattr(node, "_meta_data", None)) |
|
graph_info.fwd_out = list(out) if out[0] is not None else [] |
|
node.meta = {**asdict(graph_info)} |
|
|
|
@compatibility(is_backward_compatible=False) |
|
def get_attr_handler(self, node: Node) -> None: |
|
""" |
|
Handle the get_attr node. |
|
""" |
|
graph_info = GraphInfo() |
|
node.meta = {**asdict(graph_info)} |
|
|
|
@compatibility(is_backward_compatible=False) |
|
def output_handler(self, node: Node) -> None: |
|
""" |
|
Handle the output node. |
|
""" |
|
graph_info = GraphInfo() |
|
output_tensors = [] |
|
for par in node._input_nodes: |
|
if par.meta: |
|
output_tensors += par.meta["fwd_out"] |
|
graph_info.fwd_in = output_tensors |
|
node.meta = {**asdict(graph_info)} |
|
|
|
@compatibility(is_backward_compatible=False) |
|
def node_handler(self, node: Node) -> None: |
|
""" |
|
Handle other kind of nodes |
|
""" |
|
assert hasattr(node, "best_strategy_info"), f"Cannot find best_strategy_info in node {node}, {node.op}" |
|
graph_info = GraphInfo() |
|
meta_info = node.best_strategy_info |
|
meta_info: ShardMetaInfo |
|
|
|
# set data_ptr for input_tensor in ShardMetaInfo class |
|
input_tensors: List[torch.Tensor] = meta_info.fwd_in |
|
buffer_tensors: List[torch.Tensor] = meta_info.fwd_buffer |
|
output_tensors: List[torch.Tensor] = meta_info.fwd_out |
|
|
|
if self._is_inplace(node): |
|
# inplace operation will not create new tensor, and it only has one parent node |
|
# TODO: Verify this observation |
|
# set data_ptr for input_tensor, buffer_tensor and output_tensor of current node |
|
parent_node = list(node._input_nodes.keys())[0] |
|
parent_tensor = parent_node.meta.get("fwd_out")[0] |
|
parent_tensor: torch.Tensor |
|
for tensor in input_tensors: |
|
tensor.data_ptr = parent_tensor.data_ptr |
|
for tensor in buffer_tensors: |
|
tensor.data_ptr = parent_tensor.data_ptr |
|
for tensor in output_tensors: |
|
tensor.data_ptr = parent_tensor.data_ptr |
|
|
|
else: |
|
for par in node._input_nodes: |
|
# set data_ptr for the input_tensor of current node from the output_tensor of its parent node |
|
for tensor in par.meta.get("fwd_out", []): |
|
tensor: torch.Tensor |
|
target_input_tensor = next( |
|
(x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None |
|
) |
|
if target_input_tensor is not None: |
|
target_input_tensor.data_ptr = tensor.data_ptr |
|
|
|
# set data_ptr for tensor in input_tensor that is not set |
|
for tensor in input_tensors: |
|
if not tensor.data_ptr(): |
|
self._set_data_ptr(tensor) |
|
|
|
# set data_ptr for buffer_tensor |
|
for tensor in buffer_tensors: |
|
self._set_data_ptr(tensor) |
|
|
|
# set data_ptr for output_tensor |
|
for tensor in output_tensors: |
|
self._set_data_ptr(tensor) |
|
|
|
# attach them to graph_info |
|
graph_info.fwd_in = input_tensors |
|
graph_info.fwd_tmp = buffer_tensors |
|
graph_info.fwd_out = output_tensors |
|
|
|
# fetch other memory information |
|
memory_cost = meta_info.memory_cost |
|
graph_info.fwd_mem_tmp = memory_cost.fwd.temp |
|
graph_info.fwd_mem_out = memory_cost.fwd.activation |
|
graph_info.bwd_mem_tmp = memory_cost.bwd.temp |
|
graph_info.bwd_mem_out = memory_cost.bwd.activation |
|
|
|
# fetch flop information |
|
# here we use fwd_time and bwd_time to deal with the case that |
|
# communication cost is a float |
|
compute_cost = meta_info.compute_cost |
|
graph_info.fwd_time = compute_cost.fwd |
|
graph_info.bwd_time = compute_cost.bwd |
|
|
|
node.meta = {**asdict(graph_info)}
|
|
|