Browse Source

[autoparallel] handled illegal sharding strategy (#1728)

* [autoparallel] handled illegal sharding strategy

* polish code
pull/1743/head
Frank Lee 2 years ago committed by GitHub
parent
commit
eee84908d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 20
      colossalai/auto_parallel/tensor_shard/deprecated/_utils.py
  2. 19
      colossalai/auto_parallel/tensor_shard/deprecated/op_handler/batch_norm_handler.py
  3. 32
      colossalai/auto_parallel/tensor_shard/deprecated/op_handler/bcast_op_handler.py
  4. 27
      colossalai/auto_parallel/tensor_shard/deprecated/op_handler/conv_handler.py
  5. 33
      colossalai/auto_parallel/tensor_shard/deprecated/op_handler/dot_handler.py
  6. 19
      colossalai/auto_parallel/tensor_shard/deprecated/op_handler/embedding_handler.py
  7. 14
      colossalai/auto_parallel/tensor_shard/deprecated/op_handler/layer_norm_handler.py
  8. 17
      colossalai/auto_parallel/tensor_shard/deprecated/op_handler/reshape_handler.py
  9. 22
      colossalai/auto_parallel/tensor_shard/deprecated/op_handler/unary_elementwise_handler.py
  10. 12
      colossalai/auto_parallel/tensor_shard/deprecated/op_handler/where_handler.py
  11. 183
      colossalai/auto_parallel/tensor_shard/node_handler/dot_handler.py
  12. 8
      colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py
  13. 45
      colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py
  14. 10
      colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py
  15. 9
      colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py
  16. 22
      colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py
  17. 11
      colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py
  18. 8
      colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py
  19. 8
      colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py
  20. 3
      colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py
  21. 31
      colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py
  22. 8
      colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py
  23. 8
      colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py
  24. 4
      colossalai/auto_parallel/tensor_shard/utils/__init__.py
  25. 19
      colossalai/auto_parallel/tensor_shard/utils/misc.py
  26. 59
      colossalai/tensor/sharding_spec.py
  27. 0
      tests/test_auto_parallel/__init__.py
  28. 0
      tests/test_auto_parallel/test_tensor_shard/__init__.py
  29. 10
      tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_handler.py
  30. 8
      tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_matmul.py
  31. 0
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/__init__.py
  32. 37
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/common.py
  33. 28
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py
  34. 1
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py
  35. 18
      tests/test_tensor/test_sharded_linear.py
  36. 5
      tests/test_tensor/test_sharding_spec.py

20
colossalai/auto_parallel/tensor_shard/deprecated/_utils.py

@ -1,13 +1,15 @@
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
import torch
from torch.fx.node import Node
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
from typing import Union, Dict, List, Optional
import warnings
from functools import reduce
import functools
import operator
import warnings
from functools import reduce
from typing import Dict, List, Optional, Union
import torch
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from torch.fx.node import Node
from .constants import INFINITY_COST
@ -87,7 +89,7 @@ def generate_resharding_costs(nodes: List[Node],
return resharding_costs
def exception_handler(func):
def ignore_sharding_exception(func):
"""
A function wrapper which executes the function with a specified seed.
"""

19
colossalai/auto_parallel/tensor_shard/deprecated/op_handler/batch_norm_handler.py

@ -1,9 +1,12 @@
import operator
from functools import reduce
import torch
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
from .operator_handler import OperatorHandler
from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler
__all__ = ['BatchNormHandler']
@ -110,7 +113,7 @@ class BatchNormHandler(OperatorHandler):
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation
@exception_handler
@ignore_sharding_exception
def split_input_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
@ -185,7 +188,7 @@ class BatchNormHandler(OperatorHandler):
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'
@ -226,7 +229,7 @@ class BatchNormHandler(OperatorHandler):
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def non_split(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RR x R'
@ -322,7 +325,7 @@ class BatchNormHandler(OperatorHandler):
new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name)
self.strategies_vector.append(new_sharding_strategy)
@exception_handler
@ignore_sharding_exception
def split_input_batch(self, mesh_dim_0):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'
@ -363,7 +366,7 @@ class BatchNormHandler(OperatorHandler):
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'
@ -404,7 +407,7 @@ class BatchNormHandler(OperatorHandler):
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'

32
colossalai/auto_parallel/tensor_shard/deprecated/op_handler/bcast_op_handler.py

@ -1,14 +1,18 @@
import operator
from functools import reduce
import warnings
from copy import deepcopy
from functools import reduce
from typing import Dict, List
import torch
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler
from colossalai.auto_parallel.tensor_shard.deprecated._utils import (enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
ignore_sharding_exception)
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from copy import deepcopy
from typing import Dict, List
from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding
from .operator_handler import OperatorHandler
__all__ = ['BcastOpHandler']
@ -136,7 +140,7 @@ class BcastOpHandler(OperatorHandler):
return output_sharding_spec_list
@exception_handler
@ignore_sharding_exception
def _register_strategy(self, output_sharding_spec):
dim_partition_dict_for_input = output_sharding_spec.dim_partition_dict
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_input)
@ -171,7 +175,7 @@ class BcastOpHandler(OperatorHandler):
##############################################
#used to generate strategies for torch.matmul#
##############################################
@exception_handler
@ignore_sharding_exception
def _registry_no_split_strategies_for_matmul(self, dim_partition_dict_for_batch_dim):
# this dim partition dict only describes the batch dimensions, but in this scenario,
# matrix dimensions are fully replicated, so it do not need extra process.
@ -210,7 +214,7 @@ class BcastOpHandler(OperatorHandler):
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def _split_dim_i(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix):
# A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j]
# this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it.
@ -268,7 +272,7 @@ class BcastOpHandler(OperatorHandler):
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def _split_dim_k(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix):
# A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j]
# this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it.
@ -332,7 +336,7 @@ class BcastOpHandler(OperatorHandler):
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def _split_dim_j(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix):
# A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j]
# this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it.
@ -398,7 +402,7 @@ class BcastOpHandler(OperatorHandler):
self._split_dim_k(dim_partition_dict, mesh_dim_list)
self._split_dim_j(dim_partition_dict, mesh_dim_list)
@exception_handler
@ignore_sharding_exception
def _split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
dim_partition_dict_for_lhs = {-2: [mesh_dim_0], -1: [mesh_dim_1]}
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
@ -435,7 +439,7 @@ class BcastOpHandler(OperatorHandler):
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def _split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
dim_partition_dict_for_lhs = {-1: [mesh_dim_0]}
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
@ -474,7 +478,7 @@ class BcastOpHandler(OperatorHandler):
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def _split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
dim_partition_dict_for_lhs = {-2: [mesh_dim_0]}
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)

