[autoparallel] fix bugs caused by negative dim key (#1808)

* [autoparallel] fix bugs caused by negative dim key

* fix import error

* fix matmul test issue

* fix unit test issue
pull/1857/head
YuliangLiu0306 2 years ago committed by GitHub
parent 4268ae017b
commit 49216d7ab1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -454,6 +454,9 @@ class MatMulHandler(NodeHandler):
if -1 in dim_partition_dict:
shard = dim_partition_dict.pop(-1)
dim_partition_dict[0] = shard
if 1 in dim_partition_dict:
shard = dim_partition_dict.pop(1)
dim_partition_dict[0] = shard
# re-init the sharding spec
input_sharding_spec.__init__(input_sharding_spec.device_mesh,

@ -9,6 +9,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
@ -103,6 +104,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@ignore_sharding_exception
def split_input_channel(self, mesh_dim_0):
name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
dim_partition_dict_mapping = {
@ -134,6 +136,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'
dim_partition_dict_mapping = {
@ -165,6 +168,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def non_split(self):
name = f'RR = RR x R'
dim_partition_dict_mapping = {
@ -186,6 +190,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_input_batch(self, mesh_dim_0):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'
dim_partition_dict_mapping = {
@ -221,6 +226,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'
dim_partition_dict_mapping = {
@ -256,6 +262,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'
dim_partition_dict_mapping = {

@ -3,9 +3,12 @@ import operator
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
from colossalai.auto_parallel.tensor_shard.utils import (
enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
ignore_sharding_exception,
)
from .strategy_generator import StrategyGenerator
@ -79,6 +82,7 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@ignore_sharding_exception
def _generate_strategy_with_dim_partition(self, dim_partition):
dim_partition_dict_mapping = {"input": dim_partition, "output": dim_partition}

@ -17,6 +17,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.tensor.utils import convert_dim_partition_dict
class StrategyGenerator(ABC):
@ -74,11 +75,15 @@ class StrategyGenerator(ABC):
op_data = self.op_data[op_data_name]
if isinstance(op_data.data, tuple) and isinstance(op_data.data[0], torch.Tensor):
sharding_spec = []
for output, dim_partition_dict_element in zip(op_data.data, dim_partition_dict):
for logical_shape, dim_partition_dict_element in zip(op_data.logical_shape, dim_partition_dict):
dim_size = len(logical_shape)
dim_partition_dict_element = convert_dim_partition_dict(dim_size, dim_partition_dict_element)
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=output.shape,
entire_shape=logical_shape,
dim_partition_dict=dim_partition_dict_element)
else:
dim_size = len(op_data.logical_shape)
dim_partition_dict = convert_dim_partition_dict(dim_size, dim_partition_dict)
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=op_data.logical_shape,
dim_partition_dict=dim_partition_dict)

@ -1,19 +1,17 @@
from .process_group import ProcessGroup
from .tensor_spec import ColoTensorSpec
from .distspec import ShardSpec
from .distspec import ReplicaSpec
from .compute_spec import ComputeSpec, ComputePattern
from .colo_tensor import ColoTensor
from . import distspec
from .colo_parameter import ColoParameter
from .utils import convert_parameter, named_params_with_colotensor
from .colo_tensor import ColoTensor
from .comm_spec import CollectiveCommPattern, CommSpec
from .compute_spec import ComputePattern, ComputeSpec
from .dist_spec_mgr import DistSpecManager
from .distspec import ReplicaSpec, ShardSpec
from .param_op_hook import ParamOpHook, ParamOpHookManager
from .comm_spec import CollectiveCommPattern, CommSpec
from . import distspec
from .process_group import ProcessGroup
from .tensor_spec import ColoTensorSpec
from .utils import convert_dim_partition_dict, convert_parameter, merge_same_dim_mesh_list, named_params_with_colotensor
__all__ = [
'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter',
'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec', 'ShardSpec',
'ReplicaSpec', 'CommSpec', 'CollectiveCommPattern'
'ReplicaSpec', 'CommSpec', 'CollectiveCommPattern', 'convert_dim_partition_dict', 'merge_same_dim_mesh_list'
]

@ -1,11 +1,11 @@
import torch
from typing import Optional
import torch
from colossalai.tensor.colo_tensor import ColoTensor
from colossalai.tensor.const import TensorType
from colossalai.tensor import ColoTensorSpec
from colossalai.tensor.param_op_hook import ParamOpHookManager
from colossalai.tensor.tensor_spec import ColoTensorSpec
def filter_args(func, *args):

@ -4,9 +4,10 @@ from typing import Callable, Optional, Set
import torch
from colossalai.tensor import ColoTensorSpec, ProcessGroup, ReplicaSpec
from colossalai.tensor.dist_spec_mgr import DistSpecManager
from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec
from colossalai.tensor.distspec import DistPlacementPattern, ReplicaSpec, _DistSpec
from colossalai.tensor.process_group import ProcessGroup
from colossalai.tensor.tensor_spec import ColoTensorSpec
from .const import TensorType
from .op_wrapper import _COLOSSAL_OPS

@ -1,12 +1,14 @@
from colossalai.tensor.distspec import _DistSpec
# from colossalai.nn.layer.utils import divide
from numpy import prod
from contextlib import contextmanager
import torch
import torch.distributed as dist
# from colossalai.nn.layer.utils import divide
from numpy import prod
from packaging import version
from colossalai.logging import get_dist_logger
from colossalai.tensor import ProcessGroup
from colossalai.tensor.distspec import _DistSpec
from colossalai.tensor.process_group import ProcessGroup
# TODO(jiaruifang) circle import, move the divide to colossalai.commons.

@ -1,9 +1,11 @@
import torch
from contextlib import contextmanager
from abc import ABC, abstractmethod
from typing import List, Tuple, Any
from contextlib import contextmanager
from typing import Any, List, Tuple
import torch
from colossalai.tensor.colo_tensor import ColoTensor
from colossalai.tensor import ColoTensorSpec
from colossalai.tensor.tensor_spec import ColoTensorSpec
class ParamOpHook(ABC):

@ -6,6 +6,8 @@ import torch
from colossalai.device.device_mesh import DeviceMesh
from .utils import merge_same_dim_mesh_list
__all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec']
ALLGATHER_COST = 20
@ -181,8 +183,12 @@ class ShardingSpec:
self.dim_partition_dict = dim_partition_dict
self.sharding_sequence = sharding_sequence
if self.sharding_sequence is None:
assert self.dim_partition_dict is not None, f'dim_partition_dict should not be None, if sharding_sequence is NoneType object.'
self.dim_partition_dict = merge_same_dim_mesh_list(dim_size=len(entire_shape),
dim_partition_dict=self.dim_partition_dict)
self.convert_dict_to_shard_sequence()
elif self.dim_partition_dict is None:
assert self.sharding_sequence is not None, f'sharding_sequence should not be None, if dim_partition_dict is NoneType object.'
self.convert_shard_sequence_to_dict()
self._sanity_check()

@ -1,14 +1,16 @@
from dataclasses import dataclass
from typing import Optional
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec
from colossalai.tensor.process_group import ProcessGroup
from .compute_spec import ComputeSpec
from colossalai.tensor import ProcessGroup
from dataclasses import dataclass
@dataclass
class ColoTensorSpec:
""" ColoTensorSpec
A data class for specifications of the `ColoTensor`.
It contains attributes of `ProcessGroup`, `_DistSpec`, `ComputeSpec`.
The latter two attributes are optional. If not set, they are default value is `Replicate()` and `None`.

@ -1,7 +1,8 @@
import torch
from typing import Dict, Iterator, List, Tuple, Union
from typing import Iterator, Tuple, Union
import torch
import torch.nn as nn
from colossalai.tensor.colo_tensor import ColoTensor
@ -12,7 +13,7 @@ def all_gather_simulator(target_pair):
We don't allow uncontiguous layout, such as all-gather(S012)->S02 is NOT allowed.
Therefore, all gather operation just remove the last element in shard list,
e.g.:
e.g.:
all-gather(S01) -> S0
Argument:
@ -31,18 +32,18 @@ def all_to_all_simulator(f_target_pair, b_target_pair):
and simulate the influence of the DimSpec.
We BANNED all representations which shard_list in decreasing order,
such as S10, so all-to-all(S0, S1) -> RS01 is NOT allowed.
such as S10, so all-to-all(S0, S1) -> RS01 is NOT allowed.
Therefore, if the behind shard_list is not None, we just extend it to the front shard_list.
Argument:
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
and the second element decribes which logical axis will be sharded in that dimension.
e.g.:
e.g.:
all-to-all(S0, S1) -> [S01, R]
all-to-all(S0, R) -> [R, S0]
Otherwise, we extend the front shard_list to behind.
e.g.:
e.g.:
all-to-all(R, S1) -> [S1, R]
Argument:
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
and the second element decribes which logical axis will be sharded in that dimension.
@ -65,7 +66,7 @@ def shard_simulator(target_pair, legal_sharding_dims):
and simulate the influence of the DimSpec.
We don't allow uncontiguous layout, such as shard(S0)->S02 is NOT allowed.
In addition, We BANNED all representations which shard_list in decreasing order,
In addition, We BANNED all representations which shard_list in decreasing order,
such as S10, so shard(S0) -> S10 is NOT allowed.
Therefore, for the R dimension, we could just append any legal sharding dim on it.
e.g.:
@ -164,3 +165,37 @@ def convert_parameter(module: torch.nn.Module, param_name: str):
# Now we can set the attribute appropriately.
setattr(module, param_name, st)
def convert_dim_partition_dict(dim_size: int, dim_partition_dict: Dict[int, List[int]]) -> Dict[int, List[int]]:
'''
This method is used to convert the negative dim value to positive.
'''
dims_to_convert = []
for dim, mesh_list in dim_partition_dict.items():
if dim < 0:
dims_to_convert.append(dim)
for dim in dims_to_convert:
dim_partition_dict.pop(dim)
dim_partition_dict[dim_size + dim] = mesh_list
return dim_partition_dict
def merge_same_dim_mesh_list(dim_size: int, dim_partition_dict: Dict[int, List[int]]) -> Dict[int, List[int]]:
'''
This method is used to merge the different key value which points to same physical position.
For example:
dim_partition_dict: {1 :[0], -1: [1]} or {1: [0], 1: [1]} for a 2d tensor, the dim 1 and -1 point same physical position.
In this method, above dim_partition_dict will be converted to {1: [0, 1]}
'''
converted_dim_partition_dict = {}
for dim, mesh_list in dim_partition_dict.items():
if dim < 0:
dim = dim_size + dim
if dim not in converted_dim_partition_dict:
converted_dim_partition_dict[dim] = mesh_list
else:
converted_dim_partition_dict[dim].extend(mesh_list)
return converted_dim_partition_dict

Loading…
Cancel
Save