diff --git a/colossalai/auto_parallel/meta_profiler/constants.py b/colossalai/auto_parallel/meta_profiler/constants.py index 714674b7b..35b8c13ee 100644 --- a/colossalai/auto_parallel/meta_profiler/constants.py +++ b/colossalai/auto_parallel/meta_profiler/constants.py @@ -5,8 +5,11 @@ import torch.nn as nn from ..tensor_shard.constants import * -# list of inplace operations +# list of inplace module INPLACE_MODULE = [nn.ReLU] +# list of inplace operations +INPLACE_OPS = [torch.flatten] + # list of operations that do not save forward activations NO_SAVE_ACTIVATION = [torch.add, torch.sub, operator.add, operator.sub] diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py index b4cc58d05..15c3063b7 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py @@ -60,7 +60,7 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) # store fwd_in, fwd_buffer, fwd_out - fwd_in = [torch.zeros_like(input_op_data.data, device='meta'), torch.zeros_like(other_op_data.data, device='meta')] + fwd_in = [] fwd_buffer = [] fwd_out = [torch.zeros_like(output_op_data.data, device='meta')] diff --git a/colossalai/auto_parallel/meta_profiler/metainfo.py b/colossalai/auto_parallel/meta_profiler/metainfo.py index ff76e3059..218187768 100644 --- a/colossalai/auto_parallel/meta_profiler/metainfo.py +++ b/colossalai/auto_parallel/meta_profiler/metainfo.py @@ -12,7 +12,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( ) from colossalai.tensor.sharding_spec import ShardingSpec -from .constants import INPLACE_MODULE, NO_SAVE_ACTIVATION +from .constants import INPLACE_MODULE, INPLACE_OPS, NO_SAVE_ACTIVATION from .registry import meta_register __all__ = ['MetaInfo'] @@ -104,6 +104,8 @@ class MetaInfo: # construct kwargs if self.target in INPLACE_MODULE: kwargs = {'inplace': self.target.inplace} + elif self.target in INPLACE_OPS: + kwargs = {'inplace': True} else: kwargs = {'inplace': False} diff --git a/colossalai/auto_parallel/passes/constants.py b/colossalai/auto_parallel/passes/constants.py new file mode 100644 index 000000000..b86088474 --- /dev/null +++ b/colossalai/auto_parallel/passes/constants.py @@ -0,0 +1,8 @@ +import torch + +OUTPUT_SAVED_OPS = [torch.nn.functional.relu, torch.nn.functional.softmax, torch.flatten] + +OUTPUT_SAVED_MOD = [ + torch.nn.ReLU, + torch.nn.Softmax, +] diff --git a/colossalai/auto_parallel/passes/meta_info_prop.py b/colossalai/auto_parallel/passes/meta_info_prop.py index 607f7e17e..bdeaeffed 100644 --- a/colossalai/auto_parallel/passes/meta_info_prop.py +++ b/colossalai/auto_parallel/passes/meta_info_prop.py @@ -8,9 +8,9 @@ from torch.fx import GraphModule from torch.fx.node import Node from colossalai.auto_parallel.meta_profiler import MetaInfo +from colossalai.auto_parallel.passes.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS from colossalai.fx._compatibility import compatibility from colossalai.fx.profiler import GraphInfo -from colossalai.fx.profiler.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS def _normalize_tuple(x): @@ -46,7 +46,7 @@ class MetaInfoProp: """ Check if the node is inplace operation. """ - if node.op == 'call_method': + if node.op == 'call_module': return node.graph.owning_module.get_submodule(node.target).__class__ in OUTPUT_SAVED_MOD elif node.op == "call_function": return node.target in OUTPUT_SAVED_OPS @@ -102,56 +102,51 @@ class MetaInfoProp: meta_info: MetaInfo # set data_ptr for input_tensor in MetaInfo class - input_tensor: List[torch.Tensor] = meta_info.fwd_in - buffer_tensor: List[torch.Tensor] = meta_info.fwd_buffer - output_tensor: List[torch.Tensor] = meta_info.fwd_out + input_tensors: List[torch.Tensor] = meta_info.fwd_in + buffer_tensors: List[torch.Tensor] = meta_info.fwd_buffer + output_tensors: List[torch.Tensor] = meta_info.fwd_out - if len(input_tensor) > 0: + if self._is_inplace(node): + # inplace operation will not create new tensor, and it only has one parent node + # TODO: Verify this observation + # set data_ptr for input_tensor, buffer_tensor and output_tensor of current node + parent_node = list(node._input_nodes.keys())[0] + parent_tensor = parent_node.meta.get("fwd_out")[0] + parent_tensor: torch.Tensor + for tensor in input_tensors: + tensor.data_ptr = parent_tensor.data_ptr + for tensor in buffer_tensors: + tensor.data_ptr = parent_tensor.data_ptr + for tensor in output_tensors: + tensor.data_ptr = parent_tensor.data_ptr + + else: for par in node._input_nodes: - if par.meta: - if len(par.meta["fwd_out"]) > 0: - # set data_ptr for the input_tensor of current node from the output_tensor of its parent node - for tensor in par.meta["fwd_out"]: - tensor: torch.Tensor - target_tensor = next( - (x for x in input_tensor if not x.data_ptr() and x.shape == tensor.shape), None) - target_tensor.data_ptr = tensor.data_ptr + # set data_ptr for the input_tensor of current node from the output_tensor of its parent node + for tensor in par.meta.get("fwd_out", []): + tensor: torch.Tensor + target_input_tensor = next( + (x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None) + if target_input_tensor is not None: + target_input_tensor.data_ptr = tensor.data_ptr # set data_ptr for tensor in input_tensor that is not set - for tensor in input_tensor: + for tensor in input_tensors: if not tensor.data_ptr(): self._set_data_ptr(tensor) - # attach it to graph_info - graph_info.fwd_in = input_tensor - - if self._is_inplace(node): - # inplace operation will not create new tensor - # set data_ptr for buffer_tensor and output_tensor of current node - for tensor in input_tensor: - tensor: torch.Tensor - target_buffer_tensor = next((x for x in buffer_tensor if not x.data_ptr() and x.shape == tensor.shape), - None) - target_output_tensor = next((x for x in output_tensor if not x.data_ptr() and x.shape == tensor.shape), - None) - target_buffer_tensor.data_ptr = tensor.data_ptr - target_output_tensor.data_ptr = tensor.data_ptr - # attach them to graph_info - graph_info.fwd_tmp = buffer_tensor - graph_info.fwd_out = output_tensor - - else: # set data_ptr for buffer_tensor - for tensor in buffer_tensor: + for tensor in buffer_tensors: self._set_data_ptr(tensor) - # attach it to graph_info - graph_info.fwd_tmp = buffer_tensor # set data_ptr for output_tensor - for tensor in output_tensor: + for tensor in output_tensors: self._set_data_ptr(tensor) - # attach it to graph_info - graph_info.fwd_out = output_tensor + + # attach them to graph_info + graph_info.fwd_in = input_tensors + graph_info.fwd_tmp = buffer_tensors + graph_info.fwd_out = output_tensors # fetch other memory informations memory_cost = meta_info.memory_cost