27
colossalai/auto_parallel/tensor_shard/deprecated/op_handler/conv_handler.py

@ -1,10 +1,13 @@
import operator
from functools import reduce
import warnings
from functools import reduce
import torch
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
from .operator_handler import OperatorHandler
from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler
__all__ = ['ConvHandler']
@ -105,7 +108,7 @@ class ConvHandler(OperatorHandler):
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight
@exception_handler
@ignore_sharding_exception
def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
@ -153,7 +156,7 @@ class ConvHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def split_input_batch(self, mesh_dim_0):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR'
@ -199,7 +202,7 @@ class ConvHandler(OperatorHandler):
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
@ -245,7 +248,7 @@ class ConvHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
@ -288,7 +291,7 @@ class ConvHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def split_input_in_channel_weight_in_channel(self, mesh_dim_0):
name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R'
@ -331,7 +334,7 @@ class ConvHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def split_weight_out_channel(self, mesh_dim_0):
name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
@ -374,7 +377,7 @@ class ConvHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def non_split(self):
name = f'RR = RR x RR'
@ -415,7 +418,7 @@ class ConvHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
@ -463,7 +466,7 @@ class ConvHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'

33
colossalai/auto_parallel/tensor_shard/deprecated/op_handler/dot_handler.py

@ -1,15 +1,18 @@
import operator
from enum import Enum
from functools import reduce
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
from ..constants import LINEAR_FUNC_OP, LINEAR_MODULE_OP
from functools import reduce
from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler
from enum import Enum
from .strategy_generator import StrategyGenerator, IntermediateStrategy
from typing import List
from .operator_handler import OperatorHandler
from .strategy_generator import IntermediateStrategy, StrategyGenerator
__all__ = ['DotHandler']
@ -415,7 +418,7 @@ class DotHandler(OperatorHandler):
compute_cost = reduce(operator.mul, input_shape) * weight_shape[0] * 2 // total_sharding_size
return compute_cost
@exception_handler
@ignore_sharding_exception
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
# handle case SS = SR x RS
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
@ -456,7 +459,7 @@ class DotHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
# handle the case SR = SS x SR
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
@ -496,7 +499,7 @@ class DotHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
@ -534,7 +537,7 @@ class DotHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def recompute_split_both_contract(self, mesh_dim):
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
@ -569,7 +572,7 @@ class DotHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def split_rhs_space_only(self, mesh_dim):
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
@ -605,7 +608,7 @@ class DotHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
@ -641,7 +644,7 @@ class DotHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
@ -678,7 +681,7 @@ class DotHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'

19
colossalai/auto_parallel/tensor_shard/deprecated/op_handler/embedding_handler.py

@ -1,14 +1,17 @@
import operator
from functools import reduce
import warnings
from copy import deepcopy
from functools import reduce
from typing import Dict, List
import torch
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from copy import deepcopy
from typing import Dict, List
from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler
from .operator_handler import OperatorHandler
__all__ = ['EmbeddingHandler']
@ -76,7 +79,7 @@ class EmbeddingHandler(OperatorHandler):
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight
@exception_handler
@ignore_sharding_exception
def split_weight_both_dim(self, mesh_dim_0, mesh_dim_1):
name = f'RRS{mesh_dim_1} = RR x S{mesh_dim_0}S{mesh_dim_1}'
@ -117,7 +120,7 @@ class EmbeddingHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1}R = S{mesh_dim_0}S{mesh_dim_1} x RR'

14
colossalai/auto_parallel/tensor_shard/deprecated/op_handler/layer_norm_handler.py

@ -1,9 +1,13 @@
import operator
from functools import reduce
import torch
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.auto_parallel.tensor_shard.deprecated._utils import (enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
generate_sharding_size, ignore_sharding_exception)
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
from .operator_handler import OperatorHandler
from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler, enumerate_all_possible_2d_sharding, enumerate_all_possible_1d_sharding, generate_sharding_size
__all__ = ['LayerNormHandler']
@ -149,21 +153,21 @@ class LayerNormHandler(OperatorHandler):
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def split_input_batch_single_mesh_dim(self, mesh_dim_0):
batch_dimension_length = self.input_data.dim() - self.weight.dim()
dim_partition_list = enumerate_all_possible_1d_sharding(mesh_dim_0, batch_dimension_length)
for dim_partition in dim_partition_list:
self._generate_strategy_with_dim_partition(dim_partition)
@exception_handler
@ignore_sharding_exception
def split_input_batch_both_mesh_dim(self, mesh_dim_0, mesh_dim_1):
batch_dimension_length = self.input_data.dim() - self.weight.dim()
dim_partition_list = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, batch_dimension_length)
for dim_partition in dim_partition_list:
self._generate_strategy_with_dim_partition(dim_partition)
@exception_handler
@ignore_sharding_exception
def non_split(self):
name = f'RR = RR x R'

