mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] memory estimation for shape consistency (#2144)
* [fx] metainfo class for auto parallel * [fx] add unit test for linear metainfo * [fx] fix bwd param for linear * [fx] modify unit test * [fx] modify unit test * [fx] modify import * [fx] modify import * [fx] modify import * [fx] move meta profiler to auto parallel * [fx] add conv metainfo class * [fx] restore profiler * [fx] restore meta profiler * [autoparallel] modify unit test * [fx] modify unit test * [autoparallel] add batchnorm metainfo class * [autoparallel] fix batchnorm unit test function declaration * [fx] restore profiler * [fx] add relu metainfo class * [fx] restore profiler * [autoparallel] modify metainfo input * [autoparallel] add pooling metainfo * [autoparallel] add F.linear metainfo generator * [autoparallel] add binary elementwise metainfo * [fx] recover profiler * [autoparallel] fix forward memory calculation * [autoparallel] modify constants.py * [autoparallel] remove redundant print * [autoparallel] add F.conv metainfo * [autoparallel] linear fix * [autoparallel] memory estimation for communication actions * [autoparallel] fix docstring * [autoparallel] fix variables namepull/2164/head
parent
b87496a66b
commit
cfe2a9bd90
|
@ -6,7 +6,7 @@ from typing import Any, Dict, List, Tuple, Union
|
|||
import torch
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.tensor.shape_consistency import CommSpec
|
||||
from colossalai.tensor.comm_spec import CommSpec
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from .constants import (
|
||||
|
|
|
@ -3,8 +3,10 @@ from copy import deepcopy
|
|||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException
|
||||
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, mix_gather_simulator, shard_simulator
|
||||
|
@ -403,6 +405,158 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
valid_spec_dict.update(self.get_all_shard_spec(source_spec, orig_cost_dict))
|
||||
return valid_spec_dict
|
||||
|
||||
def mem_cost(self, comm_action_sequence: List[CommSpec]) -> TrainCycleItem:
|
||||
"""memory cost of the communication action sequence
|
||||
TODO: Currently we just consider tensor numel in the shape consistency manger,
|
||||
as the manager itself doesn't have the access to tensor dtype, we need to take
|
||||
it into consideration in memory estimation.
|
||||
|
||||
Args:
|
||||
comm_action_sequence (List[CommSpec]): list of communication actions
|
||||
|
||||
Returns:
|
||||
TrainCycleItem: memory (numel) cost of such comm_action_sequence
|
||||
"""
|
||||
|
||||
def compute_shape(sharding_spec: ShardingSpec):
|
||||
shape = sharding_spec.entire_shape
|
||||
for dim, shard in sharding_spec.dim_partition_dict.items():
|
||||
shape[dim] = shape[dim] // len(shard)
|
||||
return shape
|
||||
|
||||
def gather_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
|
||||
"""analyze all_gather memory footprint
|
||||
all_gather will allocate memory for the output tensor, and there will be temp memory for
|
||||
all_gather operation, which is twice the size of output tensor
|
||||
|
||||
Args:
|
||||
comm_spec (CommSpec): input CommSpec
|
||||
discard_input (bool): whether to discard the input tensor
|
||||
alloc_numel (int): current allocated numel
|
||||
peak_numel (int): current peak numel
|
||||
"""
|
||||
input_shape = compute_shape(comm_spec.sharding_spec)
|
||||
input_numel = np.prod(input_shape)
|
||||
output_numel = input_numel * comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis]
|
||||
peak_numel = max(peak_numel, alloc_numel + output_numel * 2)
|
||||
alloc_numel += output_numel
|
||||
if discard_input:
|
||||
alloc_numel -= input_numel
|
||||
|
||||
def split_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
|
||||
"""analyze split memory footprint
|
||||
split will allocate memory for the output tensor if we don't apply shard on the first dimension of
|
||||
the input tensor. If we apply shard on the first dimension, the `torch.tensor.contiguous()` will not
|
||||
generate new tensor in this case, so no memory will be allocated.
|
||||
|
||||
Args:
|
||||
comm_spec (CommSpec): input CommSpec
|
||||
discard_input (bool): whether to discard the input tensor
|
||||
alloc_numel (int): current allocated numel
|
||||
peak_numel (int): current peak numel
|
||||
"""
|
||||
shard_dim = comm_spec.shard_dim
|
||||
if shard_dim != 0:
|
||||
# if we don't shard the tensor on the first dimension, the split action will
|
||||
# generate a new tensor
|
||||
input_shape = compute_shape(comm_spec.sharding_spec)
|
||||
input_numel = np.prod(input_shape)
|
||||
output_numel = input_numel // comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axes]
|
||||
alloc_numel += output_numel
|
||||
peak_numel = max(peak_numel, alloc_numel)
|
||||
if discard_input:
|
||||
alloc_numel -= input_numel
|
||||
else:
|
||||
# if we shard the tensor on the first dimension, the split action will not generate
|
||||
# a new tensor, and as it will preserve a reference to the input tensor, we could
|
||||
# override the discard_input option here
|
||||
# NOTE: this special case might fail in some weird cases, e.g. if we have three split
|
||||
# actions in the comm actions sequence, the first split action operate on the second dimension,
|
||||
# the second split action operate on the first dimension, and the third split action operate, again,
|
||||
# on the second dimension. Therefore, after the first two actions in the sequence, we will allocate
|
||||
# memory the same size as the output of first split action. However, the third split action will discard
|
||||
# the input tensor, and it actually should discard the tensor generated by the first split action, so in
|
||||
# the current memory estimation framework, we will overestimate the memory usage. But the above case is
|
||||
# kind of weird, and I think we could ignore it for now.
|
||||
pass
|
||||
|
||||
def reduce_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
|
||||
"""
|
||||
a dummy function for reduce memory footprint analysis, as the reduce action doesn't allocate extra memory
|
||||
"""
|
||||
pass
|
||||
|
||||
def all2all_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
|
||||
"""analyze all_to_all memory footprint
|
||||
all_to_all will allocate memory for the output tensor, and temp memory of all_to_all action
|
||||
is twice the size of output tensor if we shard input tensor on the first dimension, otherwise
|
||||
the temp memory is three times the size of output tensor
|
||||
|
||||
Args:
|
||||
comm_spec (CommSpec): input CommSpec
|
||||
discard_input (bool): whether to discard the input tensor
|
||||
alloc_numel (int): current allocated numel
|
||||
peak_numel (int): current peak numel
|
||||
"""
|
||||
input_shape = compute_shape(comm_spec.sharding_spec)
|
||||
input_numel = np.prod(input_shape)
|
||||
output_numel = input_numel
|
||||
shard_dim = comm_spec.shard_dim
|
||||
if shard_dim != 0:
|
||||
peak_numel = max(peak_numel, alloc_numel + output_numel * 3)
|
||||
else:
|
||||
peak_numel = max(peak_numel, alloc_numel + output_numel * 2)
|
||||
alloc_numel += output_numel
|
||||
if discard_input:
|
||||
alloc_numel -= input_numel
|
||||
|
||||
def identity_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
|
||||
"""
|
||||
a dummy function for identity memory footprint analysis, as the identity action doesn't allocate extra memory
|
||||
"""
|
||||
pass
|
||||
|
||||
pattern_to_func_dict = {
|
||||
CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: [gather_analysis, split_analysis],
|
||||
CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: [all2all_analysis, all2all_analysis],
|
||||
CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: [split_analysis, gather_analysis],
|
||||
CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: [reduce_analysis, identity_analysis],
|
||||
CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: [identity_analysis, reduce_analysis],
|
||||
CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD: [],
|
||||
}
|
||||
|
||||
fwd_actions = []
|
||||
bwd_actions = []
|
||||
|
||||
# construct forward and backward comm actions sequence
|
||||
for comm_spec in comm_action_sequence:
|
||||
comm_spec: CommSpec
|
||||
fwd_action, bwd_action = pattern_to_func_dict[comm_spec.comm_pattern]
|
||||
fwd_actions.append(fwd_action)
|
||||
bwd_actions.append(bwd_action)
|
||||
|
||||
# analyze memory footprint of forward comm actions sequence
|
||||
fwd_alloc_numel = 0
|
||||
fwd_peak_numel = 0
|
||||
for idx, fwd_action, comm_spec in enumerate(zip(fwd_actions, comm_action_sequence)):
|
||||
# the first forward comm action will not discard input
|
||||
if idx == 0:
|
||||
fwd_action(comm_spec, False, fwd_alloc_numel, fwd_peak_numel)
|
||||
else:
|
||||
fwd_action(comm_spec, True, fwd_alloc_numel, fwd_peak_numel)
|
||||
|
||||
# analyze memory footprint for backward comm actions sequence
|
||||
bwd_alloc_numel = 0
|
||||
bwd_peak_numel = 0
|
||||
for idx, bwd_action, comm_spec in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))):
|
||||
bwd_action(comm_spec, True, bwd_alloc_numel, bwd_peak_numel)
|
||||
|
||||
fwd_mem = MemoryCost(activation=fwd_alloc_numel, temp=fwd_peak_numel - fwd_alloc_numel)
|
||||
bwd_mem = MemoryCost(activation=bwd_alloc_numel, temp=bwd_peak_numel - bwd_alloc_numel)
|
||||
total_mem = MemoryCost(activation=fwd_alloc_numel + bwd_alloc_numel)
|
||||
|
||||
return TrainCycleItem(fwd_mem, bwd_mem, total_mem)
|
||||
|
||||
def shape_consistency(self, source_spec: ShardingSpec,
|
||||
target_spec: ShardingSpec) -> Tuple[List[ShardingSpec], List[CommSpec], float]:
|
||||
'''
|
||||
|
|
Loading…
Reference in New Issue