Browse Source

[autoparallel] align the data_ptr with the old version of auto activation checkpoint pipeline (#2261)

pull/2263/head^2
Boyuan Yao 2 years ago committed by GitHub
parent
commit
1ea99b869e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 5
      colossalai/auto_parallel/meta_profiler/constants.py
  2. 2
      colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py
  3. 4
      colossalai/auto_parallel/meta_profiler/metainfo.py
  4. 8
      colossalai/auto_parallel/passes/constants.py
  5. 75
      colossalai/auto_parallel/passes/meta_info_prop.py

5
colossalai/auto_parallel/meta_profiler/constants.py

@ -5,8 +5,11 @@ import torch.nn as nn
from ..tensor_shard.constants import *
# list of inplace operations
# list of inplace module
INPLACE_MODULE = [nn.ReLU]
# list of inplace operations
INPLACE_OPS = [torch.flatten]
# list of operations that do not save forward activations
NO_SAVE_ACTIVATION = [torch.add, torch.sub, operator.add, operator.sub]

2
colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py

@ -60,7 +60,7 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_op_data.data, device='meta'), torch.zeros_like(other_op_data.data, device='meta')]
fwd_in = []
fwd_buffer = []
fwd_out = [torch.zeros_like(output_op_data.data, device='meta')]

4
colossalai/auto_parallel/meta_profiler/metainfo.py

@ -12,7 +12,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
)
from colossalai.tensor.sharding_spec import ShardingSpec
from .constants import INPLACE_MODULE, NO_SAVE_ACTIVATION
from .constants import INPLACE_MODULE, INPLACE_OPS, NO_SAVE_ACTIVATION
from .registry import meta_register
__all__ = ['MetaInfo']
@ -104,6 +104,8 @@ class MetaInfo:
# construct kwargs
if self.target in INPLACE_MODULE:
kwargs = {'inplace': self.target.inplace}
elif self.target in INPLACE_OPS:
kwargs = {'inplace': True}
else:
kwargs = {'inplace': False}

8
colossalai/auto_parallel/passes/constants.py

@ -0,0 +1,8 @@
import torch
OUTPUT_SAVED_OPS = [torch.nn.functional.relu, torch.nn.functional.softmax, torch.flatten]
OUTPUT_SAVED_MOD = [
torch.nn.ReLU,
torch.nn.Softmax,
]

75
colossalai/auto_parallel/passes/meta_info_prop.py

@ -8,9 +8,9 @@ from torch.fx import GraphModule
from torch.fx.node import Node
from colossalai.auto_parallel.meta_profiler import MetaInfo
from colossalai.auto_parallel.passes.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
from colossalai.fx._compatibility import compatibility
from colossalai.fx.profiler import GraphInfo
from colossalai.fx.profiler.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
def _normalize_tuple(x):
@ -46,7 +46,7 @@ class MetaInfoProp:
"""
Check if the node is inplace operation.
"""
if node.op == 'call_method':
if node.op == 'call_module':
return node.graph.owning_module.get_submodule(node.target).__class__ in OUTPUT_SAVED_MOD
elif node.op == "call_function":
return node.target in OUTPUT_SAVED_OPS
@ -102,56 +102,51 @@ class MetaInfoProp:
meta_info: MetaInfo
# set data_ptr for input_tensor in MetaInfo class
input_tensor: List[torch.Tensor] = meta_info.fwd_in
buffer_tensor: List[torch.Tensor] = meta_info.fwd_buffer
output_tensor: List[torch.Tensor] = meta_info.fwd_out
input_tensors: List[torch.Tensor] = meta_info.fwd_in
buffer_tensors: List[torch.Tensor] = meta_info.fwd_buffer
output_tensors: List[torch.Tensor] = meta_info.fwd_out
if len(input_tensor) > 0:
if self._is_inplace(node):
# inplace operation will not create new tensor, and it only has one parent node
# TODO: Verify this observation
# set data_ptr for input_tensor, buffer_tensor and output_tensor of current node
parent_node = list(node._input_nodes.keys())[0]
parent_tensor = parent_node.meta.get("fwd_out")[0]
parent_tensor: torch.Tensor
for tensor in input_tensors:
tensor.data_ptr = parent_tensor.data_ptr
for tensor in buffer_tensors:
tensor.data_ptr = parent_tensor.data_ptr
for tensor in output_tensors:
tensor.data_ptr = parent_tensor.data_ptr
else:
for par in node._input_nodes:
if par.meta:
if len(par.meta["fwd_out"]) > 0:
# set data_ptr for the input_tensor of current node from the output_tensor of its parent node
for tensor in par.meta["fwd_out"]:
tensor: torch.Tensor
target_tensor = next(
(x for x in input_tensor if not x.data_ptr() and x.shape == tensor.shape), None)
target_tensor.data_ptr = tensor.data_ptr
# set data_ptr for the input_tensor of current node from the output_tensor of its parent node
for tensor in par.meta.get("fwd_out", []):
tensor: torch.Tensor
target_input_tensor = next(
(x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None)
if target_input_tensor is not None:
target_input_tensor.data_ptr = tensor.data_ptr
# set data_ptr for tensor in input_tensor that is not set
for tensor in input_tensor:
for tensor in input_tensors:
if not tensor.data_ptr():
self._set_data_ptr(tensor)
# attach it to graph_info
graph_info.fwd_in = input_tensor
if self._is_inplace(node):
# inplace operation will not create new tensor
# set data_ptr for buffer_tensor and output_tensor of current node
for tensor in input_tensor:
tensor: torch.Tensor
target_buffer_tensor = next((x for x in buffer_tensor if not x.data_ptr() and x.shape == tensor.shape),
None)
target_output_tensor = next((x for x in output_tensor if not x.data_ptr() and x.shape == tensor.shape),
None)
target_buffer_tensor.data_ptr = tensor.data_ptr
target_output_tensor.data_ptr = tensor.data_ptr
# attach them to graph_info
graph_info.fwd_tmp = buffer_tensor
graph_info.fwd_out = output_tensor
else:
# set data_ptr for buffer_tensor
for tensor in buffer_tensor:
for tensor in buffer_tensors:
self._set_data_ptr(tensor)
# attach it to graph_info
graph_info.fwd_tmp = buffer_tensor
# set data_ptr for output_tensor
for tensor in output_tensor:
for tensor in output_tensors:
self._set_data_ptr(tensor)
# attach it to graph_info
graph_info.fwd_out = output_tensor
# attach them to graph_info
graph_info.fwd_in = input_tensors
graph_info.fwd_tmp = buffer_tensors
graph_info.fwd_out = output_tensors
# fetch other memory informations
memory_cost = meta_info.memory_cost

Loading…
Cancel
Save