mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] fix bugs caused by negative dim key (#1808)
* [autoparallel] fix bugs caused by negative dim key * fix import error * fix matmul test issue * fix unit test issuepull/1857/head
parent
4268ae017b
commit
49216d7ab1
|
@ -454,6 +454,9 @@ class MatMulHandler(NodeHandler):
|
|||
if -1 in dim_partition_dict:
|
||||
shard = dim_partition_dict.pop(-1)
|
||||
dim_partition_dict[0] = shard
|
||||
if 1 in dim_partition_dict:
|
||||
shard = dim_partition_dict.pop(1)
|
||||
dim_partition_dict[0] = shard
|
||||
|
||||
# re-init the sharding spec
|
||||
input_sharding_spec.__init__(input_sharding_spec.device_mesh,
|
||||
|
|
|
@ -9,6 +9,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
|||
ShardingStrategy,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||
|
||||
from .strategy_generator import StrategyGenerator
|
||||
|
@ -103,6 +104,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
|||
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||
strategy.memory_cost = memory_cost
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_input_channel(self, mesh_dim_0):
|
||||
name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
|
||||
dim_partition_dict_mapping = {
|
||||
|
@ -134,6 +136,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
|||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'
|
||||
dim_partition_dict_mapping = {
|
||||
|
@ -165,6 +168,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
|||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def non_split(self):
|
||||
name = f'RR = RR x R'
|
||||
dim_partition_dict_mapping = {
|
||||
|
@ -186,6 +190,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
|||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_input_batch(self, mesh_dim_0):
|
||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'
|
||||
dim_partition_dict_mapping = {
|
||||
|
@ -221,6 +226,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
|||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'
|
||||
dim_partition_dict_mapping = {
|
||||
|
@ -256,6 +262,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
|||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'
|
||||
dim_partition_dict_mapping = {
|
||||
|
|
|
@ -3,9 +3,12 @@ import operator
|
|||
from functools import reduce
|
||||
from typing import List
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
|
||||
from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding,
|
||||
enumerate_all_possible_2d_sharding)
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
|
||||
from colossalai.auto_parallel.tensor_shard.utils import (
|
||||
enumerate_all_possible_1d_sharding,
|
||||
enumerate_all_possible_2d_sharding,
|
||||
ignore_sharding_exception,
|
||||
)
|
||||
|
||||
from .strategy_generator import StrategyGenerator
|
||||
|
||||
|
@ -79,6 +82,7 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
|
|||
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||
strategy.memory_cost = memory_cost
|
||||
|
||||
@ignore_sharding_exception
|
||||
def _generate_strategy_with_dim_partition(self, dim_partition):
|
||||
dim_partition_dict_mapping = {"input": dim_partition, "output": dim_partition}
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
|||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.tensor.utils import convert_dim_partition_dict
|
||||
|
||||
|
||||
class StrategyGenerator(ABC):
|
||||
|
@ -74,11 +75,15 @@ class StrategyGenerator(ABC):
|
|||
op_data = self.op_data[op_data_name]
|
||||
if isinstance(op_data.data, tuple) and isinstance(op_data.data[0], torch.Tensor):
|
||||
sharding_spec = []
|
||||
for output, dim_partition_dict_element in zip(op_data.data, dim_partition_dict):
|
||||
for logical_shape, dim_partition_dict_element in zip(op_data.logical_shape, dim_partition_dict):
|
||||
dim_size = len(logical_shape)
|
||||
dim_partition_dict_element = convert_dim_partition_dict(dim_size, dim_partition_dict_element)
|
||||
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
|
||||
entire_shape=output.shape,
|
||||
entire_shape=logical_shape,
|
||||
dim_partition_dict=dim_partition_dict_element)
|
||||
else:
|
||||
dim_size = len(op_data.logical_shape)
|
||||
dim_partition_dict = convert_dim_partition_dict(dim_size, dim_partition_dict)
|
||||
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
|
||||
entire_shape=op_data.logical_shape,
|
||||
dim_partition_dict=dim_partition_dict)
|
||||
|
|
|
@ -1,19 +1,17 @@
|
|||
from . import distspec
|
||||
from .colo_parameter import ColoParameter
|
||||
from .colo_tensor import ColoTensor
|
||||
from .comm_spec import CollectiveCommPattern, CommSpec
|
||||
from .compute_spec import ComputePattern, ComputeSpec
|
||||
from .dist_spec_mgr import DistSpecManager
|
||||
from .distspec import ReplicaSpec, ShardSpec
|
||||
from .param_op_hook import ParamOpHook, ParamOpHookManager
|
||||
from .process_group import ProcessGroup
|
||||
from .tensor_spec import ColoTensorSpec
|
||||
from .distspec import ShardSpec
|
||||
from .distspec import ReplicaSpec
|
||||
|
||||
from .compute_spec import ComputeSpec, ComputePattern
|
||||
from .colo_tensor import ColoTensor
|
||||
from .colo_parameter import ColoParameter
|
||||
from .utils import convert_parameter, named_params_with_colotensor
|
||||
from .dist_spec_mgr import DistSpecManager
|
||||
from .param_op_hook import ParamOpHook, ParamOpHookManager
|
||||
from .comm_spec import CollectiveCommPattern, CommSpec
|
||||
from . import distspec
|
||||
from .utils import convert_dim_partition_dict, convert_parameter, merge_same_dim_mesh_list, named_params_with_colotensor
|
||||
|
||||
__all__ = [
|
||||
'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter',
|
||||
'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec', 'ShardSpec',
|
||||
'ReplicaSpec', 'CommSpec', 'CollectiveCommPattern'
|
||||
'ReplicaSpec', 'CommSpec', 'CollectiveCommPattern', 'convert_dim_partition_dict', 'merge_same_dim_mesh_list'
|
||||
]
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
import torch
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.tensor.colo_tensor import ColoTensor
|
||||
from colossalai.tensor.const import TensorType
|
||||
from colossalai.tensor import ColoTensorSpec
|
||||
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
||||
from colossalai.tensor.tensor_spec import ColoTensorSpec
|
||||
|
||||
|
||||
def filter_args(func, *args):
|
||||
|
|
|
@ -4,9 +4,10 @@ from typing import Callable, Optional, Set
|
|||
|
||||
import torch
|
||||
|
||||
from colossalai.tensor import ColoTensorSpec, ProcessGroup, ReplicaSpec
|
||||
from colossalai.tensor.dist_spec_mgr import DistSpecManager
|
||||
from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec
|
||||
from colossalai.tensor.distspec import DistPlacementPattern, ReplicaSpec, _DistSpec
|
||||
from colossalai.tensor.process_group import ProcessGroup
|
||||
from colossalai.tensor.tensor_spec import ColoTensorSpec
|
||||
|
||||
from .const import TensorType
|
||||
from .op_wrapper import _COLOSSAL_OPS
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
from colossalai.tensor.distspec import _DistSpec
|
||||
# from colossalai.nn.layer.utils import divide
|
||||
from numpy import prod
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
# from colossalai.nn.layer.utils import divide
|
||||
from numpy import prod
|
||||
from packaging import version
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from colossalai.tensor.distspec import _DistSpec
|
||||
from colossalai.tensor.process_group import ProcessGroup
|
||||
|
||||
|
||||
# TODO(jiaruifang) circle import, move the divide to colossalai.commons.
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
import torch
|
||||
from contextlib import contextmanager
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple, Any
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.tensor.colo_tensor import ColoTensor
|
||||
from colossalai.tensor import ColoTensorSpec
|
||||
from colossalai.tensor.tensor_spec import ColoTensorSpec
|
||||
|
||||
|
||||
class ParamOpHook(ABC):
|
||||
|
|
|
@ -6,6 +6,8 @@ import torch
|
|||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
from .utils import merge_same_dim_mesh_list
|
||||
|
||||
__all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec']
|
||||
|
||||
ALLGATHER_COST = 20
|
||||
|
@ -181,8 +183,12 @@ class ShardingSpec:
|
|||
self.dim_partition_dict = dim_partition_dict
|
||||
self.sharding_sequence = sharding_sequence
|
||||
if self.sharding_sequence is None:
|
||||
assert self.dim_partition_dict is not None, f'dim_partition_dict should not be None, if sharding_sequence is NoneType object.'
|
||||
self.dim_partition_dict = merge_same_dim_mesh_list(dim_size=len(entire_shape),
|
||||
dim_partition_dict=self.dim_partition_dict)
|
||||
self.convert_dict_to_shard_sequence()
|
||||
elif self.dim_partition_dict is None:
|
||||
assert self.sharding_sequence is not None, f'sharding_sequence should not be None, if dim_partition_dict is NoneType object.'
|
||||
self.convert_shard_sequence_to_dict()
|
||||
self._sanity_check()
|
||||
|
||||
|
|
|
@ -1,14 +1,16 @@
|
|||
from typing import Optional
|
||||
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
|
||||
from .compute_spec import ComputeSpec
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec
|
||||
from colossalai.tensor.process_group import ProcessGroup
|
||||
|
||||
from .compute_spec import ComputeSpec
|
||||
|
||||
|
||||
@dataclass
|
||||
class ColoTensorSpec:
|
||||
""" ColoTensorSpec
|
||||
|
||||
|
||||
A data class for specifications of the `ColoTensor`.
|
||||
It contains attributes of `ProcessGroup`, `_DistSpec`, `ComputeSpec`.
|
||||
The latter two attributes are optional. If not set, they are default value is `Replicate()` and `None`.
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import torch
|
||||
from typing import Dict, Iterator, List, Tuple, Union
|
||||
|
||||
from typing import Iterator, Tuple, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.tensor.colo_tensor import ColoTensor
|
||||
|
||||
|
||||
|
@ -12,7 +13,7 @@ def all_gather_simulator(target_pair):
|
|||
|
||||
We don't allow uncontiguous layout, such as all-gather(S012)->S02 is NOT allowed.
|
||||
Therefore, all gather operation just remove the last element in shard list,
|
||||
e.g.:
|
||||
e.g.:
|
||||
all-gather(S01) -> S0
|
||||
|
||||
Argument:
|
||||
|
@ -31,18 +32,18 @@ def all_to_all_simulator(f_target_pair, b_target_pair):
|
|||
and simulate the influence of the DimSpec.
|
||||
|
||||
We BANNED all representations which shard_list in decreasing order,
|
||||
such as S10, so all-to-all(S0, S1) -> RS01 is NOT allowed.
|
||||
such as S10, so all-to-all(S0, S1) -> RS01 is NOT allowed.
|
||||
Therefore, if the behind shard_list is not None, we just extend it to the front shard_list.
|
||||
Argument:
|
||||
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
|
||||
and the second element decribes which logical axis will be sharded in that dimension.
|
||||
e.g.:
|
||||
e.g.:
|
||||
all-to-all(S0, S1) -> [S01, R]
|
||||
all-to-all(S0, R) -> [R, S0]
|
||||
Otherwise, we extend the front shard_list to behind.
|
||||
e.g.:
|
||||
e.g.:
|
||||
all-to-all(R, S1) -> [S1, R]
|
||||
|
||||
|
||||
Argument:
|
||||
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
|
||||
and the second element decribes which logical axis will be sharded in that dimension.
|
||||
|
@ -65,7 +66,7 @@ def shard_simulator(target_pair, legal_sharding_dims):
|
|||
and simulate the influence of the DimSpec.
|
||||
|
||||
We don't allow uncontiguous layout, such as shard(S0)->S02 is NOT allowed.
|
||||
In addition, We BANNED all representations which shard_list in decreasing order,
|
||||
In addition, We BANNED all representations which shard_list in decreasing order,
|
||||
such as S10, so shard(S0) -> S10 is NOT allowed.
|
||||
Therefore, for the R dimension, we could just append any legal sharding dim on it.
|
||||
e.g.:
|
||||
|
@ -164,3 +165,37 @@ def convert_parameter(module: torch.nn.Module, param_name: str):
|
|||
|
||||
# Now we can set the attribute appropriately.
|
||||
setattr(module, param_name, st)
|
||||
|
||||
|
||||
def convert_dim_partition_dict(dim_size: int, dim_partition_dict: Dict[int, List[int]]) -> Dict[int, List[int]]:
|
||||
'''
|
||||
This method is used to convert the negative dim value to positive.
|
||||
'''
|
||||
dims_to_convert = []
|
||||
for dim, mesh_list in dim_partition_dict.items():
|
||||
if dim < 0:
|
||||
dims_to_convert.append(dim)
|
||||
for dim in dims_to_convert:
|
||||
dim_partition_dict.pop(dim)
|
||||
dim_partition_dict[dim_size + dim] = mesh_list
|
||||
return dim_partition_dict
|
||||
|
||||
|
||||
def merge_same_dim_mesh_list(dim_size: int, dim_partition_dict: Dict[int, List[int]]) -> Dict[int, List[int]]:
|
||||
'''
|
||||
This method is used to merge the different key value which points to same physical position.
|
||||
|
||||
For example:
|
||||
dim_partition_dict: {1 :[0], -1: [1]} or {1: [0], 1: [1]} for a 2d tensor, the dim 1 and -1 point same physical position.
|
||||
In this method, above dim_partition_dict will be converted to {1: [0, 1]}
|
||||
'''
|
||||
converted_dim_partition_dict = {}
|
||||
for dim, mesh_list in dim_partition_dict.items():
|
||||
if dim < 0:
|
||||
dim = dim_size + dim
|
||||
if dim not in converted_dim_partition_dict:
|
||||
converted_dim_partition_dict[dim] = mesh_list
|
||||
else:
|
||||
converted_dim_partition_dict[dim].extend(mesh_list)
|
||||
|
||||
return converted_dim_partition_dict
|
||||
|
|
Loading…
Reference in New Issue