mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] Attach input, buffer and output tensor to MetaInfo class (#2162)
* [fx] metainfo class for auto parallel * [fx] add unit test for linear metainfo * [fx] fix bwd param for linear * [fx] modify unit test * [fx] modify unit test * [fx] modify import * [fx] modify import * [fx] modify import * [fx] move meta profiler to auto parallel * [fx] add conv metainfo class * [fx] restore profiler * [fx] restore meta profiler * [autoparallel] modify unit test * [fx] modify unit test * [autoparallel] add batchnorm metainfo class * [autoparallel] fix batchnorm unit test function declaration * [fx] restore profiler * [fx] add relu metainfo class * [fx] restore profiler * [autoparallel] modify metainfo input * [autoparallel] add pooling metainfo * [autoparallel] add F.linear metainfo generator * [autoparallel] add binary elementwise metainfo * [fx] recover profiler * [autoparallel] fix forward memory calculation * [autoparallel] modify constants.py * [autoparallel] remove redundant print * [autoparallel] add F.conv metainfo * [autoparallel] linear fix * [autoparallel] memory estimation for communication actions * [autoparallel] fix docstring * [autoparallel] fix variables name * [autoparallel] attach tensor to metainfo class * [autoparallel] fix dangerous try except * [autoparallel] attach memory cost to shape consistency node * [autoparallel] attach shape consistency node's metainfo to the node * [autoparallel] remove todo in shape consistency memory estimation * [autoparallel] fix the annotationpull/2212/head^2
parent
d0bc5a1b34
commit
24246f7aa5
|
@ -64,7 +64,11 @@ def relu_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, Lis
|
|||
|
||||
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
|
||||
|
||||
# store fwd_in
|
||||
fwd_in = [input_tensor]
|
||||
# store fwd_in, fwd_buffer, fwd_out
|
||||
# NOTE: It might seems a little bit weird here, we just want to align it with the older version
|
||||
# of MetaInfoProp. In the future we might modify this part to make it clearer.
|
||||
fwd_in = []
|
||||
fwd_buffer = [torch.zeros_like(output_tensor, device='meta')]
|
||||
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
|
||||
|
||||
return compute_cost, memory_cost, fwd_in
|
||||
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
|
||||
|
|
|
@ -6,7 +6,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost,
|
|||
from colossalai.fx.profiler.memory_utils import activation_size
|
||||
from colossalai.fx.profiler.opcount import flop_mapping
|
||||
|
||||
from ..constants import BCAST_FUNC_OP
|
||||
from ..constants import BCAST_FUNC_OP, NO_SAVE_ACTIVATION
|
||||
from ..registry import meta_register
|
||||
|
||||
__all__ = ['binary_elementwise_meta_info']
|
||||
|
@ -59,7 +59,9 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train
|
|||
|
||||
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||
|
||||
# store fwd_in
|
||||
fwd_in = fwd_in_args
|
||||
# store fwd_in, fwd_buffer, fwd_out
|
||||
fwd_in = [torch.zeros_like(input_op_data.data, device='meta')]
|
||||
fwd_buffer = []
|
||||
fwd_out = [torch.zeros_like(output_op_data.data, device='meta')]
|
||||
|
||||
return compute_cost, memory_cost, fwd_in
|
||||
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
|
||||
|
|
|
@ -129,7 +129,9 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
|||
|
||||
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
|
||||
|
||||
# store fwd_in
|
||||
fwd_in = [input_tensor]
|
||||
# store fwd_in, fwd_buffer, fwd_out
|
||||
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
|
||||
fwd_buffer = []
|
||||
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
|
||||
|
||||
return compute_cost, memory_cost, fwd_in
|
||||
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
|
||||
|
|
|
@ -164,7 +164,9 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
|||
|
||||
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
|
||||
|
||||
# store fwd_in
|
||||
fwd_in = [input_tensor]
|
||||
# store fwd_in, fwd_buffer, fwd_out
|
||||
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
|
||||
fwd_buffer = []
|
||||
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
|
||||
|
||||
return compute_cost, memory_cost, fwd_in
|
||||
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
|
||||
|
|
|
@ -95,7 +95,9 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt
|
|||
|
||||
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
|
||||
|
||||
# store fwd_in
|
||||
fwd_in = [input_tensor]
|
||||
# store fwd_in, fwd_buffer, fwd_out
|
||||
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
|
||||
fwd_buffer = [torch.zeros_like(mean_tensor, device='meta'), torch.zeros_like(var_tensor, device='meta')]
|
||||
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
|
||||
|
||||
return compute_cost, memory_cost, fwd_in
|
||||
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
|
||||
|
|
|
@ -59,10 +59,12 @@ def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
|
|||
|
||||
mem_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||
|
||||
# store_fwd_in
|
||||
fwd_in = [input_tensor]
|
||||
# store fwd_in, fwd_buffer, fwd_out
|
||||
fwd_in = []
|
||||
fwd_buffer = []
|
||||
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
|
||||
|
||||
return compute_cost, mem_cost, fwd_in
|
||||
return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out
|
||||
|
||||
|
||||
@meta_register.register(torch.nn.MaxPool1d)
|
||||
|
@ -122,7 +124,9 @@ def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
|
|||
|
||||
mem_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||
|
||||
# store_fwd_in
|
||||
fwd_in = [input_tensor]
|
||||
# store fwd_in, fwd_buffer, fwd_out
|
||||
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
|
||||
fwd_buffer = [torch.zeros_like(index_matrix, device='meta')]
|
||||
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
|
||||
|
||||
return compute_cost, mem_cost, fwd_in
|
||||
return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Callable
|
||||
from typing import Callable, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -33,10 +33,13 @@ class MetaInfo:
|
|||
self.memory_cost: TrainCycleItem
|
||||
|
||||
# list of input tensors
|
||||
self.fwd_in: list[OperationData]
|
||||
self.fwd_in: List[torch.Tensor]
|
||||
|
||||
# bool type to indicate whether the function will save forward activation
|
||||
self.save_fwd_in: bool
|
||||
# list of buffer tensors
|
||||
self.fwd_buffer: List[torch.Tensor]
|
||||
|
||||
# list of output tensors
|
||||
self.fwd_out: List[torch.Tensor]
|
||||
|
||||
# sharding strategy
|
||||
self._strategy = strategy
|
||||
|
@ -94,19 +97,20 @@ class MetaInfo:
|
|||
"""
|
||||
Compute meta info based on sharding strategy and the given target function.
|
||||
"""
|
||||
|
||||
try:
|
||||
assert meta_register.has(self._target.__class__) or meta_register.has(self._target), \
|
||||
f"Meta info for {self._target} is not registered."
|
||||
if meta_register.has(self._target.__class__):
|
||||
# module
|
||||
meta_func = meta_register.get(self._target.__class__)
|
||||
|
||||
# check whether the target in the module list that we don't need to save activation
|
||||
self.save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION
|
||||
except:
|
||||
# check whether the target in the list that we don't need to save activation
|
||||
save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION
|
||||
else:
|
||||
# function
|
||||
meta_func = meta_register.get(self._target)
|
||||
|
||||
# check whether the target in the module list that we don't need to save activation
|
||||
self.save_fwd_in = self._target not in NO_SAVE_ACTIVATION
|
||||
# check whether the target in the list that we don't need to save activation
|
||||
save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION
|
||||
|
||||
# construct args for meta_func
|
||||
args = [self.compute_sharded_tensor(k, v) for k, v in self._strategy.sharding_specs.items()]
|
||||
|
@ -118,4 +122,8 @@ class MetaInfo:
|
|||
kwargs = {'inplace': False}
|
||||
|
||||
# compute metainfo with meta_func
|
||||
self.compute_cost, self.memory_cost, self.fwd_in = meta_func(*args, **kwargs)
|
||||
self.compute_cost, self.memory_cost, self.fwd_in, self.fwd_buffer, self.fwd_out = meta_func(*args, **kwargs)
|
||||
|
||||
# process corner case for NO_SAVE_ACTIVATION
|
||||
if not save_fwd_in:
|
||||
self.fwd_in = []
|
||||
|
|
|
@ -4,11 +4,13 @@ from typing import Dict, List
|
|||
import torch
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.meta_profiler import MetaInfo
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
CommAction,
|
||||
CommType,
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.comm_spec import CommSpec
|
||||
|
@ -45,6 +47,52 @@ def runtime_apply_for_iterable_object(node: Node, origin_dict: Dict, input_dict:
|
|||
return rst
|
||||
|
||||
|
||||
def construct_meta_info(node: Node, user_node: Node) -> MetaInfo:
|
||||
"""
|
||||
This method is used to construct `MetaInto` for shape consistency node
|
||||
TODO: Actually we could attain the cost information from resharding cost in node
|
||||
handler, we should modify this part in the future.
|
||||
"""
|
||||
|
||||
def compute_shape(sharding_spec: ShardingSpec):
|
||||
shape = sharding_spec.entire_shape
|
||||
new_shape = []
|
||||
for dim, shard in sharding_spec.dim_partition_dict.items():
|
||||
new_shape.append(shape[dim] // len(shard))
|
||||
return new_shape
|
||||
|
||||
meta_info = MetaInfo()
|
||||
origin_sharding_spec, target_sharding_spec = node.sharding_spec, user_node.sharding_spec
|
||||
_, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
|
||||
origin_sharding_spec, target_sharding_spec)
|
||||
|
||||
# NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel
|
||||
# get mem cost for MetaInfo
|
||||
mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence)
|
||||
element_length = node._meta_data.element_size()
|
||||
mem_cost.fwd.activation *= element_length
|
||||
mem_cost.fwd.temp *= element_length
|
||||
mem_cost.bwd.activation *= element_length
|
||||
mem_cost.bwd.temp *= element_length
|
||||
mem_cost.total.activation *= element_length
|
||||
|
||||
meta_info.memory_cost = mem_cost
|
||||
|
||||
# get computation cost for MetaInfo
|
||||
compute_cost = TrainCycleItem(total_cost['forward'], total_cost['backward'], total_cost['total'])
|
||||
meta_info.compute_cost = compute_cost
|
||||
|
||||
# get tensor shape for MetaInfo
|
||||
input_shape = compute_shape(origin_sharding_spec)
|
||||
output_shape = compute_shape(target_sharding_spec)
|
||||
|
||||
meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
|
||||
meta_info.fwd_buffer = []
|
||||
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
|
||||
|
||||
return meta_info
|
||||
|
||||
|
||||
def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_index: int, op_data_name: str):
|
||||
"""
|
||||
This method will be invoked during runtime to apply the comm action following the instruction of comm spec.
|
||||
|
@ -126,6 +174,8 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
|
|||
runtime_apply,
|
||||
args=(node, origin_dict_node, input_dict_node,
|
||||
node_to_index_dict[node], user_node_index))
|
||||
meta_info = construct_meta_info(node, user_node)
|
||||
setattr(shape_consistency_node, 'best_metainfo', meta_info)
|
||||
|
||||
new_args = list(user_node.args)
|
||||
new_kwargs = dict(user_node.kwargs)
|
||||
|
|
|
@ -407,9 +407,6 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
|
||||
def mem_cost(self, comm_action_sequence: List[CommSpec]) -> TrainCycleItem:
|
||||
"""memory cost of the communication action sequence
|
||||
TODO: Currently we just consider tensor numel in the shape consistency manger,
|
||||
as the manager itself doesn't have the access to tensor dtype, we need to take
|
||||
it into consideration in memory estimation.
|
||||
|
||||
Args:
|
||||
comm_action_sequence (List[CommSpec]): list of communication actions
|
||||
|
@ -420,9 +417,10 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
|
||||
def compute_shape(sharding_spec: ShardingSpec):
|
||||
shape = sharding_spec.entire_shape
|
||||
new_shape = []
|
||||
for dim, shard in sharding_spec.dim_partition_dict.items():
|
||||
shape[dim] = shape[dim] // len(shard)
|
||||
return shape
|
||||
new_shape.append(shape[dim] // len(shard))
|
||||
return new_shape
|
||||
|
||||
def gather_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
|
||||
"""analyze all_gather memory footprint
|
||||
|
@ -461,7 +459,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
# generate a new tensor
|
||||
input_shape = compute_shape(comm_spec.sharding_spec)
|
||||
input_numel = np.prod(input_shape)
|
||||
output_numel = input_numel // comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axes]
|
||||
output_numel = input_numel // comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis]
|
||||
alloc_numel += output_numel
|
||||
peak_numel = max(peak_numel, alloc_numel)
|
||||
if discard_input:
|
||||
|
@ -538,8 +536,9 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
# analyze memory footprint of forward comm actions sequence
|
||||
fwd_alloc_numel = 0
|
||||
fwd_peak_numel = 0
|
||||
for idx, fwd_action, comm_spec in enumerate(zip(fwd_actions, comm_action_sequence)):
|
||||
for idx, action_spec_pair in enumerate(zip(fwd_actions, comm_action_sequence)):
|
||||
# the first forward comm action will not discard input
|
||||
fwd_action, comm_spec = action_spec_pair
|
||||
if idx == 0:
|
||||
fwd_action(comm_spec, False, fwd_alloc_numel, fwd_peak_numel)
|
||||
else:
|
||||
|
@ -548,7 +547,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
# analyze memory footprint for backward comm actions sequence
|
||||
bwd_alloc_numel = 0
|
||||
bwd_peak_numel = 0
|
||||
for idx, bwd_action, comm_spec in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))):
|
||||
for idx, action_spec_pair in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))):
|
||||
bwd_action, comm_spec = action_spec_pair
|
||||
bwd_action(comm_spec, True, bwd_alloc_numel, bwd_peak_numel)
|
||||
|
||||
fwd_mem = MemoryCost(activation=fwd_alloc_numel, temp=fwd_peak_numel - fwd_alloc_numel)
|
||||
|
|
|
@ -37,7 +37,7 @@ def _batchnorm_module_mem_test(rank, world_size, port):
|
|||
# index of target node in computation graph
|
||||
node_index = 1
|
||||
# total number of target node strategies
|
||||
strategy_number = 4
|
||||
strategy_number = 9
|
||||
mem_test_for_node_strategy(rank=rank,
|
||||
model=model,
|
||||
device_mesh=device_mesh,
|
||||
|
|
|
@ -92,7 +92,7 @@ def _linear_function_mem_test(rank, world_size, port):
|
|||
model=model,
|
||||
device_mesh=device_mesh,
|
||||
node_index=2,
|
||||
strategy_number=23,
|
||||
strategy_number=24,
|
||||
input_args=[input],
|
||||
meta_arg_names=["input"])
|
||||
|
||||
|
|
Loading…
Reference in New Issue