17
colossalai/auto_parallel/tensor_shard/deprecated/op_handler/reshape_handler.py

@ -1,14 +1,17 @@
import colorsys
from .operator_handler import OperatorHandler
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from copy import deepcopy
import math
from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler
import warnings
from copy import deepcopy
import torch
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from ..constants import INFINITY_COST
from .operator_handler import OperatorHandler
class ReshapeHandler(OperatorHandler):
@ -24,7 +27,7 @@ class ReshapeHandler(OperatorHandler):
def _generate_compute_cost(self, *args, **kwargs):
return super()._generate_compute_cost(*args, **kwargs)
@exception_handler
@ignore_sharding_exception
def register_strategy(self):
# TODO: add strategies with more output sharding specs other than only fully replicated.
input_node = self.strategies_vector.predecessor_nodes[0]

22
colossalai/auto_parallel/tensor_shard/deprecated/op_handler/unary_elementwise_handler.py

@ -1,16 +1,20 @@
import math
import operator
from functools import reduce
import warnings
from copy import deepcopy
from functools import reduce
from typing import Dict, List
import torch
from colossalai.auto_parallel.tensor_shard.deprecated.constants import INFINITY_COST
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.constants import \
INFINITY_COST
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from copy import deepcopy
from typing import Dict, List
import math
from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler
from .operator_handler import OperatorHandler
__all__ = ['UnaryElementwiseHandler']
@ -40,7 +44,7 @@ class UnaryElementwiseHandler(OperatorHandler):
def _generate_compute_cost(self, *args, **kwargs):
return super()._generate_compute_cost(*args, **kwargs)
@exception_handler
@ignore_sharding_exception
def register_strategy(self):
# TODO: integrate element-wise func and module together
# create sharding strategy for element-wise function

12
colossalai/auto_parallel/tensor_shard/deprecated/op_handler/where_handler.py

@ -6,12 +6,10 @@ from typing import Dict, List
import torch
from colossalai.auto_parallel.tensor_shard.deprecated._utils import (
enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
exception_handler,
)
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.auto_parallel.tensor_shard.deprecated._utils import (enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
ignore_sharding_exception)
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
@ -146,7 +144,7 @@ class WhereHandler(OperatorHandler):
return output_sharding_spec_list
@exception_handler
@ignore_sharding_exception
def _register_strategy(self, output_sharding_spec):
dim_partition_dict_for_input = output_sharding_spec.dim_partition_dict
sharding_spec_for_condition = self._generate_sharding_spec(self.condition_data, dim_partition_dict_for_input)

183
colossalai/auto_parallel/tensor_shard/node_handler/dot_handler.py

