From 7c18a588c849358f032098f4418f10e5b8fd1335 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Tue, 13 Sep 2022 15:43:22 +0800 Subject: [PATCH] [autoparallel] added generate_sharding_spec to utils (#1590) --- colossalai/auto_parallel/solver/_utils.py | 33 +++++++++ .../solver/op_handler/batch_norm_handler.py | 55 +++++++++----- .../solver/op_handler/conv_handler.py | 73 ++++++++++++------- .../solver/op_handler/dot_handler.py | 65 +++++++++++------ .../solver/op_handler/operator_handler.py | 10 --- .../solver/strategies_constructor.py | 28 ++----- 6 files changed, 161 insertions(+), 103 deletions(-) create mode 100644 colossalai/auto_parallel/solver/_utils.py diff --git a/colossalai/auto_parallel/solver/_utils.py b/colossalai/auto_parallel/solver/_utils.py new file mode 100644 index 000000000..54c9269a4 --- /dev/null +++ b/colossalai/auto_parallel/solver/_utils.py @@ -0,0 +1,33 @@ +import torch +from torch.fx.node import Node +from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.device.device_mesh import DeviceMesh +from typing import Union, Dict, List + + +def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh, + dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec: + """ + Generate the sharding spec of the tensor based on the given dim_partition_dict. + + + Args: + input_ (Union[Node, torch.Tensor]): the input can be a Node object or a PyTorch tensor. If a node is used, it will look for its meta data associated with this node. + device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster. + dim_partition_dict (Dict[int, List[int]]): a dictionary to specify the sharding specs, the key is the tensor dimension and the value is the mesh dimension for sharding. + """ + + if isinstance(input_, Node): + assert hasattr(input_, '_meta_data'), f'The given node has not attribte _meta_data' + meta_tensor = input_._meta_data + assert meta_tensor is not None, "The given node's _meta_data attribute is None" + shape = meta_tensor.shape + elif isinstance(input_, torch.Tensor): + shape = input_.shape + else: + raise TypeError( + f'We cannot generate sharding spec for {type(input_)} type, only torch.fx.Node or torch.Tensor is expected.' + ) + + sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=shape, dim_partition_dict=dim_partition_dict) + return sharding_spec diff --git a/colossalai/auto_parallel/solver/op_handler/batch_norm_handler.py b/colossalai/auto_parallel/solver/op_handler/batch_norm_handler.py index eac2f62cc..8e6b1a7c0 100644 --- a/colossalai/auto_parallel/solver/op_handler/batch_norm_handler.py +++ b/colossalai/auto_parallel/solver/op_handler/batch_norm_handler.py @@ -4,6 +4,7 @@ import warnings import torch from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector from .operator_handler import OperatorHandler +from .._utils import generate_sharding_spec __all__ = ['BatchNormHandler'] @@ -114,13 +115,15 @@ class BatchNormHandler(OperatorHandler): name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}' dim_partition_dict_for_input = {1: [mesh_dim_0]} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, + dim_partition_dict_for_input) dim_partition_dict_for_weight = {0: [mesh_dim_0]} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) dim_partition_dict_for_output = {1: [mesh_dim_0]} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, + dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) @@ -153,7 +156,8 @@ class BatchNormHandler(OperatorHandler): new_name = f'S{mesh_dim_1}S{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}' dim_partition_dict_for_output = {0: [mesh_dim_1], 1: [mesh_dim_0]} - new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + new_sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, + dim_partition_dict_for_output) # the computation cost is all the same new_compute_cost = compute_cost @@ -188,13 +192,15 @@ class BatchNormHandler(OperatorHandler): name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}' dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, + dim_partition_dict_for_input) dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, + dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) @@ -228,13 +234,15 @@ class BatchNormHandler(OperatorHandler): name = f'RR = RR x R' dim_partition_dict_for_input = {} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, + dim_partition_dict_for_input) dim_partition_dict_for_weight = {} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) dim_partition_dict_for_output = {} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, + dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) @@ -265,7 +273,8 @@ class BatchNormHandler(OperatorHandler): def _construct_batch_sharding_strategies(mesh_dim_list, new_name): dim_partition_dict_for_output = {0: mesh_dim_list} - new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + new_sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, + dim_partition_dict_for_output) # the computation cost is all the same new_compute_cost = compute_cost @@ -323,13 +332,15 @@ class BatchNormHandler(OperatorHandler): name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN' dim_partition_dict_for_input = {0: [mesh_dim_0]} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, + dim_partition_dict_for_input) dim_partition_dict_for_weight = {} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0]} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, + dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) @@ -363,13 +374,15 @@ class BatchNormHandler(OperatorHandler): name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN' dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, + dim_partition_dict_for_input) dim_partition_dict_for_weight = {} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, + dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) @@ -403,13 +416,15 @@ class BatchNormHandler(OperatorHandler): name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN' dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, + dim_partition_dict_for_input) dim_partition_dict_for_weight = {0: [mesh_dim_1]} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, + dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) diff --git a/colossalai/auto_parallel/solver/op_handler/conv_handler.py b/colossalai/auto_parallel/solver/op_handler/conv_handler.py index e3f8a6a21..6c1b92d4a 100644 --- a/colossalai/auto_parallel/solver/op_handler/conv_handler.py +++ b/colossalai/auto_parallel/solver/op_handler/conv_handler.py @@ -4,6 +4,7 @@ import warnings import torch from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector from .operator_handler import OperatorHandler +from .._utils import generate_sharding_spec __all__ = ['ConvHandler'] @@ -108,13 +109,15 @@ class ConvHandler(OperatorHandler): name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' dim_partition_dict_for_input = {0: [mesh_dim_0]} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, + dim_partition_dict_for_input) dim_partition_dict_for_weight = {1: [mesh_dim_1]} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, + dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -152,13 +155,15 @@ class ConvHandler(OperatorHandler): name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR' dim_partition_dict_for_input = {0: [mesh_dim_0]} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, + dim_partition_dict_for_input) dim_partition_dict_for_weight = {} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0]} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, + dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -192,13 +197,15 @@ class ConvHandler(OperatorHandler): name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, + dim_partition_dict_for_input) dim_partition_dict_for_weight = {0: [mesh_dim_0]} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0]} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, + dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -235,13 +242,15 @@ class ConvHandler(OperatorHandler): name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' dim_partition_dict_for_input = {1: [mesh_dim_0]} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, + dim_partition_dict_for_input) dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) dim_partition_dict_for_output = {1: [mesh_dim_1]} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, + dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -277,13 +286,15 @@ class ConvHandler(OperatorHandler): name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R' dim_partition_dict_for_input = {1: [mesh_dim_0]} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, + dim_partition_dict_for_input) dim_partition_dict_for_weight = {0: [mesh_dim_0]} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) dim_partition_dict_for_output = {} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, + dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -320,13 +331,15 @@ class ConvHandler(OperatorHandler): name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}' dim_partition_dict_for_input = {} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, + dim_partition_dict_for_input) dim_partition_dict_for_weight = {1: [mesh_dim_0]} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) dim_partition_dict_for_output = {1: [mesh_dim_0]} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, + dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -362,13 +375,15 @@ class ConvHandler(OperatorHandler): name = f'RR = RR x RR' dim_partition_dict_for_input = {} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, + dim_partition_dict_for_input) dim_partition_dict_for_weight = {} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) dim_partition_dict_for_output = {} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, + dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -402,13 +417,15 @@ class ConvHandler(OperatorHandler): name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, + dim_partition_dict_for_input) dim_partition_dict_for_weight = {} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, + dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -443,13 +460,15 @@ class ConvHandler(OperatorHandler): name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, + dim_partition_dict_for_input) dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) dim_partition_dict_for_output = {} - sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh, + dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) diff --git a/colossalai/auto_parallel/solver/op_handler/dot_handler.py b/colossalai/auto_parallel/solver/op_handler/dot_handler.py index 26791df46..9fa99f748 100644 --- a/colossalai/auto_parallel/solver/op_handler/dot_handler.py +++ b/colossalai/auto_parallel/solver/op_handler/dot_handler.py @@ -3,6 +3,7 @@ import torch from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector from .operator_handler import OperatorHandler from functools import reduce +from .._utils import generate_sharding_spec __all__ = ['DotHandler'] @@ -28,14 +29,16 @@ class DotHandler(OperatorHandler): name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' dim_partition_dict_for_input = {0: [mesh_dim_0]} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, + dim_partition_dict_for_input) # linear layer weight is transposed during init dim_partition_dict_for_weight = {0: [mesh_dim_1]} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) + sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh, + dim_partition_dict_for_input) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) @@ -66,15 +69,17 @@ class DotHandler(OperatorHandler): name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, + dim_partition_dict_for_input) # since weight of the linear layer is transposed # the actual dim to be sharded is 1 dim_partition_dict_for_weight = {1: [mesh_dim_0]} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0]} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh, + dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) @@ -101,13 +106,15 @@ class DotHandler(OperatorHandler): name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' dim_partition_dict_for_input = {1: [mesh_dim_0]} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, + dim_partition_dict_for_input) dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) dim_partition_dict_for_output = {1: [mesh_dim_1]} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) + sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh, + dim_partition_dict_for_input) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) @@ -134,13 +141,15 @@ class DotHandler(OperatorHandler): name = f'RR = RS{mesh_dim} x S{mesh_dim}R' dim_partition_dict_for_input = {1: [mesh_dim]} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, + dim_partition_dict_for_input) dim_partition_dict_for_weight = {1: [mesh_dim]} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) dim_partition_dict_for_output = {} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh, + dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) @@ -167,13 +176,15 @@ class DotHandler(OperatorHandler): name = f'RS{mesh_dim} = RR x RS{mesh_dim}' dim_partition_dict_for_input = {} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, + dim_partition_dict_for_input) dim_partition_dict_for_weight = {0: [mesh_dim]} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) dim_partition_dict_for_output = {1: [mesh_dim]} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh, + dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) @@ -200,13 +211,15 @@ class DotHandler(OperatorHandler): name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, + dim_partition_dict_for_input) dim_partition_dict_for_weight = {} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh, + dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) @@ -233,13 +246,15 @@ class DotHandler(OperatorHandler): name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, + dim_partition_dict_for_input) dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) dim_partition_dict_for_output = {} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh, + dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) @@ -266,13 +281,15 @@ class DotHandler(OperatorHandler): name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}' dim_partition_dict_for_input = {} - sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh, + dim_partition_dict_for_input) dim_partition_dict_for_weight = {1: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight) dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh, + dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) diff --git a/colossalai/auto_parallel/solver/op_handler/operator_handler.py b/colossalai/auto_parallel/solver/op_handler/operator_handler.py index 3c0e98cf4..dc397514e 100644 --- a/colossalai/auto_parallel/solver/op_handler/operator_handler.py +++ b/colossalai/auto_parallel/solver/op_handler/operator_handler.py @@ -60,16 +60,6 @@ class OperatorHandler(ABC): """ pass - def _generate_sharding_spec(self, tensor: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec: - """ - Generate the sharding spec of the tensor based on the given dim_partition_dict - where the key is the tensor dimension and the value is the mesh dimension for sharding. - """ - sharding_spec = ShardingSpec(device_mesh=self.device_mesh, - entire_shape=tensor.shape, - dim_partition_dict=dim_partition_dict) - return sharding_spec - def _generate_memory_cost(self, dim_partition_dict_for_output, dim_partition_dict_for_weight): ''' Compute the memory cost per device with this specific strategy. diff --git a/colossalai/auto_parallel/solver/strategies_constructor.py b/colossalai/auto_parallel/solver/strategies_constructor.py index 08867c591..6343e201c 100644 --- a/colossalai/auto_parallel/solver/strategies_constructor.py +++ b/colossalai/auto_parallel/solver/strategies_constructor.py @@ -11,6 +11,7 @@ import math import torch import operator from typing import Dict, List +from ._utils import generate_sharding_spec class StrategiesConstructor: @@ -36,23 +37,6 @@ class StrategiesConstructor: self.shape_consistency_manager = shape_consistency_manager self.solver_options = solver_options - def _generate_sharding_spec(self, node: Node, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec: - """ - Generate the sharding spec of the tensor based on the given dim_partition_dict - where the key is the tensor dimension and the value is the mesh dimension for sharding. - """ - if hasattr(node, '_meta_data'): - meta_tensor = node._meta_data - elif isinstance(node, torch.Tensor): - meta_tensor = node - else: - raise RuntimeError(f'We cannot generate sharding spec for {type(node)} type.') - - sharding_spec = ShardingSpec(device_mesh=self.device_mesh, - entire_shape=meta_tensor.shape, - dim_partition_dict=dim_partition_dict) - return sharding_spec - def _generate_resharding_costs(self, input_nodes, target_sharding_specs): ''' Compute the resharding costs with this specific strategy. @@ -101,7 +85,7 @@ class StrategiesConstructor: # create sharding strategy for placeholder name = 'Replica Placeholder' dim_partition_dict = {} - output_sharding_spec = self._generate_sharding_spec(node, dim_partition_dict) + output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict) # TODO: use meta_info_prop to profile memory cost memory_cost = 0 sharding_strategy_placeholder = ShardingStrategy(name, @@ -120,7 +104,7 @@ class StrategiesConstructor: # create sharding strategy for get_attr name = 'Replica Attribute' dim_partition_dict = {} - output_sharding_spec = self._generate_sharding_spec(node, dim_partition_dict) + output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict) # TODO: use meta_info_prop to profile memory cost memory_cost = 0 sharding_strategy_attribute = ShardingStrategy(name, output_sharding_spec, memory_cost=memory_cost) @@ -167,7 +151,7 @@ class StrategiesConstructor: sharding_spec_checklist.append(input_sharding_spec) dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict) - output_sharding_spec = self._generate_sharding_spec(node, dim_partition_dict) + output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict) name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}' @@ -223,7 +207,7 @@ class StrategiesConstructor: sharding_spec_checklist.append(input_sharding_spec) dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict) - output_sharding_spec = self._generate_sharding_spec(node, dim_partition_dict) + output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict) name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}' @@ -285,7 +269,7 @@ class StrategiesConstructor: continue sharding_spec_checklist.append(input_sharding_spec) dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict) - output_sharding_spec = self._generate_sharding_spec(node, dim_partition_dict) + output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict) name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}' # TODO: use meta_info_prop to profile memory cost and compute cost compute_cost = node._meta_data.numel()