diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py index df2d30cbc..614fb66f4 100644 --- a/colossalai/auto_parallel/passes/runtime_preparation_pass.py +++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py @@ -52,7 +52,6 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]): if node.op == 'get_attr': assert len(target_sharding_specs) == 1, f'sharing weight is not supported in current version.' new_sharding_spec = target_sharding_specs[0] - user_node = node.strategies_vector.successor_nodes[0] user_strategy = node.strategies_vector.successor_nodes[0].best_strategy op_data_in_user = user_strategy.get_op_data_by_name(str(node)) origin_node_sharding_spec_dict[index] = new_sharding_spec diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py index d1ea84b39..5aa769981 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py @@ -30,7 +30,8 @@ def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStr op_data = strategy.get_op_data_by_name(weight_name) assert op_data.logical_shape != op_data.data.shape, \ "Expected the logical and physical shape of the linear operator's weight to be different, but found them to be the same" - transpose_partition_dim(sharding_spec, 0, -1) + dim_size = len(op_data.logical_shape) + transpose_partition_dim(sharding_spec, 0, dim_size - 1) return strategy @@ -54,6 +55,29 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha input_op_data = strategy.get_op_data_by_name(input_name) output_op_data = strategy.get_op_data_by_name(output_name) input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name) + output_sharding_spec = strategy.get_sharding_spec_by_name(output_op_data.name) + + # recover the last logical dimension to physical dimension + last_logical_input_dims = len(input_op_data.logical_shape) - 1 + last_logical_output_dims = len(output_op_data.logical_shape) - 1 + last_physical_input_dims = input_op_data.data.dim() - 1 + last_physical_output_dims = output_op_data.data.dim() - 1 + + if last_logical_input_dims in input_sharding_spec.dim_partition_dict: + update_partition_dim( + sharding_spec=input_sharding_spec, + dim_mapping={last_logical_input_dims: last_physical_input_dims}, + physical_shape=input_op_data.data.shape, + inplace=True, + ) + + if last_logical_output_dims in output_sharding_spec.dim_partition_dict: + update_partition_dim( + sharding_spec=output_sharding_spec, + dim_mapping={last_logical_output_dims: last_physical_output_dims}, + physical_shape=output_op_data.data.shape, + inplace=True, + ) # get logger for debug message logger = get_dist_logger() @@ -198,7 +222,14 @@ class LinearFunctionHandler(NodeHandler): type=data_type, data=self.node.args[1]._meta_data, logical_shape=self.node.args[1]._meta_data.shape[::-1]) - physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + output_meta_data = self.node._meta_data + output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape + physical_output = OperationData( + name=str(self.node), + type=OperationDataType.OUTPUT, + data=self.node._meta_data, + logical_shape=output_logical_shape, + ) mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} @@ -219,7 +250,6 @@ class LinearFunctionHandler(NodeHandler): # switch the dimensions of the transposed weight strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy, weight_name=str(self.node.args[1])) - # create multiple sharding strategies for the inputs # as input can be multi-dimensinal and the partition dim is only 2D, # we need to map the partition at dim 0 to one of the first few dimensions of the input diff --git a/colossalai/auto_parallel/tensor_shard/solver/solver.py b/colossalai/auto_parallel/tensor_shard/solver/solver.py index d6ce5e9fe..7f972884e 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/solver.py +++ b/colossalai/auto_parallel/tensor_shard/solver/solver.py @@ -32,7 +32,8 @@ class Solver: memory_budget: float = -1.0, solution_numbers: int = 1, forward_only: bool = False, - memory_increasing_coefficient: float = 1.3): + memory_increasing_coefficient: float = 1.3, + verbose=True): ''' Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph. Argument: @@ -64,6 +65,7 @@ class Solver: self.last_s_val = None # The last objective value of the best ILP solution. self.last_objective = None + self.verbose = verbose def _recover_merged_node_strategy(self): ''' @@ -177,7 +179,7 @@ class Solver: # omit initial value for nodes s_init_np = None - return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np + return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np, self.verbose def _call_solver_serialized_args(self, node_nums, @@ -192,7 +194,8 @@ class Solver: memory_costs, resharding_costs, alias_convert_costs, - s_init_np=None): + s_init_np=None, + verbose=True): """ Call the solver with serialized arguments. """ @@ -407,8 +410,6 @@ class Solver: # if v[idx][row * C + col] > 0.5: # prob += s[i][row] + s[j][col] <= 1 - verbose = True - msg = verbose time_limit = 600 assert "COIN_CMD" in pulp.listSolvers( diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py index d871db144..b39a7b0cc 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py @@ -95,7 +95,7 @@ def numerical_test_for_node_strategy(model: torch.nn.Module, cost_graph = CostGraph(strategies_constructor.leaf_strategies) cost_graph.simplify_graph() graph_analyser = GraphAnalyser(gm) - solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser) + solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, verbose=False) ret = solver.call_solver_serialized_args() solution = list(ret[0]) gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(