mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] added generate_sharding_spec to utils (#1590)
parent
49ccf8b5f8
commit
7c18a588c8
|
@ -0,0 +1,33 @@
|
|||
import torch
|
||||
from torch.fx.node import Node
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from typing import Union, Dict, List
|
||||
|
||||
|
||||
def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh,
|
||||
dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
|
||||
"""
|
||||
Generate the sharding spec of the tensor based on the given dim_partition_dict.
|
||||
|
||||
|
||||
Args:
|
||||
input_ (Union[Node, torch.Tensor]): the input can be a Node object or a PyTorch tensor. If a node is used, it will look for its meta data associated with this node.
|
||||
device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster.
|
||||
dim_partition_dict (Dict[int, List[int]]): a dictionary to specify the sharding specs, the key is the tensor dimension and the value is the mesh dimension for sharding.
|
||||
"""
|
||||
|
||||
if isinstance(input_, Node):
|
||||
assert hasattr(input_, '_meta_data'), f'The given node has not attribte _meta_data'
|
||||
meta_tensor = input_._meta_data
|
||||
assert meta_tensor is not None, "The given node's _meta_data attribute is None"
|
||||
shape = meta_tensor.shape
|
||||
elif isinstance(input_, torch.Tensor):
|
||||
shape = input_.shape
|
||||
else:
|
||||
raise TypeError(
|
||||
f'We cannot generate sharding spec for {type(input_)} type, only torch.fx.Node or torch.Tensor is expected.'
|
||||
)
|
||||
|
||||
sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=shape, dim_partition_dict=dim_partition_dict)
|
||||
return sharding_spec
|
|
@ -4,6 +4,7 @@ import warnings
|
|||
import torch
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from .operator_handler import OperatorHandler
|
||||
from .._utils import generate_sharding_spec
|
||||
|
||||
__all__ = ['BatchNormHandler']
|
||||
|
||||
|
@ -114,13 +115,15 @@ class BatchNormHandler(OperatorHandler):
|
|||
name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
@ -153,7 +156,8 @@ class BatchNormHandler(OperatorHandler):
|
|||
new_name = f'S{mesh_dim_1}S{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_1], 1: [mesh_dim_0]}
|
||||
new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
new_sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
# the computation cost is all the same
|
||||
new_compute_cost = compute_cost
|
||||
|
||||
|
@ -188,13 +192,15 @@ class BatchNormHandler(OperatorHandler):
|
|||
name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
@ -228,13 +234,15 @@ class BatchNormHandler(OperatorHandler):
|
|||
name = f'RR = RR x R'
|
||||
|
||||
dim_partition_dict_for_input = {}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
@ -265,7 +273,8 @@ class BatchNormHandler(OperatorHandler):
|
|||
|
||||
def _construct_batch_sharding_strategies(mesh_dim_list, new_name):
|
||||
dim_partition_dict_for_output = {0: mesh_dim_list}
|
||||
new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
new_sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
|
||||
# the computation cost is all the same
|
||||
new_compute_cost = compute_cost
|
||||
|
@ -323,13 +332,15 @@ class BatchNormHandler(OperatorHandler):
|
|||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
@ -363,13 +374,15 @@ class BatchNormHandler(OperatorHandler):
|
|||
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
@ -403,13 +416,15 @@ class BatchNormHandler(OperatorHandler):
|
|||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
|
|
@ -4,6 +4,7 @@ import warnings
|
|||
import torch
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from .operator_handler import OperatorHandler
|
||||
from .._utils import generate_sharding_spec
|
||||
|
||||
__all__ = ['ConvHandler']
|
||||
|
||||
|
@ -108,13 +109,15 @@ class ConvHandler(OperatorHandler):
|
|||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {1: [mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -152,13 +155,15 @@ class ConvHandler(OperatorHandler):
|
|||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -192,13 +197,15 @@ class ConvHandler(OperatorHandler):
|
|||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -235,13 +242,15 @@ class ConvHandler(OperatorHandler):
|
|||
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_1]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -277,13 +286,15 @@ class ConvHandler(OperatorHandler):
|
|||
name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -320,13 +331,15 @@ class ConvHandler(OperatorHandler):
|
|||
name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
|
||||
|
||||
dim_partition_dict_for_input = {}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -362,13 +375,15 @@ class ConvHandler(OperatorHandler):
|
|||
name = f'RR = RR x RR'
|
||||
|
||||
dim_partition_dict_for_input = {}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -402,13 +417,15 @@ class ConvHandler(OperatorHandler):
|
|||
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -443,13 +460,15 @@ class ConvHandler(OperatorHandler):
|
|||
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
|
|
@ -3,6 +3,7 @@ import torch
|
|||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from .operator_handler import OperatorHandler
|
||||
from functools import reduce
|
||||
from .._utils import generate_sharding_spec
|
||||
|
||||
__all__ = ['DotHandler']
|
||||
|
||||
|
@ -28,14 +29,16 @@ class DotHandler(OperatorHandler):
|
|||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
|
||||
# linear layer weight is transposed during init
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
@ -66,15 +69,17 @@ class DotHandler(OperatorHandler):
|
|||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
|
||||
# since weight of the linear layer is transposed
|
||||
# the actual dim to be sharded is 1
|
||||
dim_partition_dict_for_weight = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
@ -101,13 +106,15 @@ class DotHandler(OperatorHandler):
|
|||
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_1]}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
@ -134,13 +141,15 @@ class DotHandler(OperatorHandler):
|
|||
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {1: [mesh_dim]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
@ -167,13 +176,15 @@ class DotHandler(OperatorHandler):
|
|||
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
|
||||
|
||||
dim_partition_dict_for_input = {}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim]}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
@ -200,13 +211,15 @@ class DotHandler(OperatorHandler):
|
|||
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
@ -233,13 +246,15 @@ class DotHandler(OperatorHandler):
|
|||
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
@ -266,13 +281,15 @@ class DotHandler(OperatorHandler):
|
|||
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
|
||||
dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {1: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
|
||||
dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||
|
|
|
@ -60,16 +60,6 @@ class OperatorHandler(ABC):
|
|||
"""
|
||||
pass
|
||||
|
||||
def _generate_sharding_spec(self, tensor: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
|
||||
"""
|
||||
Generate the sharding spec of the tensor based on the given dim_partition_dict
|
||||
where the key is the tensor dimension and the value is the mesh dimension for sharding.
|
||||
"""
|
||||
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
|
||||
entire_shape=tensor.shape,
|
||||
dim_partition_dict=dim_partition_dict)
|
||||
return sharding_spec
|
||||
|
||||
def _generate_memory_cost(self, dim_partition_dict_for_output, dim_partition_dict_for_weight):
|
||||
'''
|
||||
Compute the memory cost per device with this specific strategy.
|
||||
|
|
|
@ -11,6 +11,7 @@ import math
|
|||
import torch
|
||||
import operator
|
||||
from typing import Dict, List
|
||||
from ._utils import generate_sharding_spec
|
||||
|
||||
|
||||
class StrategiesConstructor:
|
||||
|
@ -36,23 +37,6 @@ class StrategiesConstructor:
|
|||
self.shape_consistency_manager = shape_consistency_manager
|
||||
self.solver_options = solver_options
|
||||
|
||||
def _generate_sharding_spec(self, node: Node, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
|
||||
"""
|
||||
Generate the sharding spec of the tensor based on the given dim_partition_dict
|
||||
where the key is the tensor dimension and the value is the mesh dimension for sharding.
|
||||
"""
|
||||
if hasattr(node, '_meta_data'):
|
||||
meta_tensor = node._meta_data
|
||||
elif isinstance(node, torch.Tensor):
|
||||
meta_tensor = node
|
||||
else:
|
||||
raise RuntimeError(f'We cannot generate sharding spec for {type(node)} type.')
|
||||
|
||||
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
|
||||
entire_shape=meta_tensor.shape,
|
||||
dim_partition_dict=dim_partition_dict)
|
||||
return sharding_spec
|
||||
|
||||
def _generate_resharding_costs(self, input_nodes, target_sharding_specs):
|
||||
'''
|
||||
Compute the resharding costs with this specific strategy.
|
||||
|
@ -101,7 +85,7 @@ class StrategiesConstructor:
|
|||
# create sharding strategy for placeholder
|
||||
name = 'Replica Placeholder'
|
||||
dim_partition_dict = {}
|
||||
output_sharding_spec = self._generate_sharding_spec(node, dim_partition_dict)
|
||||
output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict)
|
||||
# TODO: use meta_info_prop to profile memory cost
|
||||
memory_cost = 0
|
||||
sharding_strategy_placeholder = ShardingStrategy(name,
|
||||
|
@ -120,7 +104,7 @@ class StrategiesConstructor:
|
|||
# create sharding strategy for get_attr
|
||||
name = 'Replica Attribute'
|
||||
dim_partition_dict = {}
|
||||
output_sharding_spec = self._generate_sharding_spec(node, dim_partition_dict)
|
||||
output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict)
|
||||
# TODO: use meta_info_prop to profile memory cost
|
||||
memory_cost = 0
|
||||
sharding_strategy_attribute = ShardingStrategy(name, output_sharding_spec, memory_cost=memory_cost)
|
||||
|
@ -167,7 +151,7 @@ class StrategiesConstructor:
|
|||
|
||||
sharding_spec_checklist.append(input_sharding_spec)
|
||||
dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict)
|
||||
output_sharding_spec = self._generate_sharding_spec(node, dim_partition_dict)
|
||||
output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict)
|
||||
|
||||
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}'
|
||||
|
||||
|
@ -223,7 +207,7 @@ class StrategiesConstructor:
|
|||
|
||||
sharding_spec_checklist.append(input_sharding_spec)
|
||||
dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict)
|
||||
output_sharding_spec = self._generate_sharding_spec(node, dim_partition_dict)
|
||||
output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict)
|
||||
|
||||
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}'
|
||||
|
||||
|
@ -285,7 +269,7 @@ class StrategiesConstructor:
|
|||
continue
|
||||
sharding_spec_checklist.append(input_sharding_spec)
|
||||
dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict)
|
||||
output_sharding_spec = self._generate_sharding_spec(node, dim_partition_dict)
|
||||
output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict)
|
||||
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}'
|
||||
# TODO: use meta_info_prop to profile memory cost and compute cost
|
||||
compute_cost = node._meta_data.numel()
|
||||
|
|
Loading…
Reference in New Issue