[autoparallel] added generate_sharding_spec to utils (#1590)

pull/1591/head
Frank Lee 2022-09-13 15:43:22 +08:00 committed by GitHub
parent 49ccf8b5f8
commit 7c18a588c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 161 additions and 103 deletions

View File

@ -0,0 +1,33 @@
import torch
from torch.fx.node import Node
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
from typing import Union, Dict, List
def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh,
dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
"""
Generate the sharding spec of the tensor based on the given dim_partition_dict.
Args:
input_ (Union[Node, torch.Tensor]): the input can be a Node object or a PyTorch tensor. If a node is used, it will look for its meta data associated with this node.
device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster.
dim_partition_dict (Dict[int, List[int]]): a dictionary to specify the sharding specs, the key is the tensor dimension and the value is the mesh dimension for sharding.
"""
if isinstance(input_, Node):
assert hasattr(input_, '_meta_data'), f'The given node has not attribte _meta_data'
meta_tensor = input_._meta_data
assert meta_tensor is not None, "The given node's _meta_data attribute is None"
shape = meta_tensor.shape
elif isinstance(input_, torch.Tensor):
shape = input_.shape
else:
raise TypeError(
f'We cannot generate sharding spec for {type(input_)} type, only torch.fx.Node or torch.Tensor is expected.'
)
sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=shape, dim_partition_dict=dim_partition_dict)
return sharding_spec

View File

@ -4,6 +4,7 @@ import warnings
import torch
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler
from .._utils import generate_sharding_spec
__all__ = ['BatchNormHandler']
@ -114,13 +115,15 @@ class BatchNormHandler(OperatorHandler):
name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
dim_partition_dict_for_input = {1: [mesh_dim_0]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_0]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
@ -153,7 +156,8 @@ class BatchNormHandler(OperatorHandler):
new_name = f'S{mesh_dim_1}S{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
dim_partition_dict_for_output = {0: [mesh_dim_1], 1: [mesh_dim_0]}
new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
new_sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
# the computation cost is all the same
new_compute_cost = compute_cost
@ -188,13 +192,15 @@ class BatchNormHandler(OperatorHandler):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
@ -228,13 +234,15 @@ class BatchNormHandler(OperatorHandler):
name = f'RR = RR x R'
dim_partition_dict_for_input = {}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
@ -265,7 +273,8 @@ class BatchNormHandler(OperatorHandler):
def _construct_batch_sharding_strategies(mesh_dim_list, new_name):
dim_partition_dict_for_output = {0: mesh_dim_list}
new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
new_sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
# the computation cost is all the same
new_compute_cost = compute_cost
@ -323,13 +332,15 @@ class BatchNormHandler(OperatorHandler):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'
dim_partition_dict_for_input = {0: [mesh_dim_0]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
@ -363,13 +374,15 @@ class BatchNormHandler(OperatorHandler):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
@ -403,13 +416,15 @@ class BatchNormHandler(OperatorHandler):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])

View File

