|
|
|
@ -8,9 +8,9 @@ from torch.fx import GraphModule
|
|
|
|
|
from torch.fx.node import Node |
|
|
|
|
|
|
|
|
|
from colossalai.auto_parallel.meta_profiler import MetaInfo |
|
|
|
|
from colossalai.auto_parallel.passes.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS |
|
|
|
|
from colossalai.fx._compatibility import compatibility |
|
|
|
|
from colossalai.fx.profiler import GraphInfo |
|
|
|
|
from colossalai.fx.profiler.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _normalize_tuple(x): |
|
|
|
@ -46,7 +46,7 @@ class MetaInfoProp:
|
|
|
|
|
""" |
|
|
|
|
Check if the node is inplace operation. |
|
|
|
|
""" |
|
|
|
|
if node.op == 'call_method': |
|
|
|
|
if node.op == 'call_module': |
|
|
|
|
return node.graph.owning_module.get_submodule(node.target).__class__ in OUTPUT_SAVED_MOD |
|
|
|
|
elif node.op == "call_function": |
|
|
|
|
return node.target in OUTPUT_SAVED_OPS |
|
|
|
@ -102,56 +102,51 @@ class MetaInfoProp:
|
|
|
|
|
meta_info: MetaInfo |
|
|
|
|
|
|
|
|
|
# set data_ptr for input_tensor in MetaInfo class |
|
|
|
|
input_tensor: List[torch.Tensor] = meta_info.fwd_in |
|
|
|
|
buffer_tensor: List[torch.Tensor] = meta_info.fwd_buffer |
|
|
|
|
output_tensor: List[torch.Tensor] = meta_info.fwd_out |
|
|
|
|
input_tensors: List[torch.Tensor] = meta_info.fwd_in |
|
|
|
|
buffer_tensors: List[torch.Tensor] = meta_info.fwd_buffer |
|
|
|
|
output_tensors: List[torch.Tensor] = meta_info.fwd_out |
|
|
|
|
|
|
|
|
|
if len(input_tensor) > 0: |
|
|
|
|
if self._is_inplace(node): |
|
|
|
|
# inplace operation will not create new tensor, and it only has one parent node |
|
|
|
|
# TODO: Verify this observation |
|
|
|
|
# set data_ptr for input_tensor, buffer_tensor and output_tensor of current node |
|
|
|
|
parent_node = list(node._input_nodes.keys())[0] |
|
|
|
|
parent_tensor = parent_node.meta.get("fwd_out")[0] |
|
|
|
|
parent_tensor: torch.Tensor |
|
|
|
|
for tensor in input_tensors: |
|
|
|
|
tensor.data_ptr = parent_tensor.data_ptr |
|
|
|
|
for tensor in buffer_tensors: |
|
|
|
|
tensor.data_ptr = parent_tensor.data_ptr |
|
|
|
|
for tensor in output_tensors: |
|
|
|
|
tensor.data_ptr = parent_tensor.data_ptr |
|
|
|
|
|
|
|
|
|
else: |
|
|
|
|
for par in node._input_nodes: |
|
|
|
|
if par.meta: |
|
|
|
|
if len(par.meta["fwd_out"]) > 0: |
|
|
|
|
# set data_ptr for the input_tensor of current node from the output_tensor of its parent node |
|
|
|
|
for tensor in par.meta["fwd_out"]: |
|
|
|
|
tensor: torch.Tensor |
|
|
|
|
target_tensor = next( |
|
|
|
|
(x for x in input_tensor if not x.data_ptr() and x.shape == tensor.shape), None) |
|
|
|
|
target_tensor.data_ptr = tensor.data_ptr |
|
|
|
|
# set data_ptr for the input_tensor of current node from the output_tensor of its parent node |
|
|
|
|
for tensor in par.meta.get("fwd_out", []): |
|
|
|
|
tensor: torch.Tensor |
|
|
|
|
target_input_tensor = next( |
|
|
|
|
(x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None) |
|
|
|
|
if target_input_tensor is not None: |
|
|
|
|
target_input_tensor.data_ptr = tensor.data_ptr |
|
|
|
|
|
|
|
|
|
# set data_ptr for tensor in input_tensor that is not set |
|
|
|
|
for tensor in input_tensor: |
|
|
|
|
for tensor in input_tensors: |
|
|
|
|
if not tensor.data_ptr(): |
|
|
|
|
self._set_data_ptr(tensor) |
|
|
|
|
|
|
|
|
|
# attach it to graph_info |
|
|
|
|
graph_info.fwd_in = input_tensor |
|
|
|
|
|
|
|
|
|
if self._is_inplace(node): |
|
|
|
|
# inplace operation will not create new tensor |
|
|
|
|
# set data_ptr for buffer_tensor and output_tensor of current node |
|
|
|
|
for tensor in input_tensor: |
|
|
|
|
tensor: torch.Tensor |
|
|
|
|
target_buffer_tensor = next((x for x in buffer_tensor if not x.data_ptr() and x.shape == tensor.shape), |
|
|
|
|
None) |
|
|
|
|
target_output_tensor = next((x for x in output_tensor if not x.data_ptr() and x.shape == tensor.shape), |
|
|
|
|
None) |
|
|
|
|
target_buffer_tensor.data_ptr = tensor.data_ptr |
|
|
|
|
target_output_tensor.data_ptr = tensor.data_ptr |
|
|
|
|
# attach them to graph_info |
|
|
|
|
graph_info.fwd_tmp = buffer_tensor |
|
|
|
|
graph_info.fwd_out = output_tensor |
|
|
|
|
|
|
|
|
|
else: |
|
|
|
|
# set data_ptr for buffer_tensor |
|
|
|
|
for tensor in buffer_tensor: |
|
|
|
|
for tensor in buffer_tensors: |
|
|
|
|
self._set_data_ptr(tensor) |
|
|
|
|
# attach it to graph_info |
|
|
|
|
graph_info.fwd_tmp = buffer_tensor |
|
|
|
|
|
|
|
|
|
# set data_ptr for output_tensor |
|
|
|
|
for tensor in output_tensor: |
|
|
|
|
for tensor in output_tensors: |
|
|
|
|
self._set_data_ptr(tensor) |
|
|
|
|
# attach it to graph_info |
|
|
|
|
graph_info.fwd_out = output_tensor |
|
|
|
|
|
|
|
|
|
# attach them to graph_info |
|
|
|
|
graph_info.fwd_in = input_tensors |
|
|
|
|
graph_info.fwd_tmp = buffer_tensors |
|
|
|
|
graph_info.fwd_out = output_tensors |
|
|
|
|
|
|
|
|
|
# fetch other memory informations |
|
|
|
|
memory_cost = meta_info.memory_cost |
|
|
|
|