Browse Source

[autoparallel] refactored shape consistency to remove redundancy (#1591)

* [autoparallel] refactored shape consistency to remove redundancy

* polish code

* polish code

* polish code
pull/1587/head^2
Frank Lee 2 years ago committed by GitHub
parent
commit
27fe8af60c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 45
      colossalai/auto_parallel/solver/_utils.py
  2. 57
      colossalai/auto_parallel/solver/op_handler/batch_norm_handler.py
  3. 73
      colossalai/auto_parallel/solver/op_handler/conv_handler.py
  4. 65
      colossalai/auto_parallel/solver/op_handler/dot_handler.py
  5. 71
      colossalai/auto_parallel/solver/op_handler/operator_handler.py
  6. 58
      colossalai/auto_parallel/solver/strategies_constructor.py
  7. 3
      colossalai/context/singleton_meta.py
  8. 39
      colossalai/tensor/shape_consistency.py
  9. 11
      tests/test_auto_parallel/test_batch_norm_handler.py
  10. 11
      tests/test_auto_parallel/test_conv_handler.py
  11. 6
      tests/test_auto_parallel/test_cost_graph.py
  12. 11
      tests/test_auto_parallel/test_dot_handler.py
  13. 4
      tests/test_auto_parallel/test_strategies_constructor.py

45
colossalai/auto_parallel/solver/_utils.py

@ -1,8 +1,9 @@
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
import torch
from torch.fx.node import Node
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
from typing import Union, Dict, List
from typing import Union, Dict, List, Optional
def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh,
@ -31,3 +32,45 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic
sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=shape, dim_partition_dict=dim_partition_dict)
return sharding_spec
def generate_resharding_costs(nodes: List[Node],
sharding_specs: List[ShardingSpec],
count_backward: Optional[bool] = True,
dtype: Optional[torch.dtype] = None):
'''
Compute the resharding costs with this specific strategy.
Argument:
nodes (List[Node]): a list of nodes
sharding_spec_for_input(ShardingSpec): a list of ShardingSpec for the nodes.
count_backward (Optional[bool]): whether to include the cost of resharding in the backward pass, default is True. False can be used for inference.
dtype (Optional[torch.dtype]): the data type for cost calculation, default is None.
'''
# The resharding_cost of weight is counted due to sharing weight cases.
resharding_costs = {}
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
# shape consistency manager is a singleton class
shape_consistency_manager = ShapeConsistencyManager()
for input_node, input_spec in zip(nodes, sharding_specs):
resharding_costs[input_node] = []
for strategy in input_node.strategies_vector:
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
# compute the resharding cost during forward phase
_, _, resharding_cost_forward = shape_consistency_manager.shape_consistency(input_sharding_spec, input_spec)
if count_backward:
# In backward phase, we should convert grad with target_spec into input_sharding_spec
_, _, resharding_cost_backward = shape_consistency_manager.shape_consistency(
input_spec, input_sharding_spec)
total_resharding_cost = resharding_cost_forward + resharding_cost_backward
else:
total_resharding_cost = resharding_cost_forward
# we need multiply the size of elem dtype to get correct communication cost
resharding_cost = total_resharding_cost * size_per_elem_bytes
resharding_costs[input_node].append(resharding_cost)
return resharding_costs

57
colossalai/auto_parallel/solver/op_handler/batch_norm_handler.py

