|
|
|
@ -30,7 +30,8 @@ def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStr
|
|
|
|
|
op_data = strategy.get_op_data_by_name(weight_name) |
|
|
|
|
assert op_data.logical_shape != op_data.data.shape, \ |
|
|
|
|
"Expected the logical and physical shape of the linear operator's weight to be different, but found them to be the same" |
|
|
|
|
transpose_partition_dim(sharding_spec, 0, -1) |
|
|
|
|
dim_size = len(op_data.logical_shape) |
|
|
|
|
transpose_partition_dim(sharding_spec, 0, dim_size - 1) |
|
|
|
|
return strategy |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -54,6 +55,29 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
|
|
|
|
|
input_op_data = strategy.get_op_data_by_name(input_name) |
|
|
|
|
output_op_data = strategy.get_op_data_by_name(output_name) |
|
|
|
|
input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name) |
|
|
|
|
output_sharding_spec = strategy.get_sharding_spec_by_name(output_op_data.name) |
|
|
|
|
|
|
|
|
|
# recover the last logical dimension to physical dimension |
|
|
|
|
last_logical_input_dims = len(input_op_data.logical_shape) - 1 |
|
|
|
|
last_logical_output_dims = len(output_op_data.logical_shape) - 1 |
|
|
|
|
last_physical_input_dims = input_op_data.data.dim() - 1 |
|
|
|
|
last_physical_output_dims = output_op_data.data.dim() - 1 |
|
|
|
|
|
|
|
|
|
if last_logical_input_dims in input_sharding_spec.dim_partition_dict: |
|
|
|
|
update_partition_dim( |
|
|
|
|
sharding_spec=input_sharding_spec, |
|
|
|
|
dim_mapping={last_logical_input_dims: last_physical_input_dims}, |
|
|
|
|
physical_shape=input_op_data.data.shape, |
|
|
|
|
inplace=True, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
if last_logical_output_dims in output_sharding_spec.dim_partition_dict: |
|
|
|
|
update_partition_dim( |
|
|
|
|
sharding_spec=output_sharding_spec, |
|
|
|
|
dim_mapping={last_logical_output_dims: last_physical_output_dims}, |
|
|
|
|
physical_shape=output_op_data.data.shape, |
|
|
|
|
inplace=True, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
# get logger for debug message |
|
|
|
|
logger = get_dist_logger() |
|
|
|
@ -198,7 +222,14 @@ class LinearFunctionHandler(NodeHandler):
|
|
|
|
|
type=data_type, |
|
|
|
|
data=self.node.args[1]._meta_data, |
|
|
|
|
logical_shape=self.node.args[1]._meta_data.shape[::-1]) |
|
|
|
|
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) |
|
|
|
|
output_meta_data = self.node._meta_data |
|
|
|
|
output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape |
|
|
|
|
physical_output = OperationData( |
|
|
|
|
name=str(self.node), |
|
|
|
|
type=OperationDataType.OUTPUT, |
|
|
|
|
data=self.node._meta_data, |
|
|
|
|
logical_shape=output_logical_shape, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} |
|
|
|
|
|
|
|
|
@ -219,7 +250,6 @@ class LinearFunctionHandler(NodeHandler):
|
|
|
|
|
# switch the dimensions of the transposed weight |
|
|
|
|
strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy, |
|
|
|
|
weight_name=str(self.node.args[1])) |
|
|
|
|
|
|
|
|
|
# create multiple sharding strategies for the inputs |
|
|
|
|
# as input can be multi-dimensinal and the partition dim is only 2D, |
|
|
|
|
# we need to map the partition at dim 0 to one of the first few dimensions of the input |
|
|
|
|