@ -5,7 +5,8 @@ import torch
import torch.nn.functional as F
from colossalai.auto_parallel.tensor_shard.utils import (switch_partition_dim, update_partition_dim)
from colossalai.tensor.sharding_spec import ShardingException
from colossalai.logging import get_dist_logger
from colossalai.tensor.sharding_spec import ShardingNotDivisibleError
from ..sharding_strategy import (OperationData, OperationDataType, ShardingStrategy)
from .node_handler import ModuleHandler, NodeHandler
@ -15,6 +16,100 @@ from .strategy import (BatchedMatMulStrategyGenerator, LinearProjectionStrategyG
__all__ = ['LinearModuleHandler', 'LinearFunctionHandler', 'BMMFunctionHandler']
def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStrategy,
weight_name: str) -> ShardingStrategy:
"""
This function is a helper function used by both module node handler and function node handler. This function will
convert the sharding spec for the transposed weight to the correct partititon spec.
Args:
strategy (ShardingStrategy): the strategy generated by the strategy generator.
weight_name (str): the name of the OperationData object for the weight.
"""
# switch the dimensions of the transposed weight
sharding_spec = strategy.get_sharding_spec_by_name(weight_name)
op_data = strategy.get_op_data_by_name(weight_name)
assert op_data.logical_shape != op_data.data.shape, \
"Expected the logical and physical shape of the linear operator's weight to be different, but found them to be the same"
switch_partition_dim(sharding_spec, 0, -1)
return strategy
def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: ShardingStrategy, input_name: str,
output_name: str) -> List[ShardingStrategy]:
"""
This function converts the logical sharding spec to the physical sharding spec for both the input and output of the linear operation. The input and output
should have the same sharding spec.
Args:
strategy (ShardingStrategy): the logical strategy generated by the strategy generator.
input_name (str): the name of the OperationData object for the input.
output_name (str): the name of the OperationData object for the output.
"""
# the result will be a list of strategies
sharding_strategies = []
# get operation data
input_op_data = strategy.get_op_data_by_name(input_name)
output_op_data = strategy.get_op_data_by_name(output_name)
input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name)
# get logger for debug message
logger = get_dist_logger()
# for the input of the linear operation, it can be multi-dimensional. The sharding spec generated is only
# 2D, where the first dimension is non-matrix dimension and the last dimension is the matrix dimension.
# the logical non-matrix dimension can belong to the 0th to (N-1)th dimension of the physical input shape.
# Thus, we enumerate to get all possible cases.
if 0 in input_sharding_spec.dim_partition_dict:
# if 0 is in the dim_partition_dict, it means that the
# the generated sharding strategy does shard the non-matrix dimension,
# in this case, we need to do enumeration
num_input_dims = input_op_data.data.dim()
for i in range(num_input_dims - 1):
strategy_copy = strategy.clone()
input_sharding_spec = strategy_copy.get_sharding_spec_by_name(input_op_data.name)
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
try:
# replace the 0th dimension in the logical sharding with ith dimension in the physical sharding
update_partition_dim(sharding_spec=input_sharding_spec,
dim_mapping={0: i},
physical_shape=input_op_data.data.shape,
inplace=True)
update_partition_dim(sharding_spec=output_sharding_spec,
dim_mapping={0: i},
physical_shape=output_op_data.data.shape,
inplace=True)
sharding_strategies.append(strategy_copy)
except ShardingNotDivisibleError as e:
logger.debug(
f'Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}'
)
else:
# the generated sharding strategy does not shard the non-matrix dimension,
# in this case, we don't need to do enumeration
# but instead, we still need to convert the logical shape to physical shape
strategy_copy = strategy.clone()
input_sharding_spec = strategy_copy.get_sharding_spec_by_name(input_op_data.name)
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
# after updating, the logical shape will be replaced by the physical shape
update_partition_dim(sharding_spec=input_sharding_spec,
dim_mapping={},
physical_shape=input_op_data.data.shape,
inplace=True)
update_partition_dim(sharding_spec=output_sharding_spec,
dim_mapping={},
physical_shape=output_op_data.data.shape,
inplace=True)
print(input_op_data.data.shape)
print(output_op_data.data.shape)
sharding_strategies.append(strategy_copy)
return sharding_strategies
@operator_registry.register(torch.nn.Linear)
class LinearModuleHandler(ModuleHandler):
"""
@ -58,44 +153,20 @@ class LinearModuleHandler(ModuleHandler):
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
"""
Convert the sharding spec from the logical shape to the physical shape.
Convert the sharding spec from the logical shape to the physical shape. In this function, two tasks are completed:
1. the sharding spec is updated for the transposed weight
2. the input and output sharding specs are updated to physical shape.
"""
# switch the dimensions of the transposed weight
for op_data, sharding_spec in strategy.input_sharding_specs.items():
if op_data.name == "weight":
assert op_data.logical_shape != op_data.data.shape
switch_partition_dim(sharding_spec, 0, -1)
strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy, weight_name='weight')
# create multiple sharding strategies for the inputs
# as input can be multi-dimensinal and the partition dim is only 2D,
# we need to map the partition at dim 0 to one of the first few dimensions of the input
sharding_strategies = []
input_op_data = strategy.get_op_data_by_name(str(self.node.args[0]))
output_op_data = strategy.get_op_data_by_name(str(self.node))
num_input_dims = input_op_data.data.dim()
input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name)
if 0 in input_sharding_spec.dim_partition_dict:
for i in range(num_input_dims - 1):
new_strategy = strategy.clone()
input_sharding_spec = new_strategy.get_sharding_spec_by_name(input_op_data.name)
output_sharding_spec = new_strategy.get_sharding_spec_by_name(output_op_data.name)
try:
update_partition_dim(sharding_spec=input_sharding_spec,
dim_mapping={0: i},
physical_shape=input_op_data.data.shape,
inplace=True)
update_partition_dim(sharding_spec=output_sharding_spec,
dim_mapping={0: i},
physical_shape=output_op_data.data.shape,
inplace=True)
sharding_strategies.append(new_strategy)
except ShardingException:
pass
else:
sharding_strategies.append(strategy)
return sharding_strategies
strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy=strategy,
input_name=str(self.node.args[0]),
output_name=str(self.node))
return strategies
@operator_registry.register(F.linear)
@ -113,9 +184,12 @@ class LinearFunctionHandler(NodeHandler):
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
input_meta_data = self.node.args[0]._meta_data
input_logical_shape = input_meta_data.view(-1, input_meta_data.shape[-1]).shape
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data)
data=self.node.args[0]._meta_data,
logical_shape=input_logical_shape)
# check if the other operand is a parameter
if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter):
@ -144,44 +218,17 @@ class LinearFunctionHandler(NodeHandler):
return mapping
def post_process(self, strategy: ShardingStrategy):
"""
Convert the sharding spec of the weight parameter back to its original shape.
"""
for op_data, sharding_spec in strategy.input_sharding_specs.items():
if op_data.name == str(self.node.args[1]):
assert op_data.logical_shape != op_data.data.shape
switch_partition_dim(sharding_spec, 0, -1)
# switch the dimensions of the transposed weight
strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy,
weight_name=str(self.node.args[1]))
# create multiple sharding strategies for the inputs
# as input can be multi-dimensinal and the partition dim is only 2D,
# we need to map the partition at dim 0 to one of the first few dimensions of the input
sharding_strategies = []
input_op_data = strategy.get_op_data_by_name(str(self.node.args[0]))
output_op_data = strategy.get_op_data_by_name(str(self.node))
num_input_dims = input_op_data.data.dim()
input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name)
if 0 in input_sharding_spec.dim_partition_dict:
for i in range(num_input_dims - 1):
new_strategy = strategy.clone()
input_sharding_spec = new_strategy.get_sharding_spec_by_name(input_op_data.name)
output_sharding_spec = new_strategy.get_sharding_spec_by_name(output_op_data.name)
try:
update_partition_dim(sharding_spec=input_sharding_spec,
dim_mapping={0: i},
physical_shape=input_op_data.data.shape,
inplace=True)
update_partition_dim(sharding_spec=output_sharding_spec,
dim_mapping={0: i},
physical_shape=output_op_data.data.shape,
inplace=True)
sharding_strategies.append(new_strategy)
except ShardingException:
pass
else:
sharding_strategies.append(strategy)
return strategy
strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy=strategy,
input_name=str(self.node.args[0]),
output_name=str(self.node))
return strategies
@operator_registry.register(torch.bmm)

8
colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py

@ -1,6 +1,7 @@
import copy
import operator
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
@ -292,7 +293,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def generate(self):
def collate_strategies(self) -> List[ShardingStrategy]:
'''
Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector.
'''
@ -325,9 +326,4 @@ class BatchNormStrategyGenerator(StrategyGenerator):
# S01R = S01R x R WITH SYNC_BN
# strategy_list.append(self.split_input_batch_1d(0, 1))
for strategy in strategy_list:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return strategy_list

45
colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py

@ -5,7 +5,8 @@ from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.utils import exception_handler
from colossalai.auto_parallel.tensor_shard.utils import \
ignore_sharding_exception
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
@ -25,8 +26,8 @@ class ConvStrategyGenerator(StrategyGenerator):
For Conv3d, the dim of input data should be 5([N, C, H, W, D]).
'''
input_op_data = self.op_data['input']
assert input_op_data.dim() in (3, 4,
5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
assert input_op_data.data.dim() in (
3, 4, 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
def update_compute_cost(self, strategy: ShardingStrategy):
'''
@ -99,7 +100,7 @@ class ConvStrategyGenerator(StrategyGenerator):
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@exception_handler
@ignore_sharding_exception
def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
@ -146,7 +147,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@exception_handler
@ignore_sharding_exception
def split_input_batch(self, mesh_dim_0):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR'
@ -183,7 +184,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@exception_handler
@ignore_sharding_exception
def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
@ -230,7 +231,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@exception_handler
@ignore_sharding_exception
def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
@ -270,7 +271,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@exception_handler
@ignore_sharding_exception
def split_input_in_channel_weight_in_channel(self, mesh_dim_0):
name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R'
@ -301,7 +302,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@exception_handler
@ignore_sharding_exception
def split_weight_out_channel(self, mesh_dim_0):
name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
@ -334,7 +335,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@exception_handler
@ignore_sharding_exception
def non_split(self):
name = f'RR = RR x RR'
@ -353,7 +354,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
@exception_handler
@ignore_sharding_exception
def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
@ -391,7 +392,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@exception_handler
@ignore_sharding_exception
def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
dim_partition_dict_mapping = {
@ -421,7 +422,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@exception_handler
@ignore_sharding_exception
def split_1d_parallel_on_out_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
dim_partition_dict_mapping = {
@ -453,7 +454,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def generate(self) -> List[ShardingStrategy]:
def collate_strategies(self) -> List[ShardingStrategy]:
strategies = []
# SS = SR x RS
strategies.append(self.split_input_batch_weight_out_channel(0, 1))
@ -491,20 +492,4 @@ class ConvStrategyGenerator(StrategyGenerator):
# RS01 = RR x RS01
strategies.append(self.split_1d_parallel_on_out_channel(0, 1))
rm_list = [strategy for strategy in strategies if strategy is None]
for rm_element in rm_list:
strategies.remove(rm_element)
illegal_strategy_list = []
# update mete info on cost
for strategy in strategies:
try:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
except AssertionError as e:
illegal_strategy_list.append(strategy)
warnings.warn(f'{e}')
for strategy in illegal_strategy_list:
strategies.remove(strategy)
return strategies

10
colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py

@ -1,4 +1,5 @@
import copy
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
@ -61,7 +62,7 @@ class TensorStrategyGenerator(GetItemStrategyGenerator):
Deal with case 1 and 2.
'''
def generate(self):
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
for strategy in self.predecessor_node.strategies_vector:
dim_partition_dict_mapping = {}
@ -109,7 +110,7 @@ class TensorTupleStrategyGenerator(GetItemStrategyGenerator):
Deal with case 3.
'''
def generate(self):
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
index = self.op_data["index"].data
@ -133,9 +134,4 @@ class TensorTupleStrategyGenerator(GetItemStrategyGenerator):
strategy_list.append(strategy)
for strategy in strategy_list:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return strategy_list

9
colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py

@ -1,6 +1,7 @@
import copy
import operator
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding,
@ -159,7 +160,7 @@ class LayerNormGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def generate(self):
def collate_strategies(self) -> List[ShardingStrategy]:
'''
Generate every possible strategies for a LayerNorm node, and record all strategies into the strategies_vector.
'''
@ -178,11 +179,5 @@ class LayerNormGenerator(StrategyGenerator):
# RR = RR x R
strategy_list.append(self.non_split())
# update mete info on cost
for strategy in strategy_list:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return strategy_list

22
colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py

@ -3,6 +3,8 @@ from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.utils import \
ignore_sharding_exception
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
@ -169,7 +171,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
total=fwd_compute_cost + bwd_compute_cost)
strategy.compute_cost = compute_cost
def generate(self) -> List[ShardingStrategy]:
def collate_strategies(self) -> List[ShardingStrategy]:
strategies = []
# SS = SR x RS
@ -201,14 +203,9 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# RS01 = RR x RS01
strategies.append(self.split_rhs_2nd_dim_1d(0, 1))
# update mete info on cost
for strategy in strategies:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return strategies
@ignore_sharding_exception
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
# handle case SS = SR x RS
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
@ -249,6 +246,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
# handle the case SR = SS x SR
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
@ -289,6 +287,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
@ -324,6 +323,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def recompute_split_both_contract(self, mesh_dim):
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
@ -351,6 +351,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_rhs_space_only(self, mesh_dim):
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
@ -380,6 +381,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
# get sharding spec
@ -410,6 +412,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communcation_action_mapping)
@ignore_sharding_exception
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
@ -437,6 +440,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
@ -542,7 +546,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mappingp['bias'] = bias_comm_spec
communication_action_mapping['bias'] = bias_comm_spec
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
@ -662,7 +666,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def generate(self) -> List[ShardingStrategy]:
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
device_mesh_is_1d = True
if len(self.device_mesh.mesh_shape) == 2 and 1 not in self.device_mesh.mesh_shape:

11
colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py

@ -25,8 +25,8 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
For Pool3d, the dim of input data should be 5([N, C, H, W, D]).
'''
input_op_data = self.op_data['input']
assert input_op_data.dim() in (3, 4,
5), f'We suppose the dim of input fed into Pool op should in range of [3, 5].'
assert input_op_data.data.dim() in (
3, 4, 5), f'We suppose the dim of input fed into Pool op should in range of [3, 5].'
def update_compute_cost(self, strategy: ShardingStrategy) -> TrainCycleItem:
'''
@ -103,7 +103,7 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
return dim_partition_list
def generate(self) -> List[ShardingStrategy]:
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
dim_partition_list = self.enumerate_all_possible_batch_dimensions_dim_partition(0, 1)
@ -111,9 +111,4 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
strategy = self._generate_strategy_with_dim_partition(dim_partition)
strategy_list.append(strategy)
for strategy in strategy_list:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return strategy_list

8
colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py

@ -1,3 +1,5 @@
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from .strategy_generator import OutputStrategyGenerator
@ -30,7 +32,7 @@ class OutputGenerator(OutputStrategyGenerator):
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def generate(self):
def collate_strategies(self) -> List[ShardingStrategy]:
dim_partition_dict_mapping = {
"output": {},
}
@ -47,8 +49,4 @@ class OutputGenerator(OutputStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return [strategy]

8
colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py

@ -1,3 +1,5 @@
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from .strategy_generator import StrategyGenerator
@ -35,7 +37,7 @@ class PlaceholderGenerator(StrategyGenerator):
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def generate(self):
def collate_strategies(self) -> List[ShardingStrategy]:
dim_partition_dict_mapping = {
"output": {},
}
@ -48,8 +50,4 @@ class PlaceholderGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return [strategy]

3
colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py

@ -1,4 +1,5 @@
import copy
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
@ -49,7 +50,7 @@ class ReshapeGenerator(FollowingStrategyGenerator):
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def generate(self):
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
# For reshape function, to keep the computing correctness we keep the sharding
# spec of input is fully replicated. In addition, we will keep the output in

31
colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py

@ -4,13 +4,12 @@ from functools import reduce
from typing import Any, Dict, List, Union
import torch
from torch.fx import Node
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, ShardingStrategy,
TrainCycleItem)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
from colossalai.tensor.sharding_spec import ShardingSpec
from torch.fx import Node
class StrategyGenerator(ABC):
@ -24,6 +23,9 @@ class StrategyGenerator(ABC):
self.op_data = operation_data_mapping
self.device_mesh = device_mesh
# validate the whether operation data is of desired value
self.validate()
@property
def has_bias(self):
"""
@ -102,9 +104,9 @@ class StrategyGenerator(ABC):
comm_cost = TrainCycleItem(fwd=0, bwd=0, total=0)
def _compute_and_add(data: OperationData, comm_spec: CommSpec):
def _compute_and_add(op_data: OperationData, comm_spec: CommSpec):
num_ele_in_comm = comm_spec.get_comm_cost()
dtype = operand.data.dtype
dtype = op_data.data.dtype
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
for phase, cost in num_ele_in_comm.items():
num_ele_in_comm[phase] = num_ele_in_comm[phase] * size_per_elem_bytes
@ -151,11 +153,30 @@ class StrategyGenerator(ABC):
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
return reduce(operator.mul, sharded_shape) * size_per_elem_bytes
@abstractmethod
def generate(self) -> List[ShardingStrategy]:
"""
Generate all possible sharding strategies for this operation.
"""
strategies = self.collate_strategies()
# some strategies may be None as ignore_sharding_exception may return None
# when ShardingSpecException occurs.
# thus, remove those None values
strategies = [strategy for strategy in strategies if strategy]
# update the costs
# update mete info on cost
# these update methods are all in-place, the default method will do nothing
# the cost info will only be added if the child class overrides these methods
for strategy in strategies:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return strategies
@abstractmethod
def collate_strategies(self) -> List[ShardingStrategy]:
pass
@abstractmethod