@ -4,6 +4,7 @@ import warnings
import torch
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler
from .._utils import generate_sharding_spec
__all__ = ['ConvHandler']
@ -108,13 +109,15 @@ class ConvHandler(OperatorHandler):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
dim_partition_dict_for_input = {0: [mesh_dim_0]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {1: [mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -152,13 +155,15 @@ class ConvHandler(OperatorHandler):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR'
dim_partition_dict_for_input = {0: [mesh_dim_0]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -192,13 +197,15 @@ class ConvHandler(OperatorHandler):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -235,13 +242,15 @@ class ConvHandler(OperatorHandler):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
dim_partition_dict_for_input = {1: [mesh_dim_0]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -277,13 +286,15 @@ class ConvHandler(OperatorHandler):
name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R'
dim_partition_dict_for_input = {1: [mesh_dim_0]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -320,13 +331,15 @@ class ConvHandler(OperatorHandler):
name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
dim_partition_dict_for_input = {}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {1: [mesh_dim_0]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_0]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -362,13 +375,15 @@ class ConvHandler(OperatorHandler):
name = f'RR = RR x RR'
dim_partition_dict_for_input = {}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -402,13 +417,15 @@ class ConvHandler(OperatorHandler):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -443,13 +460,15 @@ class ConvHandler(OperatorHandler):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])

View File

@ -3,6 +3,7 @@ import torch
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler
from functools import reduce
from .._utils import generate_sharding_spec
__all__ = ['DotHandler']
@ -28,14 +29,16 @@ class DotHandler(OperatorHandler):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
dim_partition_dict_for_input = {0: [mesh_dim_0]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
# linear layer weight is transposed during init
dim_partition_dict_for_weight = {0: [mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_input)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
@ -66,15 +69,17 @@ class DotHandler(OperatorHandler):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
# since weight of the linear layer is transposed
# the actual dim to be sharded is 1
dim_partition_dict_for_weight = {1: [mesh_dim_0]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
@ -101,13 +106,15 @@ class DotHandler(OperatorHandler):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
dim_partition_dict_for_input = {1: [mesh_dim_0]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_1]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_input)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
@ -134,13 +141,15 @@ class DotHandler(OperatorHandler):
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
dim_partition_dict_for_input = {1: [mesh_dim]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {1: [mesh_dim]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
@ -167,13 +176,15 @@ class DotHandler(OperatorHandler):
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
dim_partition_dict_for_input = {}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
@ -200,13 +211,15 @@ class DotHandler(OperatorHandler):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
@ -233,13 +246,15 @@ class DotHandler(OperatorHandler):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
@ -266,13 +281,15 @@ class DotHandler(OperatorHandler):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
dim_partition_dict_for_input = {}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
dim_partition_dict_for_weight = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
sharding_spec_for_ouput = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])

View File

@ -60,16 +60,6 @@ class OperatorHandler(ABC):
"""
pass
def _generate_sharding_spec(self, tensor: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
"""
Generate the sharding spec of the tensor based on the given dim_partition_dict
where the key is the tensor dimension and the value is the mesh dimension for sharding.
"""
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=tensor.shape,
dim_partition_dict=dim_partition_dict)
return sharding_spec
def _generate_memory_cost(self, dim_partition_dict_for_output, dim_partition_dict_for_weight):
'''
Compute the memory cost per device with this specific strategy.

View File

@ -11,6 +11,7 @@ import math
import torch
import operator
from typing import Dict, List
from ._utils import generate_sharding_spec
class StrategiesConstructor:
@ -36,23 +37,6 @@ class StrategiesConstructor:
self.shape_consistency_manager = shape_consistency_manager
self.solver_options = solver_options
def _generate_sharding_spec(self, node: Node, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
"""
Generate the sharding spec of the tensor based on the given dim_partition_dict
where the key is the tensor dimension and the value is the mesh dimension for sharding.
"""
if hasattr(node, '_meta_data'):
meta_tensor = node._meta_data
elif isinstance(node, torch.Tensor):
meta_tensor = node
else:
raise RuntimeError(f'We cannot generate sharding spec for {type(node)} type.')
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=meta_tensor.shape,
dim_partition_dict=dim_partition_dict)
return sharding_spec
def _generate_resharding_costs(self, input_nodes, target_sharding_specs):
'''
Compute the resharding costs with this specific strategy.
@ -101,7 +85,7 @@ class StrategiesConstructor:
# create sharding strategy for placeholder
name = 'Replica Placeholder'
dim_partition_dict = {}
output_sharding_spec = self._generate_sharding_spec(node, dim_partition_dict)
output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict)
# TODO: use meta_info_prop to profile memory cost
memory_cost = 0
sharding_strategy_placeholder = ShardingStrategy(name,
@ -120,7 +104,7 @@ class StrategiesConstructor:
# create sharding strategy for get_attr
name = 'Replica Attribute'
dim_partition_dict = {}
output_sharding_spec = self._generate_sharding_spec(node, dim_partition_dict)
output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict)
# TODO: use meta_info_prop to profile memory cost
memory_cost = 0
sharding_strategy_attribute = ShardingStrategy(name, output_sharding_spec, memory_cost=memory_cost)
@ -167,7 +151,7 @@ class StrategiesConstructor:
sharding_spec_checklist.append(input_sharding_spec)
dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict)
output_sharding_spec = self._generate_sharding_spec(node, dim_partition_dict)
output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict)
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}'
@ -223,7 +207,7 @@ class StrategiesConstructor:
sharding_spec_checklist.append(input_sharding_spec)
dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict)
output_sharding_spec = self._generate_sharding_spec(node, dim_partition_dict)
output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict)
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}'
@ -285,7 +269,7 @@ class StrategiesConstructor:
continue
sharding_spec_checklist.append(input_sharding_spec)
dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict)
output_sharding_spec = self._generate_sharding_spec(node, dim_partition_dict)
output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict)
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}'
# TODO: use meta_info_prop to profile memory cost and compute cost
compute_cost = node._meta_data.numel()