Browse Source

[autoparallel] fix linear logical convert issue (#1857)

pull/1880/head
YuliangLiu0306 2 years ago committed by GitHub
parent
commit
1b494ad73c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      colossalai/auto_parallel/passes/runtime_preparation_pass.py
  2. 36
      colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py
  3. 11
      colossalai/auto_parallel/tensor_shard/solver/solver.py
  4. 2
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py

1
colossalai/auto_parallel/passes/runtime_preparation_pass.py

@ -52,7 +52,6 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
if node.op == 'get_attr':
assert len(target_sharding_specs) == 1, f'sharing weight is not supported in current version.'
new_sharding_spec = target_sharding_specs[0]
user_node = node.strategies_vector.successor_nodes[0]
user_strategy = node.strategies_vector.successor_nodes[0].best_strategy
op_data_in_user = user_strategy.get_op_data_by_name(str(node))
origin_node_sharding_spec_dict[index] = new_sharding_spec

36
colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py

@ -30,7 +30,8 @@ def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStr
op_data = strategy.get_op_data_by_name(weight_name)
assert op_data.logical_shape != op_data.data.shape, \
"Expected the logical and physical shape of the linear operator's weight to be different, but found them to be the same"
transpose_partition_dim(sharding_spec, 0, -1)
dim_size = len(op_data.logical_shape)
transpose_partition_dim(sharding_spec, 0, dim_size - 1)
return strategy
@ -54,6 +55,29 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
input_op_data = strategy.get_op_data_by_name(input_name)
output_op_data = strategy.get_op_data_by_name(output_name)
input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name)
output_sharding_spec = strategy.get_sharding_spec_by_name(output_op_data.name)
# recover the last logical dimension to physical dimension
last_logical_input_dims = len(input_op_data.logical_shape) - 1
last_logical_output_dims = len(output_op_data.logical_shape) - 1
last_physical_input_dims = input_op_data.data.dim() - 1
last_physical_output_dims = output_op_data.data.dim() - 1
if last_logical_input_dims in input_sharding_spec.dim_partition_dict:
update_partition_dim(
sharding_spec=input_sharding_spec,
dim_mapping={last_logical_input_dims: last_physical_input_dims},
physical_shape=input_op_data.data.shape,
inplace=True,
)
if last_logical_output_dims in output_sharding_spec.dim_partition_dict:
update_partition_dim(
sharding_spec=output_sharding_spec,
dim_mapping={last_logical_output_dims: last_physical_output_dims},
physical_shape=output_op_data.data.shape,
inplace=True,
)
# get logger for debug message
logger = get_dist_logger()
@ -198,7 +222,14 @@ class LinearFunctionHandler(NodeHandler):
type=data_type,
data=self.node.args[1]._meta_data,
logical_shape=self.node.args[1]._meta_data.shape[::-1])
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
output_meta_data = self.node._meta_data
output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape
physical_output = OperationData(
name=str(self.node),
type=OperationDataType.OUTPUT,
data=self.node._meta_data,
logical_shape=output_logical_shape,
)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
@ -219,7 +250,6 @@ class LinearFunctionHandler(NodeHandler):
# switch the dimensions of the transposed weight
strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy,
weight_name=str(self.node.args[1]))
# create multiple sharding strategies for the inputs
# as input can be multi-dimensinal and the partition dim is only 2D,
# we need to map the partition at dim 0 to one of the first few dimensions of the input

11
colossalai/auto_parallel/tensor_shard/solver/solver.py

@ -32,7 +32,8 @@ class Solver:
memory_budget: float = -1.0,
solution_numbers: int = 1,
forward_only: bool = False,
memory_increasing_coefficient: float = 1.3):
memory_increasing_coefficient: float = 1.3,
verbose=True):
'''
Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
Argument:
@ -64,6 +65,7 @@ class Solver:
self.last_s_val = None
# The last objective value of the best ILP solution.
self.last_objective = None
self.verbose = verbose
def _recover_merged_node_strategy(self):
'''
@ -177,7 +179,7 @@ class Solver:
# omit initial value for nodes
s_init_np = None
return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np
return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np, self.verbose
def _call_solver_serialized_args(self,
node_nums,
@ -192,7 +194,8 @@ class Solver:
memory_costs,
resharding_costs,
alias_convert_costs,
s_init_np=None):
s_init_np=None,
verbose=True):
"""
Call the solver with serialized arguments.
"""
@ -407,8 +410,6 @@ class Solver:
# if v[idx][row * C + col] > 0.5:
# prob += s[i][row] + s[j][col] <= 1
verbose = True
msg = verbose
time_limit = 600
assert "COIN_CMD" in pulp.listSolvers(

2
tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py

@ -95,7 +95,7 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
graph_analyser = GraphAnalyser(gm)
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, verbose=False)
ret = solver.call_solver_serialized_args()
solution = list(ret[0])
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(

Loading…
Cancel
Save