8
colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py

@ -1,4 +1,5 @@
import copy
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
@ -48,7 +49,7 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator):
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def generate(self):
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
# For element-wise function, we keep the sharding spec of output node same as
# the input. Therefore, the different strategies of input node with same
@ -73,9 +74,4 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator):
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
for strategy in strategy_list:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return strategy_list

8
colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py

@ -1,4 +1,5 @@
import copy
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding,
@ -78,7 +79,7 @@ class WhereGenerator(StrategyGenerator):
return dim_partition_list
def generate(self):
def collate_strategies(self) -> List[ShardingStrategy]:
'''
Generate every possible strategies for a where node, and record all strategies into the strategies_vector.
'''
@ -90,9 +91,4 @@ class WhereGenerator(StrategyGenerator):
strategy = self._generate_strategy_with_dim_partition(dim_partition)
strategy_list.append(strategy)
for strategy in strategy_list:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return strategy_list

4
colossalai/auto_parallel/tensor_shard/utils/__init__.py

@ -1,12 +1,12 @@
from .broadcast import (BroadcastType, get_broadcast_shape, is_broadcastable, recover_sharding_spec_for_broadcast_shape)
from .factory import generate_resharding_costs, generate_sharding_spec
from .misc import exception_handler
from .misc import ignore_sharding_exception
from .sharding import (enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding, generate_sharding_size,
switch_partition_dim, update_partition_dim)
__all__ = [
'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
'generate_resharding_costs', 'generate_sharding_spec', 'exception_handler', 'switch_partition_dim',
'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'switch_partition_dim',
'update_partition_dim', 'enumerate_all_possible_1d_sharding', 'enumerate_all_possible_2d_sharding',
'generate_sharding_size'
]

