[autoparallel] support more flexible data type (#1967)

pull/1987/head^2
YuliangLiu0306 2 years ago committed by GitHub
parent 5bec3b2168
commit 05020e50d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,6 +4,7 @@ from .binary_elementwise_handler import BinaryElementwiseHandler
from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler
from .conv_handler import ConvFunctionHandler, ConvModuleHandler
from .getatrr_handler import GetattrHandler
from .getitem_handler import GetItemHandler
from .layer_norm_handler import LayerNormModuleHandler
from .linear_handler import LinearFunctionHandler, LinearModuleHandler
from .matmul_handler import MatMulHandler
@ -19,5 +20,6 @@ __all__ = [
'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler',
'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler',
'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler',
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler', 'GetattrHandler'
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler',
'GetItemHandler', 'GetattrHandler'
]

@ -51,6 +51,10 @@ class NodeHandler(ABC):
for node in self.predecessor_node:
node_name = str(node)
# get the current sharding spec generated by this node handler
# TODO: we need to check this in future
if not isinstance(node._meta_data, torch.Tensor):
continue
op_data = strategy.get_op_data_by_name(node_name)
current_sharding_spec = strategy.sharding_specs[op_data]

@ -11,7 +11,9 @@ __all__ = ['ReshapeHandler']
@operator_registry.register(torch.reshape)
@operator_registry.register(torch.Tensor.split)
@operator_registry.register(torch.flatten)
@operator_registry.register(torch.Tensor.transpose)
@operator_registry.register(torch.Tensor.permute)
@operator_registry.register(torch.Tensor.view)
@operator_registry.register(torch.nn.AdaptiveAvgPool2d)
@ -26,6 +28,24 @@ class ReshapeHandler(NodeHandler):
generators.append(ReshapeGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
return generators
def infer_logical_shape(self, data):
"""
This function is used to infer logical shape for operands.
Notes: This function is only used for the operands whose data are not only in type of tensor,
such as tuple of tensor.
"""
if isinstance(data, torch.Tensor):
return data.shape
else:
assert isinstance(data, tuple), "input_data should be a tuple of tensor or a tensor."
logical_shape = []
for tensor in data:
assert isinstance(tensor, torch.Tensor), "input_data should be a tuple of tensor or a tensor."
logical_shape.append(tensor.shape)
logical_shape = tuple(logical_shape)
return logical_shape
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
@ -36,10 +56,19 @@ class ReshapeHandler(NodeHandler):
else:
data_type = OperationDataType.ARG
input_data = self.node.args[0]._meta_data
input_logical_shape = self.infer_logical_shape(input_data)
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=data_type,
data=self.node.args[0]._meta_data)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
data=input_data,
logical_shape=input_logical_shape)
output_data = self.node._meta_data
output_logical_shape = self.infer_logical_shape(output_data)
physical_output = OperationData(name=str(self.node),
type=OperationDataType.OUTPUT,
data=output_data,
logical_shape=output_logical_shape)
mapping = {"input": physical_input_operand, "output": physical_output}

@ -81,9 +81,10 @@ class StrategyGenerator(ABC):
for logical_shape, dim_partition_dict_element in zip(op_data.logical_shape, dim_partition_dict):
dim_size = len(logical_shape)
dim_partition_dict_element = convert_dim_partition_dict(dim_size, dim_partition_dict_element)
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=logical_shape,
dim_partition_dict=dim_partition_dict_element)
sharding_spec_element = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=logical_shape,
dim_partition_dict=dim_partition_dict_element)
sharding_spec.append(sharding_spec_element)
else:
assert isinstance(
op_data.data, torch.Tensor
@ -193,18 +194,40 @@ class StrategyGenerator(ABC):
Args:
strategy (ShardingStrategy): the ShardingStrategy generated.
key (str): the name of the operation data defined by the generator.
"""
op_data = self.op_data[key]
sharded_shape = strategy.sharding_specs[op_data].get_sharded_shape_per_device()
if len(sharded_shape) == 0:
num_elements = 1
def _compute_size_in_bytes_helper(sharding_spec, meta_data):
sharded_shape = sharding_spec.get_sharded_shape_per_device()
if len(sharded_shape) == 0:
num_elements = 1
else:
num_elements = reduce(operator.mul, sharded_shape)
dtype = getattr(meta_data, 'dtype')
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
return num_elements * size_per_elem_bytes
if isinstance(op_data.data, tuple):
assert isinstance(strategy.sharding_specs[op_data], list), \
'sharding_spec of op_data should be a list of sharding specs if op_data.data is a tuple.'
total_bytes = 0
for index, sharding_spec in enumerate(strategy.sharding_specs[op_data]):
meta_data = op_data.data[index]
if isinstance(meta_data, torch.Tensor):
element_bytes = _compute_size_in_bytes_helper(sharding_spec, meta_data)
else:
# if meta_data is not a tensor, we count the memroy as 0
element_bytes = 0
total_bytes += element_bytes
else:
num_elements = reduce(operator.mul, sharded_shape)
dtype = self.op_data[key].data.dtype
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
return num_elements * size_per_elem_bytes
if isinstance(op_data.data, torch.Tensor):
total_bytes = _compute_size_in_bytes_helper(strategy.sharding_specs[op_data], op_data.data)
else:
# if op_data.data is not a tensor, we count the memroy as 0
total_bytes = 0
return total_bytes
def generate(self) -> List[ShardingStrategy]:
"""

@ -10,6 +10,8 @@ from .strategy import StrategyGenerator, UnaryElementwiseGenerator
__all__ = ['UnaryElementwiseHandler']
@operator_registry.register(torch.Tensor.to)
@operator_registry.register(torch.Tensor.type)
@operator_registry.register(torch.abs)
@operator_registry.register(torch.nn.ReLU)
class UnaryElementwiseHandler(NodeHandler):

Loading…
Cancel
Save