[autoparallel] Attach input, buffer and output tensor to MetaInfo class (#2162)

* [fx] metainfo class for auto parallel

* [fx] add unit test for linear metainfo

* [fx] fix bwd param for linear

* [fx] modify unit test

* [fx] modify unit test

* [fx] modify import

* [fx] modify import

* [fx] modify import

* [fx] move meta profiler to auto parallel

* [fx] add conv metainfo class

* [fx] restore profiler

* [fx] restore meta profiler

* [autoparallel] modify unit test

* [fx] modify unit test

* [autoparallel] add batchnorm metainfo class

* [autoparallel] fix batchnorm unit test function declaration

* [fx] restore profiler

* [fx] add relu metainfo class

* [fx] restore profiler

* [autoparallel] modify metainfo input

* [autoparallel] add pooling metainfo

* [autoparallel] add F.linear metainfo generator

* [autoparallel] add binary elementwise metainfo

* [fx] recover profiler

* [autoparallel] fix forward memory calculation

* [autoparallel] modify constants.py

* [autoparallel] remove redundant print

* [autoparallel] add F.conv metainfo

* [autoparallel] linear fix

* [autoparallel] memory estimation for communication actions

* [autoparallel] fix docstring

* [autoparallel] fix variables name

* [autoparallel] attach tensor to metainfo class

* [autoparallel] fix dangerous try except

* [autoparallel] attach memory cost to shape consistency node

* [autoparallel] attach shape consistency node's metainfo to the node

* [autoparallel] remove todo in shape consistency memory estimation

* [autoparallel] fix the annotation
pull/2212/head^2
Boyuan Yao 2022-12-28 13:37:40 +08:00 committed by GitHub
parent d0bc5a1b34
commit 24246f7aa5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 118 additions and 44 deletions

View File

@ -64,7 +64,11 @@ def relu_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, Lis
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in
fwd_in = [input_tensor]
# store fwd_in, fwd_buffer, fwd_out
# NOTE: It might seems a little bit weird here, we just want to align it with the older version
# of MetaInfoProp. In the future we might modify this part to make it clearer.
fwd_in = []
fwd_buffer = [torch.zeros_like(output_tensor, device='meta')]
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
return compute_cost, memory_cost, fwd_in
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out

View File

@ -6,7 +6,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost,
from colossalai.fx.profiler.memory_utils import activation_size
from colossalai.fx.profiler.opcount import flop_mapping
from ..constants import BCAST_FUNC_OP
from ..constants import BCAST_FUNC_OP, NO_SAVE_ACTIVATION
from ..registry import meta_register
__all__ = ['binary_elementwise_meta_info']
@ -59,7 +59,9 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
# store fwd_in
fwd_in = fwd_in_args
# store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_op_data.data, device='meta')]
fwd_buffer = []
fwd_out = [torch.zeros_like(output_op_data.data, device='meta')]
return compute_cost, memory_cost, fwd_in
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out

View File

@ -129,7 +129,9 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in
fwd_in = [input_tensor]
# store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
fwd_buffer = []
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
return compute_cost, memory_cost, fwd_in
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out

View File

@ -164,7 +164,9 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in
fwd_in = [input_tensor]
# store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
fwd_buffer = []
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
return compute_cost, memory_cost, fwd_in
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out

View File

@ -95,7 +95,9 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in
fwd_in = [input_tensor]
# store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
fwd_buffer = [torch.zeros_like(mean_tensor, device='meta'), torch.zeros_like(var_tensor, device='meta')]
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
return compute_cost, memory_cost, fwd_in
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out

View File

@ -59,10 +59,12 @@ def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
mem_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
# store_fwd_in
fwd_in = [input_tensor]
# store fwd_in, fwd_buffer, fwd_out
fwd_in = []
fwd_buffer = []
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
return compute_cost, mem_cost, fwd_in
return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out
@meta_register.register(torch.nn.MaxPool1d)
@ -122,7 +124,9 @@ def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
mem_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
# store_fwd_in
fwd_in = [input_tensor]
# store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
fwd_buffer = [torch.zeros_like(index_matrix, device='meta')]
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
return compute_cost, mem_cost, fwd_in
return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out

View File

@ -1,4 +1,4 @@
from typing import Callable
from typing import Callable, List
import numpy as np
import torch
@ -33,10 +33,13 @@ class MetaInfo:
self.memory_cost: TrainCycleItem
# list of input tensors
self.fwd_in: list[OperationData]
self.fwd_in: List[torch.Tensor]
# bool type to indicate whether the function will save forward activation
self.save_fwd_in: bool
# list of buffer tensors
self.fwd_buffer: List[torch.Tensor]
# list of output tensors
self.fwd_out: List[torch.Tensor]
# sharding strategy
self._strategy = strategy
@ -94,19 +97,20 @@ class MetaInfo:
"""
Compute meta info based on sharding strategy and the given target function.
"""
try:
assert meta_register.has(self._target.__class__) or meta_register.has(self._target), \
f"Meta info for {self._target} is not registered."
if meta_register.has(self._target.__class__):
# module
meta_func = meta_register.get(self._target.__class__)
# check whether the target in the module list that we don't need to save activation
self.save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION
except:
# check whether the target in the list that we don't need to save activation
save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION
else:
# function
meta_func = meta_register.get(self._target)
# check whether the target in the module list that we don't need to save activation
self.save_fwd_in = self._target not in NO_SAVE_ACTIVATION
# check whether the target in the list that we don't need to save activation
save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION
# construct args for meta_func
args = [self.compute_sharded_tensor(k, v) for k, v in self._strategy.sharding_specs.items()]
@ -118,4 +122,8 @@ class MetaInfo:
kwargs = {'inplace': False}
# compute metainfo with meta_func
self.compute_cost, self.memory_cost, self.fwd_in = meta_func(*args, **kwargs)
self.compute_cost, self.memory_cost, self.fwd_in, self.fwd_buffer, self.fwd_out = meta_func(*args, **kwargs)
# process corner case for NO_SAVE_ACTIVATION
if not save_fwd_in:
self.fwd_in = []

View File

@ -4,11 +4,13 @@ from typing import Dict, List
import torch
from torch.fx.node import Node
from colossalai.auto_parallel.meta_profiler import MetaInfo
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
OperationData,
OperationDataType,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.comm_spec import CommSpec
@ -45,6 +47,52 @@ def runtime_apply_for_iterable_object(node: Node, origin_dict: Dict, input_dict:
return rst
def construct_meta_info(node: Node, user_node: Node) -> MetaInfo:
"""
This method is used to construct `MetaInto` for shape consistency node
TODO: Actually we could attain the cost information from resharding cost in node
handler, we should modify this part in the future.
"""
def compute_shape(sharding_spec: ShardingSpec):
shape = sharding_spec.entire_shape
new_shape = []
for dim, shard in sharding_spec.dim_partition_dict.items():
new_shape.append(shape[dim] // len(shard))
return new_shape
meta_info = MetaInfo()
origin_sharding_spec, target_sharding_spec = node.sharding_spec, user_node.sharding_spec
_, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
origin_sharding_spec, target_sharding_spec)
# NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel
# get mem cost for MetaInfo
mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence)
element_length = node._meta_data.element_size()
mem_cost.fwd.activation *= element_length
mem_cost.fwd.temp *= element_length
mem_cost.bwd.activation *= element_length
mem_cost.bwd.temp *= element_length
mem_cost.total.activation *= element_length
meta_info.memory_cost = mem_cost
# get computation cost for MetaInfo
compute_cost = TrainCycleItem(total_cost['forward'], total_cost['backward'], total_cost['total'])
meta_info.compute_cost = compute_cost
# get tensor shape for MetaInfo
input_shape = compute_shape(origin_sharding_spec)
output_shape = compute_shape(target_sharding_spec)
meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
meta_info.fwd_buffer = []
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
return meta_info
def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_index: int, op_data_name: str):
"""
This method will be invoked during runtime to apply the comm action following the instruction of comm spec.
@ -126,6 +174,8 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
runtime_apply,
args=(node, origin_dict_node, input_dict_node,
node_to_index_dict[node], user_node_index))
meta_info = construct_meta_info(node, user_node)
setattr(shape_consistency_node, 'best_metainfo', meta_info)
new_args = list(user_node.args)
new_kwargs = dict(user_node.kwargs)

View File

@ -407,9 +407,6 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
def mem_cost(self, comm_action_sequence: List[CommSpec]) -> TrainCycleItem:
"""memory cost of the communication action sequence
TODO: Currently we just consider tensor numel in the shape consistency manger,
as the manager itself doesn't have the access to tensor dtype, we need to take
it into consideration in memory estimation.
Args:
comm_action_sequence (List[CommSpec]): list of communication actions
@ -420,9 +417,10 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
def compute_shape(sharding_spec: ShardingSpec):
shape = sharding_spec.entire_shape
new_shape = []
for dim, shard in sharding_spec.dim_partition_dict.items():
shape[dim] = shape[dim] // len(shard)
return shape
new_shape.append(shape[dim] // len(shard))
return new_shape
def gather_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
"""analyze all_gather memory footprint
@ -461,7 +459,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
# generate a new tensor
input_shape = compute_shape(comm_spec.sharding_spec)
input_numel = np.prod(input_shape)
output_numel = input_numel // comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axes]
output_numel = input_numel // comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis]
alloc_numel += output_numel
peak_numel = max(peak_numel, alloc_numel)
if discard_input:
@ -538,8 +536,9 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
# analyze memory footprint of forward comm actions sequence
fwd_alloc_numel = 0
fwd_peak_numel = 0
for idx, fwd_action, comm_spec in enumerate(zip(fwd_actions, comm_action_sequence)):
for idx, action_spec_pair in enumerate(zip(fwd_actions, comm_action_sequence)):
# the first forward comm action will not discard input
fwd_action, comm_spec = action_spec_pair
if idx == 0:
fwd_action(comm_spec, False, fwd_alloc_numel, fwd_peak_numel)
else:
@ -548,7 +547,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
# analyze memory footprint for backward comm actions sequence
bwd_alloc_numel = 0
bwd_peak_numel = 0
for idx, bwd_action, comm_spec in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))):
for idx, action_spec_pair in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))):
bwd_action, comm_spec = action_spec_pair
bwd_action(comm_spec, True, bwd_alloc_numel, bwd_peak_numel)
fwd_mem = MemoryCost(activation=fwd_alloc_numel, temp=fwd_peak_numel - fwd_alloc_numel)

View File

@ -37,7 +37,7 @@ def _batchnorm_module_mem_test(rank, world_size, port):
# index of target node in computation graph
node_index = 1
# total number of target node strategies
strategy_number = 4
strategy_number = 9
mem_test_for_node_strategy(rank=rank,
model=model,
device_mesh=device_mesh,

View File

@ -92,7 +92,7 @@ def _linear_function_mem_test(rank, world_size, port):
model=model,
device_mesh=device_mesh,
node_index=2,
strategy_number=23,
strategy_number=24,
input_args=[input],
meta_arg_names=["input"])