From 49216d7ab18569db74654dee2828b79e67ff8a4a Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Tue, 8 Nov 2022 17:03:50 +0800 Subject: [PATCH] [autoparallel] fix bugs caused by negative dim key (#1808) * [autoparallel] fix bugs caused by negative dim key * fix import error * fix matmul test issue * fix unit test issue --- .../node_handler/matmul_handler.py | 3 ++ .../strategy/batch_norm_generator.py | 7 +++ .../strategy/normal_pooling_generator.py | 10 ++-- .../strategy/strategy_generator.py | 9 +++- colossalai/tensor/__init__.py | 20 ++++---- colossalai/tensor/colo_parameter.py | 6 +-- colossalai/tensor/colo_tensor.py | 5 +- colossalai/tensor/dist_spec_mgr.py | 10 ++-- colossalai/tensor/param_op_hook.py | 10 ++-- colossalai/tensor/sharding_spec.py | 6 +++ colossalai/tensor/tensor_spec.py | 10 ++-- colossalai/tensor/utils.py | 51 ++++++++++++++++--- 12 files changed, 106 insertions(+), 41 deletions(-) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py index 5bc899049..ba3e03976 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py @@ -454,6 +454,9 @@ class MatMulHandler(NodeHandler): if -1 in dim_partition_dict: shard = dim_partition_dict.pop(-1) dim_partition_dict[0] = shard + if 1 in dim_partition_dict: + shard = dim_partition_dict.pop(1) + dim_partition_dict[0] = shard # re-init the sharding spec input_sharding_spec.__init__(input_sharding_spec.device_mesh, diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py index b3769ccd6..6a81a7eaa 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py @@ -9,6 +9,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( ShardingStrategy, TrainCycleItem, ) +from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception from colossalai.tensor.shape_consistency import CollectiveCommPattern from .strategy_generator import StrategyGenerator @@ -103,6 +104,7 @@ class BatchNormStrategyGenerator(StrategyGenerator): memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost + @ignore_sharding_exception def split_input_channel(self, mesh_dim_0): name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}' dim_partition_dict_mapping = { @@ -134,6 +136,7 @@ class BatchNormStrategyGenerator(StrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) + @ignore_sharding_exception def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1): name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}' dim_partition_dict_mapping = { @@ -165,6 +168,7 @@ class BatchNormStrategyGenerator(StrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) + @ignore_sharding_exception def non_split(self): name = f'RR = RR x R' dim_partition_dict_mapping = { @@ -186,6 +190,7 @@ class BatchNormStrategyGenerator(StrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) + @ignore_sharding_exception def split_input_batch(self, mesh_dim_0): name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN' dim_partition_dict_mapping = { @@ -221,6 +226,7 @@ class BatchNormStrategyGenerator(StrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) + @ignore_sharding_exception def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1): name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN' dim_partition_dict_mapping = { @@ -256,6 +262,7 @@ class BatchNormStrategyGenerator(StrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) + @ignore_sharding_exception def split_input_both_dim(self, mesh_dim_0, mesh_dim_1): name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN' dim_partition_dict_mapping = { diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py index 457f51450..9df6d2fbf 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py @@ -3,9 +3,12 @@ import operator from functools import reduce from typing import List -from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) -from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding, - enumerate_all_possible_2d_sharding) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem +from colossalai.auto_parallel.tensor_shard.utils import ( + enumerate_all_possible_1d_sharding, + enumerate_all_possible_2d_sharding, + ignore_sharding_exception, +) from .strategy_generator import StrategyGenerator @@ -79,6 +82,7 @@ class NormalPoolStrategyGenerator(StrategyGenerator): memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost + @ignore_sharding_exception def _generate_strategy_with_dim_partition(self, dim_partition): dim_partition_dict_mapping = {"input": dim_partition, "output": dim_partition} diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py index 096bda619..c0f7a33da 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py @@ -17,6 +17,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.tensor.utils import convert_dim_partition_dict class StrategyGenerator(ABC): @@ -74,11 +75,15 @@ class StrategyGenerator(ABC): op_data = self.op_data[op_data_name] if isinstance(op_data.data, tuple) and isinstance(op_data.data[0], torch.Tensor): sharding_spec = [] - for output, dim_partition_dict_element in zip(op_data.data, dim_partition_dict): + for logical_shape, dim_partition_dict_element in zip(op_data.logical_shape, dim_partition_dict): + dim_size = len(logical_shape) + dim_partition_dict_element = convert_dim_partition_dict(dim_size, dim_partition_dict_element) sharding_spec = ShardingSpec(device_mesh=self.device_mesh, - entire_shape=output.shape, + entire_shape=logical_shape, dim_partition_dict=dim_partition_dict_element) else: + dim_size = len(op_data.logical_shape) + dim_partition_dict = convert_dim_partition_dict(dim_size, dim_partition_dict) sharding_spec = ShardingSpec(device_mesh=self.device_mesh, entire_shape=op_data.logical_shape, dim_partition_dict=dim_partition_dict) diff --git a/colossalai/tensor/__init__.py b/colossalai/tensor/__init__.py index 4946d7077..ebccf7e18 100644 --- a/colossalai/tensor/__init__.py +++ b/colossalai/tensor/__init__.py @@ -1,19 +1,17 @@ -from .process_group import ProcessGroup -from .tensor_spec import ColoTensorSpec -from .distspec import ShardSpec -from .distspec import ReplicaSpec - -from .compute_spec import ComputeSpec, ComputePattern -from .colo_tensor import ColoTensor +from . import distspec from .colo_parameter import ColoParameter -from .utils import convert_parameter, named_params_with_colotensor +from .colo_tensor import ColoTensor +from .comm_spec import CollectiveCommPattern, CommSpec +from .compute_spec import ComputePattern, ComputeSpec from .dist_spec_mgr import DistSpecManager +from .distspec import ReplicaSpec, ShardSpec from .param_op_hook import ParamOpHook, ParamOpHookManager -from .comm_spec import CollectiveCommPattern, CommSpec -from . import distspec +from .process_group import ProcessGroup +from .tensor_spec import ColoTensorSpec +from .utils import convert_dim_partition_dict, convert_parameter, merge_same_dim_mesh_list, named_params_with_colotensor __all__ = [ 'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter', 'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec', 'ShardSpec', - 'ReplicaSpec', 'CommSpec', 'CollectiveCommPattern' + 'ReplicaSpec', 'CommSpec', 'CollectiveCommPattern', 'convert_dim_partition_dict', 'merge_same_dim_mesh_list' ] diff --git a/colossalai/tensor/colo_parameter.py b/colossalai/tensor/colo_parameter.py index 17c326516..7247ef966 100644 --- a/colossalai/tensor/colo_parameter.py +++ b/colossalai/tensor/colo_parameter.py @@ -1,11 +1,11 @@ -import torch - from typing import Optional +import torch + from colossalai.tensor.colo_tensor import ColoTensor from colossalai.tensor.const import TensorType -from colossalai.tensor import ColoTensorSpec from colossalai.tensor.param_op_hook import ParamOpHookManager +from colossalai.tensor.tensor_spec import ColoTensorSpec def filter_args(func, *args): diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index 2dd0de560..c9e48a453 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -4,9 +4,10 @@ from typing import Callable, Optional, Set import torch -from colossalai.tensor import ColoTensorSpec, ProcessGroup, ReplicaSpec from colossalai.tensor.dist_spec_mgr import DistSpecManager -from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec +from colossalai.tensor.distspec import DistPlacementPattern, ReplicaSpec, _DistSpec +from colossalai.tensor.process_group import ProcessGroup +from colossalai.tensor.tensor_spec import ColoTensorSpec from .const import TensorType from .op_wrapper import _COLOSSAL_OPS diff --git a/colossalai/tensor/dist_spec_mgr.py b/colossalai/tensor/dist_spec_mgr.py index f1dc241a8..d5c0ce28e 100644 --- a/colossalai/tensor/dist_spec_mgr.py +++ b/colossalai/tensor/dist_spec_mgr.py @@ -1,12 +1,14 @@ -from colossalai.tensor.distspec import _DistSpec -# from colossalai.nn.layer.utils import divide -from numpy import prod from contextlib import contextmanager + import torch import torch.distributed as dist +# from colossalai.nn.layer.utils import divide +from numpy import prod from packaging import version + from colossalai.logging import get_dist_logger -from colossalai.tensor import ProcessGroup +from colossalai.tensor.distspec import _DistSpec +from colossalai.tensor.process_group import ProcessGroup # TODO(jiaruifang) circle import, move the divide to colossalai.commons. diff --git a/colossalai/tensor/param_op_hook.py b/colossalai/tensor/param_op_hook.py index 03cb090a6..23fad971c 100644 --- a/colossalai/tensor/param_op_hook.py +++ b/colossalai/tensor/param_op_hook.py @@ -1,9 +1,11 @@ -import torch -from contextlib import contextmanager from abc import ABC, abstractmethod -from typing import List, Tuple, Any +from contextlib import contextmanager +from typing import Any, List, Tuple + +import torch + from colossalai.tensor.colo_tensor import ColoTensor -from colossalai.tensor import ColoTensorSpec +from colossalai.tensor.tensor_spec import ColoTensorSpec class ParamOpHook(ABC): diff --git a/colossalai/tensor/sharding_spec.py b/colossalai/tensor/sharding_spec.py index c8bce731e..cdd033885 100644 --- a/colossalai/tensor/sharding_spec.py +++ b/colossalai/tensor/sharding_spec.py @@ -6,6 +6,8 @@ import torch from colossalai.device.device_mesh import DeviceMesh +from .utils import merge_same_dim_mesh_list + __all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec'] ALLGATHER_COST = 20 @@ -181,8 +183,12 @@ class ShardingSpec: self.dim_partition_dict = dim_partition_dict self.sharding_sequence = sharding_sequence if self.sharding_sequence is None: + assert self.dim_partition_dict is not None, f'dim_partition_dict should not be None, if sharding_sequence is NoneType object.' + self.dim_partition_dict = merge_same_dim_mesh_list(dim_size=len(entire_shape), + dim_partition_dict=self.dim_partition_dict) self.convert_dict_to_shard_sequence() elif self.dim_partition_dict is None: + assert self.sharding_sequence is not None, f'sharding_sequence should not be None, if dim_partition_dict is NoneType object.' self.convert_shard_sequence_to_dict() self._sanity_check() diff --git a/colossalai/tensor/tensor_spec.py b/colossalai/tensor/tensor_spec.py index 23dd3b9af..580df9f8f 100644 --- a/colossalai/tensor/tensor_spec.py +++ b/colossalai/tensor/tensor_spec.py @@ -1,14 +1,16 @@ +from dataclasses import dataclass from typing import Optional -from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern + +from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec +from colossalai.tensor.process_group import ProcessGroup + from .compute_spec import ComputeSpec -from colossalai.tensor import ProcessGroup -from dataclasses import dataclass @dataclass class ColoTensorSpec: """ ColoTensorSpec - + A data class for specifications of the `ColoTensor`. It contains attributes of `ProcessGroup`, `_DistSpec`, `ComputeSpec`. The latter two attributes are optional. If not set, they are default value is `Replicate()` and `None`. diff --git a/colossalai/tensor/utils.py b/colossalai/tensor/utils.py index b2eda5a8d..c5ffc9fb5 100644 --- a/colossalai/tensor/utils.py +++ b/colossalai/tensor/utils.py @@ -1,7 +1,8 @@ -import torch +from typing import Dict, Iterator, List, Tuple, Union -from typing import Iterator, Tuple, Union +import torch import torch.nn as nn + from colossalai.tensor.colo_tensor import ColoTensor @@ -12,7 +13,7 @@ def all_gather_simulator(target_pair): We don't allow uncontiguous layout, such as all-gather(S012)->S02 is NOT allowed. Therefore, all gather operation just remove the last element in shard list, - e.g.: + e.g.: all-gather(S01) -> S0 Argument: @@ -31,18 +32,18 @@ def all_to_all_simulator(f_target_pair, b_target_pair): and simulate the influence of the DimSpec. We BANNED all representations which shard_list in decreasing order, - such as S10, so all-to-all(S0, S1) -> RS01 is NOT allowed. + such as S10, so all-to-all(S0, S1) -> RS01 is NOT allowed. Therefore, if the behind shard_list is not None, we just extend it to the front shard_list. Argument: target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded, and the second element decribes which logical axis will be sharded in that dimension. - e.g.: + e.g.: all-to-all(S0, S1) -> [S01, R] all-to-all(S0, R) -> [R, S0] Otherwise, we extend the front shard_list to behind. - e.g.: + e.g.: all-to-all(R, S1) -> [S1, R] - + Argument: target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded, and the second element decribes which logical axis will be sharded in that dimension. @@ -65,7 +66,7 @@ def shard_simulator(target_pair, legal_sharding_dims): and simulate the influence of the DimSpec. We don't allow uncontiguous layout, such as shard(S0)->S02 is NOT allowed. - In addition, We BANNED all representations which shard_list in decreasing order, + In addition, We BANNED all representations which shard_list in decreasing order, such as S10, so shard(S0) -> S10 is NOT allowed. Therefore, for the R dimension, we could just append any legal sharding dim on it. e.g.: @@ -164,3 +165,37 @@ def convert_parameter(module: torch.nn.Module, param_name: str): # Now we can set the attribute appropriately. setattr(module, param_name, st) + + +def convert_dim_partition_dict(dim_size: int, dim_partition_dict: Dict[int, List[int]]) -> Dict[int, List[int]]: + ''' + This method is used to convert the negative dim value to positive. + ''' + dims_to_convert = [] + for dim, mesh_list in dim_partition_dict.items(): + if dim < 0: + dims_to_convert.append(dim) + for dim in dims_to_convert: + dim_partition_dict.pop(dim) + dim_partition_dict[dim_size + dim] = mesh_list + return dim_partition_dict + + +def merge_same_dim_mesh_list(dim_size: int, dim_partition_dict: Dict[int, List[int]]) -> Dict[int, List[int]]: + ''' + This method is used to merge the different key value which points to same physical position. + + For example: + dim_partition_dict: {1 :[0], -1: [1]} or {1: [0], 1: [1]} for a 2d tensor, the dim 1 and -1 point same physical position. + In this method, above dim_partition_dict will be converted to {1: [0, 1]} + ''' + converted_dim_partition_dict = {} + for dim, mesh_list in dim_partition_dict.items(): + if dim < 0: + dim = dim_size + dim + if dim not in converted_dim_partition_dict: + converted_dim_partition_dict[dim] = mesh_list + else: + converted_dim_partition_dict[dim].extend(mesh_list) + + return converted_dim_partition_dict