19
colossalai/auto_parallel/tensor_shard/utils/misc.py

@ -1,16 +1,19 @@
import functools
import warnings
__all__ = ['exception_handler']
from colossalai.logging import get_dist_logger
from colossalai.tensor.sharding_spec import ShardingSpecException
__all__ = ['ignore_sharding_exception']
def exception_handler(func):
def ignore_sharding_exception(func):
"""
A function wrapper to handle the AssertionError in the function.
A function wrapper to handle the ShardingSpecException in the function.
If ShardingSpecException occurs, this function will return None.
Usage:
# mute the assertion error in the function
@exception_handler
@ignore_sharding_exception
def do_something():
...
"""
@ -18,9 +21,11 @@ def exception_handler(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
logger = get_dist_logger()
rst = func(*args, **kwargs)
return rst
except AssertionError as e:
warnings.warn(f'{e}')
except ShardingSpecException as e:
logger.debug(e)
return None
return wrapper

59
colossalai/tensor/sharding_spec.py

@ -1,10 +1,12 @@
import torch
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
import operator
from copy import deepcopy
from enum import Enum
from functools import reduce
import operator
import torch
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.utils import (all_gather_simulator, all_to_all_simulator, shard_simulator)
__all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec']
@ -138,7 +140,19 @@ class _DimSpec:
return difference
class ShardingException(Exception):
class ShardingSpecException(Exception):
pass
class ShardingOutOfIndexError(ShardingSpecException):
pass
class DuplicatedShardingDimensionError(ShardingSpecException):
pass
class ShardingNotDivisibleError(ShardingSpecException):
pass
@ -156,7 +170,11 @@ class ShardingSpec:
sharding_sequence(List[_DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1].
'''
def __init__(self, device_mesh, entire_shape, dim_partition_dict=None, sharding_sequence=None):
def __init__(self,
device_mesh: DeviceMesh,
entire_shape: torch.Size,
dim_partition_dict=None,
sharding_sequence=None):
self.device_mesh = device_mesh
self.entire_shape = entire_shape
self.dim_partition_dict = dim_partition_dict
@ -174,19 +192,36 @@ class ShardingSpec:
return ' '.join(res_list)
def _sanity_check(self):
'''
In sanity check, we need make sure all axes in logical device mesh only be used
once.
'''
dim_check_list = [i for i in range(self.device_mesh.logical_mesh_id.dim())]
# make sure all axes in logical device mesh only be used once
dim_check_list = list(range(self.device_mesh.logical_mesh_id.dim()))
for dim, shard_list in self.dim_partition_dict.items():
for element in shard_list:
if element in dim_check_list:
dim_check_list.remove(element)
else:
raise ValueError(
raise DuplicatedShardingDimensionError(
f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.")
# make sure that the dimension is not out of index
for dim in self.dim_partition_dict.keys():
if dim >= len(self.entire_shape):
raise ShardingOutOfIndexError(
f"The dim_partition_dict specifies to shard dimension {dim} but the entire_shape only has {len(self.entire_shape)} dimensions"
)
# make sure that the sharding for a dimension is divisible by the number of devices
for dim, shard_list in self.dim_partition_dict.items():
tensor_dim_size = self.entire_shape[dim]
num_devices = 1
for element in shard_list:
num_devices *= self.device_mesh.mesh_shape[element]
if tensor_dim_size % num_devices != 0:
raise ShardingNotDivisibleError(
f'The size of dimension at index {dim} is {tensor_dim_size}, it cannot be sharded over {num_devices} devices.'
)
def convert_dict_to_shard_sequence(self):
'''
Convert dim_partition_dict into list of _DimSpec, and assign it to sharding_sequence.

0
tests/test_auto_parallel/__init__.py

0
tests/test_auto_parallel/test_tensor_shard/__init__.py

10
tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_handler.py

@ -1,12 +1,15 @@
from cProfile import run
import pytest
import torch
from torch.fx import GraphModule
import torch.nn as nn
import pytest
from torch.fx import GraphModule
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag
class ConvModel(nn.Module):
@ -27,6 +30,7 @@ class ConvModel(nn.Module):
return x
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_conv_handler():
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)

8
tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_matmul.py

@ -1,12 +1,13 @@
import pytest
import torch
from torch.fx import GraphModule
import torch.nn as nn
import pytest
from torch.fx import GraphModule
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag
class MatmulModel(nn.Module):
@ -20,6 +21,7 @@ class MatmulModel(nn.Module):
return x
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_conv_handler():
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)