@ -4,7 +4,6 @@ import warnings
import torch
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler
from .._utils import generate_sharding_spec
__all__ = ['BatchNormHandler']
@ -115,15 +114,13 @@ class BatchNormHandler(OperatorHandler):
name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
dim_partition_dict_for_input = {1: [mesh_dim_0]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0]}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_0]}
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
@ -156,8 +153,7 @@ class BatchNormHandler(OperatorHandler):
new_name = f'S{mesh_dim_1}S{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
dim_partition_dict_for_output = {0: [mesh_dim_1], 1: [mesh_dim_0]}
new_sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# the computation cost is all the same
new_compute_cost = compute_cost
@ -192,15 +188,13 @@ class BatchNormHandler(OperatorHandler):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
@ -234,15 +228,13 @@ class BatchNormHandler(OperatorHandler):
name = f'RR = RR x R'
dim_partition_dict_for_input = {}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {}
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
@ -273,8 +265,7 @@ class BatchNormHandler(OperatorHandler):
def _construct_batch_sharding_strategies(mesh_dim_list, new_name):
dim_partition_dict_for_output = {0: mesh_dim_list}
new_sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# the computation cost is all the same
new_compute_cost = compute_cost
@ -332,15 +323,13 @@ class BatchNormHandler(OperatorHandler):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'
dim_partition_dict_for_input = {0: [mesh_dim_0]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0]}
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
@ -374,15 +363,13 @@ class BatchNormHandler(OperatorHandler):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
@ -416,15 +403,13 @@ class BatchNormHandler(OperatorHandler):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_1]}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
@ -459,7 +444,7 @@ class BatchNormHandler(OperatorHandler):
Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector.
Example:
norm_handler = BatchNormHandler(node, self.device_mesh, strategies_vector,
norm_handler = BatchNormHandler(node, strategies_vector,
self.shape_consistency_manager)
norm_handler.register_strategy()
for strategy in norm_handler.strategies_vector:

73
colossalai/auto_parallel/solver/op_handler/conv_handler.py

@ -4,7 +4,6 @@ import warnings
import torch
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler
from .._utils import generate_sharding_spec
__all__ = ['ConvHandler']
@ -109,15 +108,13 @@ class ConvHandler(OperatorHandler):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
dim_partition_dict_for_input = {0: [mesh_dim_0]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {1: [mesh_dim_1]}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -158,15 +155,13 @@ class ConvHandler(OperatorHandler):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR'
dim_partition_dict_for_input = {0: [mesh_dim_0]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0]}
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -205,15 +200,13 @@ class ConvHandler(OperatorHandler):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0]}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0]}
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -252,15 +245,13 @@ class ConvHandler(OperatorHandler):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
dim_partition_dict_for_input = {1: [mesh_dim_0]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_1]}
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -296,15 +287,13 @@ class ConvHandler(OperatorHandler):
name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R'
dim_partition_dict_for_input = {1: [mesh_dim_0]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0]}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {}
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -340,15 +329,13 @@ class ConvHandler(OperatorHandler):
name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
dim_partition_dict_for_input = {}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {1: [mesh_dim_0]}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_0]}
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -384,15 +371,13 @@ class ConvHandler(OperatorHandler):
name = f'RR = RR x RR'
dim_partition_dict_for_input = {}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {}
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -426,15 +411,13 @@ class ConvHandler(OperatorHandler):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -475,15 +458,13 @@ class ConvHandler(OperatorHandler):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {}
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])

65
colossalai/auto_parallel/solver/op_handler/dot_handler.py

@ -3,7 +3,6 @@ import torch
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler
from functools import reduce
from .._utils import generate_sharding_spec
__all__ = ['DotHandler']
@ -29,16 +28,14 @@ class DotHandler(OperatorHandler):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
dim_partition_dict_for_input = {0: [mesh_dim_0]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
# linear layer weight is transposed during init
dim_partition_dict_for_weight = {0: [mesh_dim_1]}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_input)
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
@ -69,17 +66,15 @@ class DotHandler(OperatorHandler):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
# since weight of the linear layer is transposed
# the actual dim to be sharded is 1
dim_partition_dict_for_weight = {1: [mesh_dim_0]}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0]}
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
@ -106,15 +101,13 @@ class DotHandler(OperatorHandler):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
dim_partition_dict_for_input = {1: [mesh_dim_0]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_1]}
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_input)
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
@ -141,15 +134,13 @@ class DotHandler(OperatorHandler):
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
dim_partition_dict_for_input = {1: [mesh_dim]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {1: [mesh_dim]}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {}
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
@ -176,15 +167,13 @@ class DotHandler(OperatorHandler):
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
dim_partition_dict_for_input = {}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim]}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim]}
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
@ -211,15 +200,13 @@ class DotHandler(OperatorHandler):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
@ -246,15 +233,13 @@ class DotHandler(OperatorHandler):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {}
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
@ -281,15 +266,13 @@ class DotHandler(OperatorHandler):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
dim_partition_dict_for_input = {}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])

