[autoparallel] memory estimation for shape consistency (#2144)

* [fx] metainfo class for auto parallel

* [fx] add unit test for linear metainfo

* [fx] fix bwd param for linear

* [fx] modify unit test

* [fx] modify unit test

* [fx] modify import

* [fx] modify import

* [fx] modify import

* [fx] move meta profiler to auto parallel

* [fx] add conv metainfo class

* [fx] restore profiler

* [fx] restore meta profiler

* [autoparallel] modify unit test

* [fx] modify unit test

* [autoparallel] add batchnorm metainfo class

* [autoparallel] fix batchnorm unit test function declaration

* [fx] restore profiler

* [fx] add relu metainfo class

* [fx] restore profiler

* [autoparallel] modify metainfo input

* [autoparallel] add pooling metainfo

* [autoparallel] add F.linear metainfo generator

* [autoparallel] add binary elementwise metainfo

* [fx] recover profiler

* [autoparallel] fix forward memory calculation

* [autoparallel] modify constants.py

* [autoparallel] remove redundant print

* [autoparallel] add F.conv metainfo

* [autoparallel] linear fix

* [autoparallel] memory estimation for communication actions

* [autoparallel] fix docstring

* [autoparallel] fix variables name
pull/2164/head
Boyuan Yao 2022-12-21 10:39:37 +08:00 committed by GitHub
parent b87496a66b
commit cfe2a9bd90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 155 additions and 1 deletions

View File

@ -6,7 +6,7 @@ from typing import Any, Dict, List, Tuple, Union
import torch
from torch.fx.node import Node
from colossalai.tensor.shape_consistency import CommSpec
from colossalai.tensor.comm_spec import CommSpec
from colossalai.tensor.sharding_spec import ShardingSpec
from .constants import (

View File

@ -3,8 +3,10 @@ from copy import deepcopy
from dataclasses import dataclass
from typing import Dict, List, Tuple
import numpy as np
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, mix_gather_simulator, shard_simulator
@ -403,6 +405,158 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
valid_spec_dict.update(self.get_all_shard_spec(source_spec, orig_cost_dict))
return valid_spec_dict
def mem_cost(self, comm_action_sequence: List[CommSpec]) -> TrainCycleItem:
"""memory cost of the communication action sequence
TODO: Currently we just consider tensor numel in the shape consistency manger,
as the manager itself doesn't have the access to tensor dtype, we need to take
it into consideration in memory estimation.
Args:
comm_action_sequence (List[CommSpec]): list of communication actions
Returns:
TrainCycleItem: memory (numel) cost of such comm_action_sequence
"""
def compute_shape(sharding_spec: ShardingSpec):
shape = sharding_spec.entire_shape
for dim, shard in sharding_spec.dim_partition_dict.items():
shape[dim] = shape[dim] // len(shard)
return shape
def gather_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
"""analyze all_gather memory footprint
all_gather will allocate memory for the output tensor, and there will be temp memory for
all_gather operation, which is twice the size of output tensor
Args:
comm_spec (CommSpec): input CommSpec
discard_input (bool): whether to discard the input tensor
alloc_numel (int): current allocated numel
peak_numel (int): current peak numel
"""
input_shape = compute_shape(comm_spec.sharding_spec)
input_numel = np.prod(input_shape)
output_numel = input_numel * comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis]
peak_numel = max(peak_numel, alloc_numel + output_numel * 2)
alloc_numel += output_numel
if discard_input:
alloc_numel -= input_numel
def split_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
"""analyze split memory footprint
split will allocate memory for the output tensor if we don't apply shard on the first dimension of
the input tensor. If we apply shard on the first dimension, the `torch.tensor.contiguous()` will not
generate new tensor in this case, so no memory will be allocated.
Args:
comm_spec (CommSpec): input CommSpec
discard_input (bool): whether to discard the input tensor
alloc_numel (int): current allocated numel
peak_numel (int): current peak numel
"""
shard_dim = comm_spec.shard_dim
if shard_dim != 0:
# if we don't shard the tensor on the first dimension, the split action will
# generate a new tensor
input_shape = compute_shape(comm_spec.sharding_spec)
input_numel = np.prod(input_shape)
output_numel = input_numel // comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axes]
alloc_numel += output_numel
peak_numel = max(peak_numel, alloc_numel)
if discard_input:
alloc_numel -= input_numel
else:
# if we shard the tensor on the first dimension, the split action will not generate
# a new tensor, and as it will preserve a reference to the input tensor, we could
# override the discard_input option here
# NOTE: this special case might fail in some weird cases, e.g. if we have three split
# actions in the comm actions sequence, the first split action operate on the second dimension,
# the second split action operate on the first dimension, and the third split action operate, again,
# on the second dimension. Therefore, after the first two actions in the sequence, we will allocate
# memory the same size as the output of first split action. However, the third split action will discard
# the input tensor, and it actually should discard the tensor generated by the first split action, so in
# the current memory estimation framework, we will overestimate the memory usage. But the above case is
# kind of weird, and I think we could ignore it for now.
pass
def reduce_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
"""
a dummy function for reduce memory footprint analysis, as the reduce action doesn't allocate extra memory
"""
pass
def all2all_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
"""analyze all_to_all memory footprint
all_to_all will allocate memory for the output tensor, and temp memory of all_to_all action
is twice the size of output tensor if we shard input tensor on the first dimension, otherwise
the temp memory is three times the size of output tensor
Args:
comm_spec (CommSpec): input CommSpec
discard_input (bool): whether to discard the input tensor
alloc_numel (int): current allocated numel
peak_numel (int): current peak numel
"""
input_shape = compute_shape(comm_spec.sharding_spec)
input_numel = np.prod(input_shape)
output_numel = input_numel
shard_dim = comm_spec.shard_dim
if shard_dim != 0:
peak_numel = max(peak_numel, alloc_numel + output_numel * 3)
else:
peak_numel = max(peak_numel, alloc_numel + output_numel * 2)
alloc_numel += output_numel
if discard_input:
alloc_numel -= input_numel
def identity_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
"""
a dummy function for identity memory footprint analysis, as the identity action doesn't allocate extra memory
"""
pass
pattern_to_func_dict = {
CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: [gather_analysis, split_analysis],
CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: [all2all_analysis, all2all_analysis],
CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: [split_analysis, gather_analysis],
CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: [reduce_analysis, identity_analysis],
CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: [identity_analysis, reduce_analysis],
CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD: [],
}
fwd_actions = []
bwd_actions = []
# construct forward and backward comm actions sequence
for comm_spec in comm_action_sequence:
comm_spec: CommSpec
fwd_action, bwd_action = pattern_to_func_dict[comm_spec.comm_pattern]
fwd_actions.append(fwd_action)
bwd_actions.append(bwd_action)
# analyze memory footprint of forward comm actions sequence
fwd_alloc_numel = 0
fwd_peak_numel = 0
for idx, fwd_action, comm_spec in enumerate(zip(fwd_actions, comm_action_sequence)):
# the first forward comm action will not discard input
if idx == 0:
fwd_action(comm_spec, False, fwd_alloc_numel, fwd_peak_numel)
else:
fwd_action(comm_spec, True, fwd_alloc_numel, fwd_peak_numel)
# analyze memory footprint for backward comm actions sequence
bwd_alloc_numel = 0
bwd_peak_numel = 0
for idx, bwd_action, comm_spec in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))):
bwd_action(comm_spec, True, bwd_alloc_numel, bwd_peak_numel)
fwd_mem = MemoryCost(activation=fwd_alloc_numel, temp=fwd_peak_numel - fwd_alloc_numel)
bwd_mem = MemoryCost(activation=bwd_alloc_numel, temp=bwd_peak_numel - bwd_alloc_numel)
total_mem = MemoryCost(activation=fwd_alloc_numel + bwd_alloc_numel)
return TrainCycleItem(fwd_mem, bwd_mem, total_mem)
def shape_consistency(self, source_spec: ShardingSpec,
target_spec: ShardingSpec) -> Tuple[List[ShardingSpec], List[CommSpec], float]:
'''