[autoparallel] fix bugs caused by negative dim key (#1808)

* [autoparallel] fix bugs caused by negative dim key

* fix import error

* fix matmul test issue

* fix unit test issue
pull/1857/head
YuliangLiu0306 2022-11-08 17:03:50 +08:00 committed by GitHub
parent 4268ae017b
commit 49216d7ab1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 108 additions and 43 deletions

View File

@ -454,6 +454,9 @@ class MatMulHandler(NodeHandler):
if -1 in dim_partition_dict: if -1 in dim_partition_dict:
shard = dim_partition_dict.pop(-1) shard = dim_partition_dict.pop(-1)
dim_partition_dict[0] = shard dim_partition_dict[0] = shard
if 1 in dim_partition_dict:
shard = dim_partition_dict.pop(1)
dim_partition_dict[0] = shard
# re-init the sharding spec # re-init the sharding spec
input_sharding_spec.__init__(input_sharding_spec.device_mesh, input_sharding_spec.__init__(input_sharding_spec.device_mesh,

View File

@ -9,6 +9,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
ShardingStrategy, ShardingStrategy,
TrainCycleItem, TrainCycleItem,
) )
from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception
from colossalai.tensor.shape_consistency import CollectiveCommPattern from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator from .strategy_generator import StrategyGenerator
@ -103,6 +104,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost strategy.memory_cost = memory_cost
@ignore_sharding_exception
def split_input_channel(self, mesh_dim_0): def split_input_channel(self, mesh_dim_0):
name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}' name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
dim_partition_dict_mapping = { dim_partition_dict_mapping = {
@ -134,6 +136,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1): def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}' name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'
dim_partition_dict_mapping = { dim_partition_dict_mapping = {
@ -165,6 +168,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def non_split(self): def non_split(self):
name = f'RR = RR x R' name = f'RR = RR x R'
dim_partition_dict_mapping = { dim_partition_dict_mapping = {
@ -186,6 +190,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_input_batch(self, mesh_dim_0): def split_input_batch(self, mesh_dim_0):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN' name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'
dim_partition_dict_mapping = { dim_partition_dict_mapping = {
@ -221,6 +226,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1): def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN' name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'
dim_partition_dict_mapping = { dim_partition_dict_mapping = {
@ -256,6 +262,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1): def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN' name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'
dim_partition_dict_mapping = { dim_partition_dict_mapping = {

View File

@ -3,9 +3,12 @@ import operator
from functools import reduce from functools import reduce
from typing import List from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding, from colossalai.auto_parallel.tensor_shard.utils import (
enumerate_all_possible_2d_sharding) enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
ignore_sharding_exception,
)
from .strategy_generator import StrategyGenerator from .strategy_generator import StrategyGenerator
@ -79,6 +82,7 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost strategy.memory_cost = memory_cost
@ignore_sharding_exception
def _generate_strategy_with_dim_partition(self, dim_partition): def _generate_strategy_with_dim_partition(self, dim_partition):
dim_partition_dict_mapping = {"input": dim_partition, "output": dim_partition} dim_partition_dict_mapping = {"input": dim_partition, "output": dim_partition}

View File

@ -17,6 +17,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.tensor.utils import convert_dim_partition_dict
class StrategyGenerator(ABC): class StrategyGenerator(ABC):
@ -74,11 +75,15 @@ class StrategyGenerator(ABC):
op_data = self.op_data[op_data_name] op_data = self.op_data[op_data_name]
if isinstance(op_data.data, tuple) and isinstance(op_data.data[0], torch.Tensor): if isinstance(op_data.data, tuple) and isinstance(op_data.data[0], torch.Tensor):
sharding_spec = [] sharding_spec = []
for output, dim_partition_dict_element in zip(op_data.data, dim_partition_dict): for logical_shape, dim_partition_dict_element in zip(op_data.logical_shape, dim_partition_dict):
dim_size = len(logical_shape)
dim_partition_dict_element = convert_dim_partition_dict(dim_size, dim_partition_dict_element)
sharding_spec = ShardingSpec(device_mesh=self.device_mesh, sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=output.shape, entire_shape=logical_shape,
dim_partition_dict=dim_partition_dict_element) dim_partition_dict=dim_partition_dict_element)
else: else:
dim_size = len(op_data.logical_shape)
dim_partition_dict = convert_dim_partition_dict(dim_size, dim_partition_dict)
sharding_spec = ShardingSpec(device_mesh=self.device_mesh, sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=op_data.logical_shape, entire_shape=op_data.logical_shape,
dim_partition_dict=dim_partition_dict) dim_partition_dict=dim_partition_dict)

View File

@ -1,19 +1,17 @@
from . import distspec
from .colo_parameter import ColoParameter
from .colo_tensor import ColoTensor
from .comm_spec import CollectiveCommPattern, CommSpec
from .compute_spec import ComputePattern, ComputeSpec
from .dist_spec_mgr import DistSpecManager
from .distspec import ReplicaSpec, ShardSpec
from .param_op_hook import ParamOpHook, ParamOpHookManager
from .process_group import ProcessGroup from .process_group import ProcessGroup
from .tensor_spec import ColoTensorSpec from .tensor_spec import ColoTensorSpec
from .distspec import ShardSpec from .utils import convert_dim_partition_dict, convert_parameter, merge_same_dim_mesh_list, named_params_with_colotensor
from .distspec import ReplicaSpec
from .compute_spec import ComputeSpec, ComputePattern
from .colo_tensor import ColoTensor
from .colo_parameter import ColoParameter
from .utils import convert_parameter, named_params_with_colotensor
from .dist_spec_mgr import DistSpecManager
from .param_op_hook import ParamOpHook, ParamOpHookManager
from .comm_spec import CollectiveCommPattern, CommSpec
from . import distspec
__all__ = [ __all__ = [
'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter', 'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter',
'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec', 'ShardSpec', 'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec', 'ShardSpec',
'ReplicaSpec', 'CommSpec', 'CollectiveCommPattern' 'ReplicaSpec', 'CommSpec', 'CollectiveCommPattern', 'convert_dim_partition_dict', 'merge_same_dim_mesh_list'
] ]

View File

@ -1,11 +1,11 @@
import torch
from typing import Optional from typing import Optional
import torch
from colossalai.tensor.colo_tensor import ColoTensor from colossalai.tensor.colo_tensor import ColoTensor
from colossalai.tensor.const import TensorType from colossalai.tensor.const import TensorType
from colossalai.tensor import ColoTensorSpec
from colossalai.tensor.param_op_hook import ParamOpHookManager from colossalai.tensor.param_op_hook import ParamOpHookManager
from colossalai.tensor.tensor_spec import ColoTensorSpec
def filter_args(func, *args): def filter_args(func, *args):

View File

@ -4,9 +4,10 @@ from typing import Callable, Optional, Set
import torch import torch
from colossalai.tensor import ColoTensorSpec, ProcessGroup, ReplicaSpec
from colossalai.tensor.dist_spec_mgr import DistSpecManager from colossalai.tensor.dist_spec_mgr import DistSpecManager
from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec from colossalai.tensor.distspec import DistPlacementPattern, ReplicaSpec, _DistSpec
from colossalai.tensor.process_group import ProcessGroup
from colossalai.tensor.tensor_spec import ColoTensorSpec
from .const import TensorType from .const import TensorType
from .op_wrapper import _COLOSSAL_OPS from .op_wrapper import _COLOSSAL_OPS

View File

@ -1,12 +1,14 @@
from colossalai.tensor.distspec import _DistSpec
# from colossalai.nn.layer.utils import divide
from numpy import prod
from contextlib import contextmanager from contextlib import contextmanager
import torch import torch
import torch.distributed as dist import torch.distributed as dist
# from colossalai.nn.layer.utils import divide
from numpy import prod
from packaging import version from packaging import version
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.tensor import ProcessGroup from colossalai.tensor.distspec import _DistSpec
from colossalai.tensor.process_group import ProcessGroup
# TODO(jiaruifang) circle import, move the divide to colossalai.commons. # TODO(jiaruifang) circle import, move the divide to colossalai.commons.

View File

@ -1,9 +1,11 @@
import torch
from contextlib import contextmanager
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Tuple, Any from contextlib import contextmanager
from typing import Any, List, Tuple
import torch
from colossalai.tensor.colo_tensor import ColoTensor from colossalai.tensor.colo_tensor import ColoTensor
from colossalai.tensor import ColoTensorSpec from colossalai.tensor.tensor_spec import ColoTensorSpec
class ParamOpHook(ABC): class ParamOpHook(ABC):

View File

@ -6,6 +6,8 @@ import torch
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from .utils import merge_same_dim_mesh_list
__all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec'] __all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec']
ALLGATHER_COST = 20 ALLGATHER_COST = 20
@ -181,8 +183,12 @@ class ShardingSpec:
self.dim_partition_dict = dim_partition_dict self.dim_partition_dict = dim_partition_dict
self.sharding_sequence = sharding_sequence self.sharding_sequence = sharding_sequence
if self.sharding_sequence is None: if self.sharding_sequence is None:
assert self.dim_partition_dict is not None, f'dim_partition_dict should not be None, if sharding_sequence is NoneType object.'
self.dim_partition_dict = merge_same_dim_mesh_list(dim_size=len(entire_shape),
dim_partition_dict=self.dim_partition_dict)
self.convert_dict_to_shard_sequence() self.convert_dict_to_shard_sequence()
elif self.dim_partition_dict is None: elif self.dim_partition_dict is None:
assert self.sharding_sequence is not None, f'sharding_sequence should not be None, if dim_partition_dict is NoneType object.'
self.convert_shard_sequence_to_dict() self.convert_shard_sequence_to_dict()
self._sanity_check() self._sanity_check()

View File

@ -1,14 +1,16 @@
from typing import Optional
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
from .compute_spec import ComputeSpec
from colossalai.tensor import ProcessGroup
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional
from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec
from colossalai.tensor.process_group import ProcessGroup
from .compute_spec import ComputeSpec
@dataclass @dataclass
class ColoTensorSpec: class ColoTensorSpec:
""" ColoTensorSpec """ ColoTensorSpec
A data class for specifications of the `ColoTensor`. A data class for specifications of the `ColoTensor`.
It contains attributes of `ProcessGroup`, `_DistSpec`, `ComputeSpec`. It contains attributes of `ProcessGroup`, `_DistSpec`, `ComputeSpec`.
The latter two attributes are optional. If not set, they are default value is `Replicate()` and `None`. The latter two attributes are optional. If not set, they are default value is `Replicate()` and `None`.

View File

@ -1,7 +1,8 @@
import torch from typing import Dict, Iterator, List, Tuple, Union
from typing import Iterator, Tuple, Union import torch
import torch.nn as nn import torch.nn as nn
from colossalai.tensor.colo_tensor import ColoTensor from colossalai.tensor.colo_tensor import ColoTensor
@ -12,7 +13,7 @@ def all_gather_simulator(target_pair):
We don't allow uncontiguous layout, such as all-gather(S012)->S02 is NOT allowed. We don't allow uncontiguous layout, such as all-gather(S012)->S02 is NOT allowed.
Therefore, all gather operation just remove the last element in shard list, Therefore, all gather operation just remove the last element in shard list,
e.g.: e.g.:
all-gather(S01) -> S0 all-gather(S01) -> S0
Argument: Argument:
@ -31,18 +32,18 @@ def all_to_all_simulator(f_target_pair, b_target_pair):
and simulate the influence of the DimSpec. and simulate the influence of the DimSpec.
We BANNED all representations which shard_list in decreasing order, We BANNED all representations which shard_list in decreasing order,
such as S10, so all-to-all(S0, S1) -> RS01 is NOT allowed. such as S10, so all-to-all(S0, S1) -> RS01 is NOT allowed.
Therefore, if the behind shard_list is not None, we just extend it to the front shard_list. Therefore, if the behind shard_list is not None, we just extend it to the front shard_list.
Argument: Argument:
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded, target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
and the second element decribes which logical axis will be sharded in that dimension. and the second element decribes which logical axis will be sharded in that dimension.
e.g.: e.g.:
all-to-all(S0, S1) -> [S01, R] all-to-all(S0, S1) -> [S01, R]
all-to-all(S0, R) -> [R, S0] all-to-all(S0, R) -> [R, S0]
Otherwise, we extend the front shard_list to behind. Otherwise, we extend the front shard_list to behind.
e.g.: e.g.:
all-to-all(R, S1) -> [S1, R] all-to-all(R, S1) -> [S1, R]
Argument: Argument:
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded, target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
and the second element decribes which logical axis will be sharded in that dimension. and the second element decribes which logical axis will be sharded in that dimension.
@ -65,7 +66,7 @@ def shard_simulator(target_pair, legal_sharding_dims):
and simulate the influence of the DimSpec. and simulate the influence of the DimSpec.
We don't allow uncontiguous layout, such as shard(S0)->S02 is NOT allowed. We don't allow uncontiguous layout, such as shard(S0)->S02 is NOT allowed.
In addition, We BANNED all representations which shard_list in decreasing order, In addition, We BANNED all representations which shard_list in decreasing order,
such as S10, so shard(S0) -> S10 is NOT allowed. such as S10, so shard(S0) -> S10 is NOT allowed.
Therefore, for the R dimension, we could just append any legal sharding dim on it. Therefore, for the R dimension, we could just append any legal sharding dim on it.
e.g.: e.g.:
@ -164,3 +165,37 @@ def convert_parameter(module: torch.nn.Module, param_name: str):
# Now we can set the attribute appropriately. # Now we can set the attribute appropriately.
setattr(module, param_name, st) setattr(module, param_name, st)
def convert_dim_partition_dict(dim_size: int, dim_partition_dict: Dict[int, List[int]]) -> Dict[int, List[int]]:
'''
This method is used to convert the negative dim value to positive.
'''
dims_to_convert = []
for dim, mesh_list in dim_partition_dict.items():
if dim < 0:
dims_to_convert.append(dim)
for dim in dims_to_convert:
dim_partition_dict.pop(dim)
dim_partition_dict[dim_size + dim] = mesh_list
return dim_partition_dict
def merge_same_dim_mesh_list(dim_size: int, dim_partition_dict: Dict[int, List[int]]) -> Dict[int, List[int]]:
'''
This method is used to merge the different key value which points to same physical position.
For example:
dim_partition_dict: {1 :[0], -1: [1]} or {1: [0], 1: [1]} for a 2d tensor, the dim 1 and -1 point same physical position.
In this method, above dim_partition_dict will be converted to {1: [0, 1]}
'''
converted_dim_partition_dict = {}
for dim, mesh_list in dim_partition_dict.items():
if dim < 0:
dim = dim_size + dim
if dim not in converted_dim_partition_dict:
converted_dim_partition_dict[dim] = mesh_list
else:
converted_dim_partition_dict[dim].extend(mesh_list)
return converted_dim_partition_dict