0
tests/test_auto_parallel/test_tensor_shard/test_node_handler/__init__.py

37
tests/test_auto_parallel/test_tensor_shard/test_node_handler/common.py

@ -0,0 +1,37 @@
import torch
from colossalai.tensor.sharding_spec import ShardingSpec
def is_sharding_spec_valid(sharding_spec: ShardingSpec, tensor: torch.Tensor):
"""
This function checks whether the ShardingSpec is valid for the physical tensor.
This check includes 2 items:
1. the sharding spec covers all dimensions of the physical tensor
2. the sharding spec for each dimension is divisible by the number of devices.
#
"""
# make sure all dims are covered in sharding spec
sharding_len = len(sharding_spec.sharding_sequence)
tensor_num_dim = tensor.dim()
num_devices_in_col = sharding_spec.device_mesh.mesh_shape[0]
num_devices_in_row = sharding_spec.device_mesh.mesh_shape[1]
assert sharding_len == tensor_num_dim, \
f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).'
# make sure the sharding is valid for each dim
for i in range(tensor_num_dim):
dim_size = tensor.shape[i]
dim_spec = sharding_spec.sharding_sequence[i]
if str(dim_spec).startswith('S'):
devices_str = str(dim_spec).lstrip('S')
num_devices = 1
if '0' in devices_str:
num_devices *= num_devices_in_col
if '1' in devices_str:
num_devices *= num_devices_in_row
assert dim_size >= num_devices and dim_size % num_devices == 0, \
f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.'

