mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] memory estimation for shape consistency (#2144)
* [fx] metainfo class for auto parallel * [fx] add unit test for linear metainfo * [fx] fix bwd param for linear * [fx] modify unit test * [fx] modify unit test * [fx] modify import * [fx] modify import * [fx] modify import * [fx] move meta profiler to auto parallel * [fx] add conv metainfo class * [fx] restore profiler * [fx] restore meta profiler * [autoparallel] modify unit test * [fx] modify unit test * [autoparallel] add batchnorm metainfo class * [autoparallel] fix batchnorm unit test function declaration * [fx] restore profiler * [fx] add relu metainfo class * [fx] restore profiler * [autoparallel] modify metainfo input * [autoparallel] add pooling metainfo * [autoparallel] add F.linear metainfo generator * [autoparallel] add binary elementwise metainfo * [fx] recover profiler * [autoparallel] fix forward memory calculation * [autoparallel] modify constants.py * [autoparallel] remove redundant print * [autoparallel] add F.conv metainfo * [autoparallel] linear fix * [autoparallel] memory estimation for communication actions * [autoparallel] fix docstring * [autoparallel] fix variables namepull/2163/head
parent
b87496a66b
commit
cfe2a9bd90
|
@ -6,7 +6,7 @@ from typing import Any, Dict, List, Tuple, Union
|
||||||
import torch
|
import torch
|
||||||
from torch.fx.node import Node
|
from torch.fx.node import Node
|
||||||
|
|
||||||
from colossalai.tensor.shape_consistency import CommSpec
|
from colossalai.tensor.comm_spec import CommSpec
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
|
||||||
from .constants import (
|
from .constants import (
|
||||||
|
|
|
@ -3,8 +3,10 @@ from copy import deepcopy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
|
||||||
from colossalai.context.singleton_meta import SingletonMeta
|
from colossalai.context.singleton_meta import SingletonMeta
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException
|
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException
|
||||||
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, mix_gather_simulator, shard_simulator
|
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, mix_gather_simulator, shard_simulator
|
||||||
|
@ -403,6 +405,158 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||||
valid_spec_dict.update(self.get_all_shard_spec(source_spec, orig_cost_dict))
|
valid_spec_dict.update(self.get_all_shard_spec(source_spec, orig_cost_dict))
|
||||||
return valid_spec_dict
|
return valid_spec_dict
|
||||||
|
|
||||||
|
def mem_cost(self, comm_action_sequence: List[CommSpec]) -> TrainCycleItem:
|
||||||
|
"""memory cost of the communication action sequence
|
||||||
|
TODO: Currently we just consider tensor numel in the shape consistency manger,
|
||||||
|
as the manager itself doesn't have the access to tensor dtype, we need to take
|
||||||
|
it into consideration in memory estimation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
comm_action_sequence (List[CommSpec]): list of communication actions
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TrainCycleItem: memory (numel) cost of such comm_action_sequence
|
||||||
|
"""
|
||||||
|
|
||||||
|
def compute_shape(sharding_spec: ShardingSpec):
|
||||||
|
shape = sharding_spec.entire_shape
|
||||||
|
for dim, shard in sharding_spec.dim_partition_dict.items():
|
||||||
|
shape[dim] = shape[dim] // len(shard)
|
||||||
|
return shape
|
||||||
|
|
||||||
|
def gather_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
|
||||||
|
"""analyze all_gather memory footprint
|
||||||
|
all_gather will allocate memory for the output tensor, and there will be temp memory for
|
||||||
|
all_gather operation, which is twice the size of output tensor
|
||||||
|
|
||||||
|
Args:
|
||||||
|
comm_spec (CommSpec): input CommSpec
|
||||||
|
discard_input (bool): whether to discard the input tensor
|
||||||
|
alloc_numel (int): current allocated numel
|
||||||
|
peak_numel (int): current peak numel
|
||||||
|
"""
|
||||||
|
input_shape = compute_shape(comm_spec.sharding_spec)
|
||||||
|
input_numel = np.prod(input_shape)
|
||||||
|
output_numel = input_numel * comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis]
|
||||||
|
peak_numel = max(peak_numel, alloc_numel + output_numel * 2)
|
||||||
|
alloc_numel += output_numel
|
||||||
|
if discard_input:
|
||||||
|
alloc_numel -= input_numel
|
||||||
|
|
||||||
|
def split_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
|
||||||
|
"""analyze split memory footprint
|
||||||
|
split will allocate memory for the output tensor if we don't apply shard on the first dimension of
|
||||||
|
the input tensor. If we apply shard on the first dimension, the `torch.tensor.contiguous()` will not
|
||||||
|
generate new tensor in this case, so no memory will be allocated.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
comm_spec (CommSpec): input CommSpec
|
||||||
|
discard_input (bool): whether to discard the input tensor
|
||||||
|
alloc_numel (int): current allocated numel
|
||||||
|
peak_numel (int): current peak numel
|
||||||
|
"""
|
||||||
|
shard_dim = comm_spec.shard_dim
|
||||||
|
if shard_dim != 0:
|
||||||
|
# if we don't shard the tensor on the first dimension, the split action will
|
||||||
|
# generate a new tensor
|
||||||
|
input_shape = compute_shape(comm_spec.sharding_spec)
|
||||||
|
input_numel = np.prod(input_shape)
|
||||||
|
output_numel = input_numel // comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axes]
|
||||||
|
alloc_numel += output_numel
|
||||||
|
peak_numel = max(peak_numel, alloc_numel)
|
||||||
|
if discard_input:
|
||||||
|
alloc_numel -= input_numel
|
||||||
|
else:
|
||||||
|
# if we shard the tensor on the first dimension, the split action will not generate
|
||||||
|
# a new tensor, and as it will preserve a reference to the input tensor, we could
|
||||||
|
# override the discard_input option here
|
||||||
|
# NOTE: this special case might fail in some weird cases, e.g. if we have three split
|
||||||
|
# actions in the comm actions sequence, the first split action operate on the second dimension,
|
||||||
|
# the second split action operate on the first dimension, and the third split action operate, again,
|
||||||
|
# on the second dimension. Therefore, after the first two actions in the sequence, we will allocate
|
||||||
|
# memory the same size as the output of first split action. However, the third split action will discard
|
||||||
|
# the input tensor, and it actually should discard the tensor generated by the first split action, so in
|
||||||
|
# the current memory estimation framework, we will overestimate the memory usage. But the above case is
|
||||||
|
# kind of weird, and I think we could ignore it for now.
|
||||||
|
pass
|
||||||
|
|
||||||
|
def reduce_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
|
||||||
|
"""
|
||||||
|
a dummy function for reduce memory footprint analysis, as the reduce action doesn't allocate extra memory
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def all2all_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
|
||||||
|
"""analyze all_to_all memory footprint
|
||||||
|
all_to_all will allocate memory for the output tensor, and temp memory of all_to_all action
|
||||||
|
is twice the size of output tensor if we shard input tensor on the first dimension, otherwise
|
||||||
|
the temp memory is three times the size of output tensor
|
||||||
|
|
||||||
|
Args:
|
||||||
|
comm_spec (CommSpec): input CommSpec
|
||||||
|
discard_input (bool): whether to discard the input tensor
|
||||||
|
alloc_numel (int): current allocated numel
|
||||||
|
peak_numel (int): current peak numel
|
||||||
|
"""
|
||||||
|
input_shape = compute_shape(comm_spec.sharding_spec)
|
||||||
|
input_numel = np.prod(input_shape)
|
||||||
|
output_numel = input_numel
|
||||||
|
shard_dim = comm_spec.shard_dim
|
||||||
|
if shard_dim != 0:
|
||||||
|
peak_numel = max(peak_numel, alloc_numel + output_numel * 3)
|
||||||
|
else:
|
||||||
|
peak_numel = max(peak_numel, alloc_numel + output_numel * 2)
|
||||||
|
alloc_numel += output_numel
|
||||||
|
if discard_input:
|
||||||
|
alloc_numel -= input_numel
|
||||||
|
|
||||||
|
def identity_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
|
||||||
|
"""
|
||||||
|
a dummy function for identity memory footprint analysis, as the identity action doesn't allocate extra memory
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
pattern_to_func_dict = {
|
||||||
|
CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: [gather_analysis, split_analysis],
|
||||||
|
CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: [all2all_analysis, all2all_analysis],
|
||||||
|
CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: [split_analysis, gather_analysis],
|
||||||
|
CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: [reduce_analysis, identity_analysis],
|
||||||
|
CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: [identity_analysis, reduce_analysis],
|
||||||
|
CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD: [],
|
||||||
|
}
|
||||||
|
|
||||||
|
fwd_actions = []
|
||||||
|
bwd_actions = []
|
||||||
|
|
||||||
|
# construct forward and backward comm actions sequence
|
||||||
|
for comm_spec in comm_action_sequence:
|
||||||
|
comm_spec: CommSpec
|
||||||
|
fwd_action, bwd_action = pattern_to_func_dict[comm_spec.comm_pattern]
|
||||||
|
fwd_actions.append(fwd_action)
|
||||||
|
bwd_actions.append(bwd_action)
|
||||||
|
|
||||||
|
# analyze memory footprint of forward comm actions sequence
|
||||||
|
fwd_alloc_numel = 0
|
||||||
|
fwd_peak_numel = 0
|
||||||
|
for idx, fwd_action, comm_spec in enumerate(zip(fwd_actions, comm_action_sequence)):
|
||||||
|
# the first forward comm action will not discard input
|
||||||
|
if idx == 0:
|
||||||
|
fwd_action(comm_spec, False, fwd_alloc_numel, fwd_peak_numel)
|
||||||
|
else:
|
||||||
|
fwd_action(comm_spec, True, fwd_alloc_numel, fwd_peak_numel)
|
||||||
|
|
||||||
|
# analyze memory footprint for backward comm actions sequence
|
||||||
|
bwd_alloc_numel = 0
|
||||||
|
bwd_peak_numel = 0
|
||||||
|
for idx, bwd_action, comm_spec in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))):
|
||||||
|
bwd_action(comm_spec, True, bwd_alloc_numel, bwd_peak_numel)
|
||||||
|
|
||||||
|
fwd_mem = MemoryCost(activation=fwd_alloc_numel, temp=fwd_peak_numel - fwd_alloc_numel)
|
||||||
|
bwd_mem = MemoryCost(activation=bwd_alloc_numel, temp=bwd_peak_numel - bwd_alloc_numel)
|
||||||
|
total_mem = MemoryCost(activation=fwd_alloc_numel + bwd_alloc_numel)
|
||||||
|
|
||||||
|
return TrainCycleItem(fwd_mem, bwd_mem, total_mem)
|
||||||
|
|
||||||
def shape_consistency(self, source_spec: ShardingSpec,
|
def shape_consistency(self, source_spec: ShardingSpec,
|
||||||
target_spec: ShardingSpec) -> Tuple[List[ShardingSpec], List[CommSpec], float]:
|
target_spec: ShardingSpec) -> Tuple[List[ShardingSpec], List[CommSpec], float]:
|
||||||
'''
|
'''
|
||||||
|
|
Loading…
Reference in New Issue