mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] support more flexible data type (#1967)
parent
5bec3b2168
commit
05020e50d0
|
@ -4,6 +4,7 @@ from .binary_elementwise_handler import BinaryElementwiseHandler
|
|||
from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler
|
||||
from .conv_handler import ConvFunctionHandler, ConvModuleHandler
|
||||
from .getatrr_handler import GetattrHandler
|
||||
from .getitem_handler import GetItemHandler
|
||||
from .layer_norm_handler import LayerNormModuleHandler
|
||||
from .linear_handler import LinearFunctionHandler, LinearModuleHandler
|
||||
from .matmul_handler import MatMulHandler
|
||||
|
@ -19,5 +20,6 @@ __all__ = [
|
|||
'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler',
|
||||
'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler',
|
||||
'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler',
|
||||
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler', 'GetattrHandler'
|
||||
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler',
|
||||
'GetItemHandler', 'GetattrHandler'
|
||||
]
|
||||
|
|
|
@ -51,6 +51,10 @@ class NodeHandler(ABC):
|
|||
for node in self.predecessor_node:
|
||||
node_name = str(node)
|
||||
# get the current sharding spec generated by this node handler
|
||||
|
||||
# TODO: we need to check this in future
|
||||
if not isinstance(node._meta_data, torch.Tensor):
|
||||
continue
|
||||
op_data = strategy.get_op_data_by_name(node_name)
|
||||
current_sharding_spec = strategy.sharding_specs[op_data]
|
||||
|
||||
|
|
|
@ -11,7 +11,9 @@ __all__ = ['ReshapeHandler']
|
|||
|
||||
|
||||
@operator_registry.register(torch.reshape)
|
||||
@operator_registry.register(torch.Tensor.split)
|
||||
@operator_registry.register(torch.flatten)
|
||||
@operator_registry.register(torch.Tensor.transpose)
|
||||
@operator_registry.register(torch.Tensor.permute)
|
||||
@operator_registry.register(torch.Tensor.view)
|
||||
@operator_registry.register(torch.nn.AdaptiveAvgPool2d)
|
||||
|
@ -26,6 +28,24 @@ class ReshapeHandler(NodeHandler):
|
|||
generators.append(ReshapeGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
|
||||
return generators
|
||||
|
||||
def infer_logical_shape(self, data):
|
||||
"""
|
||||
This function is used to infer logical shape for operands.
|
||||
|
||||
Notes: This function is only used for the operands whose data are not only in type of tensor,
|
||||
such as tuple of tensor.
|
||||
"""
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data.shape
|
||||
else:
|
||||
assert isinstance(data, tuple), "input_data should be a tuple of tensor or a tensor."
|
||||
logical_shape = []
|
||||
for tensor in data:
|
||||
assert isinstance(tensor, torch.Tensor), "input_data should be a tuple of tensor or a tensor."
|
||||
logical_shape.append(tensor.shape)
|
||||
logical_shape = tuple(logical_shape)
|
||||
return logical_shape
|
||||
|
||||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||
# use transposed shape for strategies
|
||||
# the strategies will be transformed back to its original shape in self.post_process
|
||||
|
@ -36,10 +56,19 @@ class ReshapeHandler(NodeHandler):
|
|||
else:
|
||||
data_type = OperationDataType.ARG
|
||||
|
||||
input_data = self.node.args[0]._meta_data
|
||||
input_logical_shape = self.infer_logical_shape(input_data)
|
||||
physical_input_operand = OperationData(name=str(self.node.args[0]),
|
||||
type=data_type,
|
||||
data=self.node.args[0]._meta_data)
|
||||
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
|
||||
data=input_data,
|
||||
logical_shape=input_logical_shape)
|
||||
|
||||
output_data = self.node._meta_data
|
||||
output_logical_shape = self.infer_logical_shape(output_data)
|
||||
physical_output = OperationData(name=str(self.node),
|
||||
type=OperationDataType.OUTPUT,
|
||||
data=output_data,
|
||||
logical_shape=output_logical_shape)
|
||||
|
||||
mapping = {"input": physical_input_operand, "output": physical_output}
|
||||
|
||||
|
|
|
@ -81,9 +81,10 @@ class StrategyGenerator(ABC):
|
|||
for logical_shape, dim_partition_dict_element in zip(op_data.logical_shape, dim_partition_dict):
|
||||
dim_size = len(logical_shape)
|
||||
dim_partition_dict_element = convert_dim_partition_dict(dim_size, dim_partition_dict_element)
|
||||
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
|
||||
entire_shape=logical_shape,
|
||||
dim_partition_dict=dim_partition_dict_element)
|
||||
sharding_spec_element = ShardingSpec(device_mesh=self.device_mesh,
|
||||
entire_shape=logical_shape,
|
||||
dim_partition_dict=dim_partition_dict_element)
|
||||
sharding_spec.append(sharding_spec_element)
|
||||
else:
|
||||
assert isinstance(
|
||||
op_data.data, torch.Tensor
|
||||
|
@ -193,18 +194,40 @@ class StrategyGenerator(ABC):
|
|||
Args:
|
||||
strategy (ShardingStrategy): the ShardingStrategy generated.
|
||||
key (str): the name of the operation data defined by the generator.
|
||||
|
||||
"""
|
||||
op_data = self.op_data[key]
|
||||
sharded_shape = strategy.sharding_specs[op_data].get_sharded_shape_per_device()
|
||||
|
||||
if len(sharded_shape) == 0:
|
||||
num_elements = 1
|
||||
def _compute_size_in_bytes_helper(sharding_spec, meta_data):
|
||||
sharded_shape = sharding_spec.get_sharded_shape_per_device()
|
||||
if len(sharded_shape) == 0:
|
||||
num_elements = 1
|
||||
else:
|
||||
num_elements = reduce(operator.mul, sharded_shape)
|
||||
dtype = getattr(meta_data, 'dtype')
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
return num_elements * size_per_elem_bytes
|
||||
|
||||
if isinstance(op_data.data, tuple):
|
||||
assert isinstance(strategy.sharding_specs[op_data], list), \
|
||||
'sharding_spec of op_data should be a list of sharding specs if op_data.data is a tuple.'
|
||||
total_bytes = 0
|
||||
for index, sharding_spec in enumerate(strategy.sharding_specs[op_data]):
|
||||
meta_data = op_data.data[index]
|
||||
if isinstance(meta_data, torch.Tensor):
|
||||
element_bytes = _compute_size_in_bytes_helper(sharding_spec, meta_data)
|
||||
else:
|
||||
# if meta_data is not a tensor, we count the memroy as 0
|
||||
element_bytes = 0
|
||||
total_bytes += element_bytes
|
||||
|
||||
else:
|
||||
num_elements = reduce(operator.mul, sharded_shape)
|
||||
dtype = self.op_data[key].data.dtype
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
return num_elements * size_per_elem_bytes
|
||||
if isinstance(op_data.data, torch.Tensor):
|
||||
total_bytes = _compute_size_in_bytes_helper(strategy.sharding_specs[op_data], op_data.data)
|
||||
else:
|
||||
# if op_data.data is not a tensor, we count the memroy as 0
|
||||
total_bytes = 0
|
||||
|
||||
return total_bytes
|
||||
|
||||
def generate(self) -> List[ShardingStrategy]:
|
||||
"""
|
||||
|
|
|
@ -10,6 +10,8 @@ from .strategy import StrategyGenerator, UnaryElementwiseGenerator
|
|||
__all__ = ['UnaryElementwiseHandler']
|
||||
|
||||
|
||||
@operator_registry.register(torch.Tensor.to)
|
||||
@operator_registry.register(torch.Tensor.type)
|
||||
@operator_registry.register(torch.abs)
|
||||
@operator_registry.register(torch.nn.ReLU)
|
||||
class UnaryElementwiseHandler(NodeHandler):
|
||||
|
|
Loading…
Reference in New Issue