28
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py

@ -6,12 +6,13 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationDa
StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.tensor.sharding_spec import ShardingSpec
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.common import \
is_sharding_spec_valid
def test_linear_module_handler():
model = nn.Sequential(nn.Linear(16, 32).to('meta'))
tracer = ColoTracer()
graph = tracer.trace(model, meta_args={"input": torch.rand(2, 2, 4, 16).to('meta')})
gm = ColoGraphModule(model, graph)
@ -91,6 +92,12 @@ def test_linear_module_handler():
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
output_sharding_spec = strategy.get_sharding_spec_by_name('_0')
# make sure the sharding spec is valid
is_sharding_spec_valid(input_sharding_spec, torch.rand(2, 2, 4, 16))
is_sharding_spec_valid(weight_sharding_spec, model.get_parameter('0.weight'))
is_sharding_spec_valid(bias_sharding_spec, model.get_parameter('0.bias'))
is_sharding_spec_valid(output_sharding_spec, torch.rand([2, 2, 4, 32]))
# make sure the sharding matches across different operation data
assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1]
assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
@ -101,7 +108,7 @@ def test_linear_module_handler():
def test_linear_function_handler():
model = nn.Linear(16, 32).to('meta')
tracer = ColoTracer()
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16).to('meta')})
graph = tracer.trace(model, meta_args={"input": torch.rand(2, 2, 4, 16).to('meta')})
gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4)
@ -117,11 +124,13 @@ def test_linear_function_handler():
# # check operation data mapping
mapping = handler.get_operation_data_mapping()
print(mapping['input'].logical_shape)
assert mapping['input'].name == "input_1"
assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 16])
assert mapping['input'].data.shape == torch.Size([2, 2, 4, 16])
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([4, 16])
assert mapping['input'].logical_shape == torch.Size([16, 16])
assert mapping['other'].name == "weight"
assert mapping['other'].data.is_meta
@ -137,7 +146,7 @@ def test_linear_function_handler():
assert mapping['output'].name == "linear"
assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([4, 32])
assert mapping['output'].data.shape == torch.Size([2, 2, 4, 32])
assert mapping['output'].type == OperationDataType.OUTPUT
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
@ -167,11 +176,18 @@ def test_linear_function_handler():
for strategy in strategies_vector:
strategy: ShardingStrategy
print(strategy)
input_sharding_spec = strategy.get_sharding_spec_by_name('input_1')
weight_sharding_spec = strategy.get_sharding_spec_by_name('weight')
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
output_sharding_spec = strategy.get_sharding_spec_by_name('linear')
# make sure the sharding spec is valid
is_sharding_spec_valid(input_sharding_spec, torch.rand(2, 2, 4, 16))
is_sharding_spec_valid(weight_sharding_spec, model.get_parameter('weight'))
is_sharding_spec_valid(bias_sharding_spec, model.get_parameter('bias'))
is_sharding_spec_valid(output_sharding_spec, torch.rand([2, 2, 4, 32]))
# make sure the sharding matches across different operation data
assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1]
assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]

1
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py

@ -1,6 +1,5 @@
import torch
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import \
ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler import \

18
tests/test_tensor/test_sharded_linear.py

@ -1,16 +1,18 @@
from functools import partial
from lib2to3 import pgen2
import colossalai
import torch
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn.functional as F
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from functools import partial
import colossalai
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.tensor import ColoTensor, ColoParameter, ProcessGroup
from colossalai.nn._ops._utils import gather_forward_split_backward
from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
def run_dist(rank, world_size, port):
@ -18,7 +20,7 @@ def run_dist(rank, world_size, port):
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# create mlp vars
x = ColoTensor.from_torch_tensor(torch.rand(2, 4, 8, requires_grad=True)).cuda()
x = ColoTensor.from_torch_tensor(torch.rand(4, 4, 8, requires_grad=True)).cuda()
w = ColoParameter.from_torch_tensor(torch.rand(16, 8, requires_grad=True)).cuda()
b = ColoParameter.from_torch_tensor(torch.rand(16, requires_grad=True)).cuda()

5
tests/test_tensor/test_sharding_spec.py

@ -1,6 +1,7 @@
import torch
from colossalai.tensor.sharding_spec import _DimSpec, ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
def test_sharding_spec():
@ -11,7 +12,7 @@ def test_sharding_spec():
# [8, 9, 10,11],
# [12,13,14,15]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 8, 6))
entire_shape = torch.Size((16, 8, 6))
dim_partition_dict = {0: [0, 1]}
# DistSpec:
# shard_sequence: S01,R,R

Loading…
Cancel
Save