71
colossalai/auto_parallel/solver/op_handler/operator_handler.py

@ -7,6 +7,7 @@ from typing import Dict, List
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from .._utils import generate_resharding_costs, generate_sharding_spec
from ..sharding_strategy import StrategiesVector
@ -17,24 +18,24 @@ class OperatorHandler(ABC):
'''
The OperatorHandler is an abstract class used to generate every possible strategies for an operator node.
Argument:
input_node(Node): the input node in node argument list.
input_index(int): the index of input node in the node argument list.
weight(torch.Tensor): Weight of the node.
output_node(Node): Output_node is the output of the node.
device_mesh(DeviceMesh): A logical view of a physical mesh.
strategies_vector(StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
shape_consistency_manager(ShapeConsistencyManager): ShapeConsistencyManager will give the resharding costs of the different sharding specs.
Args:
node (Node): the input node in node argument list.
device_mesh (DeviceMesh): A logical view of a physical mesh.
strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
handle_backward (Optional[bool]): whether to consider the backward pass. The default value is True. False can be used for inference.
'''
def __init__(self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector,
shape_consistency_manager: ShapeConsistencyManager):
def __init__(self,
node: Node,
device_mesh: DeviceMesh,
strategies_vector: StrategiesVector,
handle_backward: bool = True):
self.node = node
self.predecessor_node = list(node._input_nodes.keys())
self.successor_node = list(node.users.keys())
self.device_mesh = device_mesh
self.strategies_vector = strategies_vector
self.shape_consistency_manager = shape_consistency_manager
self.handle_backward = handle_backward
# find the module and its parameters associated with this node
# this can be used to compute the compute/communication/sharding cost
@ -102,35 +103,23 @@ class OperatorHandler(ABC):
return total_memory_cost, activation_memory_cost, weight_memory_cost
def _generate_resharding_costs(self, sharding_spec_for_input):
'''
Compute the resharding costs with this specific strategy.
Note: The resharding_cost of weight is NOT counted.
Argument:
resharding_costs(Dict[int, List[float]]): The resharding cost generated in this method will be appended into this dictionary.
Resharding_cost[i][j] means the cost of i-th argument in the output node argument list
with j-th strategy in its strategies_vector transforms to sharding spec wanted in this
strategy.
sharding_spec_for_input(ShardingSpec): ShardingSpec of the input node.
'''
def _generate_resharding_costs(self, sharding_specs):
# The resharding_cost of weight is counted due to sharing weight cases.
resharding_costs = {}
dtype = self.node._meta_data.dtype
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
for input_node, input_spec in zip(self.predecessor_node, sharding_spec_for_input):
resharding_costs[input_node] = []
for strategy in input_node.strategies_vector:
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
# compute the resharding cost during forward phase
_, _, resharding_cost_forward = self.shape_consistency_manager.shape_consistency(
input_sharding_spec, input_spec)
# In backward phase, we should convert grad with target_spec into input_sharding_spec
_, _, resharding_cost_backward = self.shape_consistency_manager.shape_consistency(
input_spec, input_sharding_spec)
# we need multiply the size of elem dtype to get correct communication cost
resharding_cost = (resharding_cost_forward + resharding_cost_backward) * size_per_elem_bytes
resharding_costs[input_node].append(resharding_cost)
return resharding_costs
nodes = self.predecessor_node
return generate_resharding_costs(nodes=nodes,
sharding_specs=sharding_specs,
count_backward=self.handle_backward,
dtype=dtype)
def _generate_sharding_spec(self, input_: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
return generate_sharding_spec(input_=input_,
device_mesh=self.device_mesh,
dim_partition_dict=dim_partition_dict)
@abstractmethod
def _generate_compute_cost(self, *args, **kwargs):
"""
Compute the flops involved in the node.
"""
pass

58
colossalai/auto_parallel/solver/strategies_constructor.py

@ -11,7 +11,7 @@ import math
import torch
import operator
from typing import Dict, List
from ._utils import generate_sharding_spec
from ._utils import generate_sharding_spec, generate_resharding_costs
class StrategiesConstructor:
@ -21,12 +21,10 @@ class StrategiesConstructor:
Args:
graph (Graph): a Graph object used for analysis and strategy generation.
device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster.
shape_consistency_manager (ShapeConsistencyManager): a ShapeConsistencyManager object to make sure the sharding specs are consistent.
solver_options (SolverOptions): a SolverOptions object which specifies the preferences for plan searching.
"""
def __init__(self, graph: Graph, device_mesh: DeviceMesh, shape_consistency_manager: ShapeConsistencyManager,
solver_options: SolverOptions):
def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: SolverOptions):
self.graph = graph
assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
self.root_module = self.graph.owning_module
@ -34,27 +32,8 @@ class StrategiesConstructor:
self.device_mesh = device_mesh
self.leaf_strategies = []
self.strategy_map = {}
self.shape_consistency_manager = shape_consistency_manager
self.solver_options = solver_options
def _generate_resharding_costs(self, input_nodes, target_sharding_specs):
'''
Compute the resharding costs with this specific strategy.
Argument:
sharding_spec_for_input(ShardingSpec): ShardingSpec of the input node.
'''
resharding_costs = {}
for input_node, target_sharding_spec in zip(input_nodes, target_sharding_specs):
resharding_costs[input_node] = []
for strategy in input_node.strategies_vector:
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
_, _, resharding_cost = self.shape_consistency_manager.shape_consistency(
input_sharding_spec, target_sharding_spec)
resharding_costs[input_node].append(resharding_cost)
return resharding_costs
def remove_duplicated_strategy(self, strategies_vector):
'''
In build_strategies_and_cost method, we may produce some duplicated strategies.
@ -120,14 +99,13 @@ class StrategiesConstructor:
# conv module
if submod_type in CONV_MODULE_OP:
# use ConvHandler to create sharding strategies for conv module node
conv_handler = ConvHandler(node, self.device_mesh, strategies_vector,
self.shape_consistency_manager)
conv_handler = ConvHandler(node, self.device_mesh, strategies_vector)
conv_handler.register_strategy()
# linear module
elif submod_type in LINEAR_MODULE_OP:
# use DotHandler to create sharding strategies for linear module node
dot_handler = DotHandler(node, self.device_mesh, strategies_vector, self.shape_consistency_manager)
dot_handler = DotHandler(node, self.device_mesh, strategies_vector)
dot_handler.register_strategy()
# element-wise module
@ -158,8 +136,8 @@ class StrategiesConstructor:
# TODO: use meta_info_prop to profile memory cost and compute cost
compute_cost = node._meta_data.numel()
memory_cost = 0
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec])
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec])
# to prevent the resharding happening, set their resharding cost to inf.
resharding_costs[input_node] = [
@ -214,8 +192,8 @@ class StrategiesConstructor:
# TODO: use meta_info_prop to profile memory cost and compute cost
compute_cost = node._meta_data.numel()
memory_cost = 0
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec])
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec])
sharding_strategy = ShardingStrategy(name,
output_sharding_spec,
@ -275,8 +253,8 @@ class StrategiesConstructor:
compute_cost = node._meta_data.numel()
memory_cost = 0
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec])
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec])
# to prevent the resharding happening, set their resharding cost to inf.
resharding_costs[input_node] = [
@ -317,8 +295,8 @@ class StrategiesConstructor:
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
compute_cost = 0
memory_cost = 0
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
[new_input_sharding_spec])
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
[new_input_sharding_spec])
sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec),
compute_cost=compute_cost,
memory_cost=memory_cost,
@ -335,8 +313,8 @@ class StrategiesConstructor:
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
compute_cost = 0
memory_cost = 0
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec])
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec])
sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec),
compute_cost=compute_cost,
memory_cost=memory_cost,
@ -360,8 +338,8 @@ class StrategiesConstructor:
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
compute_cost = 0
memory_cost = 0
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec])
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec])
# to prevent the resharding happening, set their resharding cost to inf.
resharding_costs[input_tensor_node] = [
cost if cost == 0 else math.inf for cost in resharding_costs[input_tensor_node]
@ -397,8 +375,8 @@ class StrategiesConstructor:
output_sharding_spec = input_sharding_specs
# TODO: use meta_info_prop to profile memory cost
memory_cost = 0
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
input_sharding_specs)
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
input_sharding_specs)
# clear the resharding cost for the output node
# TODO: we may remove this in final version

3
colossalai/context/singleton_meta.py

@ -15,4 +15,7 @@ class SingletonMeta(type):
if cls not in cls._instances:
instance = super().__call__(*args, **kwargs)
cls._instances[cls] = instance
else:
assert len(args) == 0 and len(
kwargs) == 0, f'{cls.__name__} is a singleton class and a instance has been created.'
return cls._instances[cls]

39
colossalai/tensor/shape_consistency.py

@ -1,15 +1,22 @@
import torch
from dataclasses import dataclass
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
from enum import Enum
from copy import deepcopy
from typing import Dict, List, Optional, Tuple, Union
from colossalai.context.singleton_meta import SingletonMeta
import torch.distributed as dist
import math
from functools import reduce
import operator
from torch.distributed import ReduceOp
__all__ = [
'CollectiveCommPattern', 'CommSpec', 'ShapeConsistencyManager', 'ShapeConsistencyOptions',
'set_shape_consistency_options'
]
class CollectiveCommPattern(Enum):
ALLGATHER = 'all_gather'
@ -152,14 +159,40 @@ class CommSpec:
tensor.data = tensor
class ShapeConsistencyManager:
@dataclass
class ShapeConsistencyOptions:
"""
ShapeConsistencyOptions is a dataclass which specifies the preferences for shape consistency.
"""
# TODO: shape consistency option is not implemented yet
pass
def set_shape_consistency_options(options: ShapeConsistencyOptions):
"""
Configure the shape consistency manager via function call.
"""
manager = ShapeConsistencyManager()
manager.options = options
def __init__(self, consistency_option=None):
self.consistency_option = consistency_option
class ShapeConsistencyManager(metaclass=SingletonMeta):
def __init__(self):
self._options = None
self.total_communication_cost = 0
self.total_transform_steps = 0
self.cached_spec_pairs_transform_path = {}
@property
def options(self):
return self._options
@options.setter
def options(self, options_: ShapeConsistencyOptions):
assert isinstance(options_, ShapeConsistencyOptions)
self._options = options_
def get_all_all_gather_spec(self, source_spec, orig_cost):
'''
Get all valid sharding specs from source_spec with single all-gather operation, and

11
tests/test_auto_parallel/test_batch_norm_handler.py

@ -8,7 +8,6 @@ from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.auto_parallel.solver.op_handler.batch_norm_handler import BatchNormHandler
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
@ -31,7 +30,6 @@ def test_bn_handler():
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 16, 64, 64))
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer()
model = BNModel(16)
@ -77,10 +75,11 @@ def test_bn_handler():
# generate bn strategy
strategies_vector = StrategiesVector(node=nodes[2])
bn_handler = BatchNormHandler(node=nodes[2],
device_mesh=device_mesh,
strategies_vector=strategies_vector,
shape_consistency_manager=shape_consistency_manager)
bn_handler = BatchNormHandler(
node=nodes[2],
device_mesh=device_mesh,
strategies_vector=strategies_vector,
)
bn_handler.register_strategy()
# ['RS0 = RS0 x S0', 'S1S0 = RS0 x S0', 'RS1 = RS1 x S1', 'S0S1 = RS1 x S1', 'RR = RR x R', 'S0R = RR x R', 'S1R = RR x R', 'S01R = RR x R', 'RS01 = RS01 x S01',
# 'S0R = S0R x R WITH SYNC_BN', 'S1R = S1R x R WITH SYNC_BN', 'S0S1 = S0S1 x S1 WITH SYNC_BN', 'S1S0 = S1S0 x S0 WITH SYNC_BN', 'S01R = S01R x R WITH SYNC_BN']

11
tests/test_auto_parallel/test_conv_handler.py

@ -8,7 +8,6 @@ from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.auto_parallel.solver.op_handler.conv_handler import ConvHandler
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
@ -31,7 +30,6 @@ def test_conv_handler():
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 16, 64, 64))
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer()
model = ConvModel(16, 32)
@ -77,10 +75,11 @@ def test_conv_handler():
# generate conv strategy
strategies_vector = StrategiesVector(node=nodes[2])
conv_handler = ConvHandler(node=nodes[2],
device_mesh=device_mesh,
strategies_vector=strategies_vector,
shape_consistency_manager=shape_consistency_manager)
conv_handler = ConvHandler(
node=nodes[2],
device_mesh=device_mesh,
strategies_vector=strategies_vector,
)
conv_handler.register_strategy()
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0R x RR', 'S1R = S1R x RR', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RR = RS0 x S0R', 'RR = RS1 x S1R', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR', 'S01R = S01R x RR', 'RR = RS01 x S01R']
strategy_name_list = [strategy.name for strategy in conv_handler.strategies_vector]

6
tests/test_auto_parallel/test_cost_graph.py

@ -4,10 +4,7 @@ from torch.fx import GraphModule
import torch.nn as nn
import pytest
from colossalai.fx.proxy import ColoProxy
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
from colossalai.auto_parallel.solver.cost_graph import CostGraph
@ -37,7 +34,6 @@ def test_cost_graph():
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 16, 64, 64))
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer()
model = ConvModel(16, 32)
@ -55,7 +51,7 @@ def test_cost_graph():
gm.recompile()
solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
# (x, mul):{(0, 0): 0}

11
tests/test_auto_parallel/test_dot_handler.py

@ -8,7 +8,6 @@ from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.auto_parallel.solver.op_handler.dot_handler import DotHandler
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
@ -31,7 +30,6 @@ def test_dot_handler():
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 8))
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer()
model = LinearModel(8, 16)
@ -76,10 +74,11 @@ def test_dot_handler():
# generate dot strategy
strategies_vector = StrategiesVector(node=nodes[2])
dot_handler = DotHandler(node=nodes[2],
device_mesh=device_mesh,
strategies_vector=strategies_vector,
shape_consistency_manager=shape_consistency_manager)
dot_handler = DotHandler(
node=nodes[2],
device_mesh=device_mesh,
strategies_vector=strategies_vector,
)
strategies_vector = dot_handler.register_strategy()
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR']

4
tests/test_auto_parallel/test_strategies_constructor.py

@ -8,7 +8,6 @@ from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.auto_parallel.solver.op_handler.conv_handler import CONV_STRATEGIES_LIST
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
from colossalai.auto_parallel.solver.options import SolverOptions
@ -34,7 +33,6 @@ def test_strategies_constructor():
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 16, 64, 64))
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer()
model = ConvModel(16, 32)
@ -49,7 +47,7 @@ def test_strategies_constructor():
gm.recompile()
solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
assert strategies_constructor.leaf_strategies == []
assert strategies_constructor.strategy_map == {}

Loading…
Cancel
Save