mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] fix linear logical convert issue (#1857)
parent
c2947dadf1
commit
1b494ad73c
|
@ -52,7 +52,6 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
|
|||
if node.op == 'get_attr':
|
||||
assert len(target_sharding_specs) == 1, f'sharing weight is not supported in current version.'
|
||||
new_sharding_spec = target_sharding_specs[0]
|
||||
user_node = node.strategies_vector.successor_nodes[0]
|
||||
user_strategy = node.strategies_vector.successor_nodes[0].best_strategy
|
||||
op_data_in_user = user_strategy.get_op_data_by_name(str(node))
|
||||
origin_node_sharding_spec_dict[index] = new_sharding_spec
|
||||
|
|
|
@ -30,7 +30,8 @@ def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStr
|
|||
op_data = strategy.get_op_data_by_name(weight_name)
|
||||
assert op_data.logical_shape != op_data.data.shape, \
|
||||
"Expected the logical and physical shape of the linear operator's weight to be different, but found them to be the same"
|
||||
transpose_partition_dim(sharding_spec, 0, -1)
|
||||
dim_size = len(op_data.logical_shape)
|
||||
transpose_partition_dim(sharding_spec, 0, dim_size - 1)
|
||||
return strategy
|
||||
|
||||
|
||||
|
@ -54,6 +55,29 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
|
|||
input_op_data = strategy.get_op_data_by_name(input_name)
|
||||
output_op_data = strategy.get_op_data_by_name(output_name)
|
||||
input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name)
|
||||
output_sharding_spec = strategy.get_sharding_spec_by_name(output_op_data.name)
|
||||
|
||||
# recover the last logical dimension to physical dimension
|
||||
last_logical_input_dims = len(input_op_data.logical_shape) - 1
|
||||
last_logical_output_dims = len(output_op_data.logical_shape) - 1
|
||||
last_physical_input_dims = input_op_data.data.dim() - 1
|
||||
last_physical_output_dims = output_op_data.data.dim() - 1
|
||||
|
||||
if last_logical_input_dims in input_sharding_spec.dim_partition_dict:
|
||||
update_partition_dim(
|
||||
sharding_spec=input_sharding_spec,
|
||||
dim_mapping={last_logical_input_dims: last_physical_input_dims},
|
||||
physical_shape=input_op_data.data.shape,
|
||||
inplace=True,
|
||||
)
|
||||
|
||||
if last_logical_output_dims in output_sharding_spec.dim_partition_dict:
|
||||
update_partition_dim(
|
||||
sharding_spec=output_sharding_spec,
|
||||
dim_mapping={last_logical_output_dims: last_physical_output_dims},
|
||||
physical_shape=output_op_data.data.shape,
|
||||
inplace=True,
|
||||
)
|
||||
|
||||
# get logger for debug message
|
||||
logger = get_dist_logger()
|
||||
|
@ -198,7 +222,14 @@ class LinearFunctionHandler(NodeHandler):
|
|||
type=data_type,
|
||||
data=self.node.args[1]._meta_data,
|
||||
logical_shape=self.node.args[1]._meta_data.shape[::-1])
|
||||
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
|
||||
output_meta_data = self.node._meta_data
|
||||
output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape
|
||||
physical_output = OperationData(
|
||||
name=str(self.node),
|
||||
type=OperationDataType.OUTPUT,
|
||||
data=self.node._meta_data,
|
||||
logical_shape=output_logical_shape,
|
||||
)
|
||||
|
||||
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
|
||||
|
||||
|
@ -219,7 +250,6 @@ class LinearFunctionHandler(NodeHandler):
|
|||
# switch the dimensions of the transposed weight
|
||||
strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy,
|
||||
weight_name=str(self.node.args[1]))
|
||||
|
||||
# create multiple sharding strategies for the inputs
|
||||
# as input can be multi-dimensinal and the partition dim is only 2D,
|
||||
# we need to map the partition at dim 0 to one of the first few dimensions of the input
|
||||
|
|
|
@ -32,7 +32,8 @@ class Solver:
|
|||
memory_budget: float = -1.0,
|
||||
solution_numbers: int = 1,
|
||||
forward_only: bool = False,
|
||||
memory_increasing_coefficient: float = 1.3):
|
||||
memory_increasing_coefficient: float = 1.3,
|
||||
verbose=True):
|
||||
'''
|
||||
Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
|
||||
Argument:
|
||||
|
@ -64,6 +65,7 @@ class Solver:
|
|||
self.last_s_val = None
|
||||
# The last objective value of the best ILP solution.
|
||||
self.last_objective = None
|
||||
self.verbose = verbose
|
||||
|
||||
def _recover_merged_node_strategy(self):
|
||||
'''
|
||||
|
@ -177,7 +179,7 @@ class Solver:
|
|||
# omit initial value for nodes
|
||||
s_init_np = None
|
||||
|
||||
return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np
|
||||
return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np, self.verbose
|
||||
|
||||
def _call_solver_serialized_args(self,
|
||||
node_nums,
|
||||
|
@ -192,7 +194,8 @@ class Solver:
|
|||
memory_costs,
|
||||
resharding_costs,
|
||||
alias_convert_costs,
|
||||
s_init_np=None):
|
||||
s_init_np=None,
|
||||
verbose=True):
|
||||
"""
|
||||
Call the solver with serialized arguments.
|
||||
"""
|
||||
|
@ -407,8 +410,6 @@ class Solver:
|
|||
# if v[idx][row * C + col] > 0.5:
|
||||
# prob += s[i][row] + s[j][col] <= 1
|
||||
|
||||
verbose = True
|
||||
|
||||
msg = verbose
|
||||
time_limit = 600
|
||||
assert "COIN_CMD" in pulp.listSolvers(
|
||||
|
|
|
@ -95,7 +95,7 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
|
|||
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
||||
cost_graph.simplify_graph()
|
||||
graph_analyser = GraphAnalyser(gm)
|
||||
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
|
||||
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, verbose=False)
|
||||
ret = solver.call_solver_serialized_args()
|
||||
solution = list(ret[0])
|
||||
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
|
||||
|
|
Loading…
Reference in New Issue