[autoparallel] process size nodes in runtime pass (#2130)

* [autoparallel] process size nodes in runtime pass

* polish code
pull/2131/head
YuliangLiu0306 2 years ago committed by GitHub
parent 536560ccc0
commit a3c6924deb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,5 +1,6 @@
import operator
from copy import deepcopy from copy import deepcopy
from typing import List from typing import Dict, List, Union
import torch import torch
from torch.fx import symbolic_trace from torch.fx import symbolic_trace
@ -20,6 +21,35 @@ from colossalai.tensor.sharding_spec import ShardingSpec
shape_consistency_manager = ShapeConsistencyManager() shape_consistency_manager = ShapeConsistencyManager()
def size_processing(size: Union[int, torch.Size],
dim_partition_dict: Dict[int, List[int]],
device_mesh_info: Dict[int, int],
target_dim: int = None,
node_name: str = None):
"""
This method will be invoked during runtime to convert size node value depending on distributed information.
"""
if target_dim is not None:
assert isinstance(size, int)
if target_dim in dim_partition_dict:
total_shard_size = 1
for shard_dim in dim_partition_dict[target_dim]:
total_shard_size *= device_mesh_info[shard_dim]
size = size * total_shard_size
else:
size = list(size)
for dim, dim_size in enumerate(size):
if dim in dim_partition_dict:
total_shard_size = 1
for shard_dim in dim_partition_dict[dim]:
total_shard_size *= device_mesh_info[shard_dim]
size[dim] = dim_size * total_shard_size
size = torch.Size(size)
return size
def _solution_annotatation(gm: torch.fx.GraphModule, def _solution_annotatation(gm: torch.fx.GraphModule,
solution: List[int], solution: List[int],
strategies_constructor: StrategiesConstructor = None): strategies_constructor: StrategiesConstructor = None):
@ -103,6 +133,119 @@ def _solution_annotatation(gm: torch.fx.GraphModule,
return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict
def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
"""
In the auto parallel system, tensors may get shard on different devices, so the size of tensors
need to be converted to the size of original tensor and managed by the users, such as torch.view,
torch.reshape, etc. These nodes have enough information like input sharding_spec and
output sharding_spec to decide how to convert the size value.
"""
mod_graph = gm.graph
nodes = tuple(mod_graph.nodes)
node_pairs = {}
for node in nodes:
if node.op == 'call_method' and node.target == 'size':
# extract useful information from size node
# dim_partition_dict will instruct the size value on which
# dimension should be enlarged.
sharding_spec = node.args[0].sharding_spec
dim_partition_dict = sharding_spec.dim_partition_dict
# there are two usages of torch.Tensor.size:
# tensor.size()
# tensor.size(dim)
# if a target_dim is assigned, then the output will be
# in type of int, instead of torch.Size
target_dim = None
if len(node.args) > 1:
target_dim = node.args[1]
if target_dim < 0:
target_dim += node.args[0]._meta_data.dim()
# DeviceMesh information instructs the scaling of the size value
device_mesh_info = {}
for dim, dim_size in enumerate(device_mesh.mesh_shape):
device_mesh_info[dim] = dim_size
with mod_graph.inserting_after(node):
size_processing_node = mod_graph.create_node('call_function',
size_processing,
args=(node, dim_partition_dict, device_mesh_info,
target_dim, node.name))
# store original node and processing node pair in node_pairs dictioanry
# It will be used to replace the original node with processing node in slice object
node_pairs[node] = size_processing_node
size_processing_node._meta_data = node._meta_data
user_list = list(node.users.keys())
for user in user_list:
if user == size_processing_node:
continue
new_args = list(user.args)
new_kwargs = dict(user.kwargs)
# the origin node may be a positional argument or key word argument of user node
if node in new_args:
# substitute the origin node with size_processing_node
new_args[new_args.index(node)] = size_processing_node
user.args = tuple(new_args)
elif str(node) in new_kwargs:
# substitute the origin node with size_processing_node
new_kwargs[str(node)] = size_processing_node
user.kwargs = new_kwargs
if node.op == 'call_function' and node.target == operator.getitem:
getitem_index = node.args[1]
# slice object is quite special in torch.fx graph,
# On one side, we treat slice object same as type of int,
# so we do not create a node for slice object. On the other side,
# slice object could take fx.Node as its argument. And the user
# relationship cannot be tracked in fx graph.
# Therefore, I record the node_pairs in this pass, and use the it
# to replace the original node argument inside the slice object if
# it has been processed in above pass.
# There are three main usages of operator.getitem:
# getitem(input, int)
# getitem(input, slice)
# getitem(input, Tuple[slice])
# In this pass, we need process the last two cases because
# node arguments may potentially appear in these cases.
if isinstance(getitem_index, slice):
new_start, new_stop, new_step = getitem_index.start, getitem_index.stop, getitem_index.step
if getitem_index.start in node_pairs:
new_start = node_pairs[getitem_index.start]
elif getitem_index.stop in node_pairs:
new_stop = node_pairs[getitem_index.stop]
elif getitem_index.step in node_pairs:
new_step = node_pairs[getitem_index.step]
new_slice_item = slice(new_start, new_stop, new_step)
new_args = (node.args[0], new_slice_item)
node.args = new_args
elif isinstance(getitem_index, (tuple, list)):
assert isinstance(getitem_index[0], slice)
new_slice_items = []
for slice_item in getitem_index:
new_start, new_stop, new_step = slice_item.start, slice_item.stop, slice_item.step
if slice_item.start in node_pairs:
new_start = node_pairs[slice_item.start]
elif slice_item.stop in node_pairs:
new_stop = node_pairs[slice_item.stop]
elif slice_item.step in node_pairs:
new_step = node_pairs[slice_item.step]
new_slice_item = slice(new_start, new_stop, new_step)
new_slice_items.append(new_slice_item)
new_args = (node.args[0], tuple(new_slice_items))
node.args = new_args
return gm
def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
""" """
This pass will process node args to adapt the distributed tensor layout. This pass will process node args to adapt the distributed tensor layout.
@ -138,6 +281,7 @@ def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
method = getattr(node.args[0]._meta_data.__class__, node.target) method = getattr(node.args[0]._meta_data.__class__, node.target)
# process the node with (input, *shape) style args # process the node with (input, *shape) style args
if method in (torch.Tensor.view, torch.Tensor.reshape): if method in (torch.Tensor.view, torch.Tensor.reshape):
for arg in node.args: for arg in node.args:
if isinstance(arg, Node): if isinstance(arg, Node):
if isinstance(arg._meta_data, (int, tuple, list)): if isinstance(arg._meta_data, (int, tuple, list)):
@ -157,10 +301,18 @@ def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
# 1. torch.view(input, *shape) # 1. torch.view(input, *shape)
# 2. torch.view(input, shape) # 2. torch.view(input, shape)
if isinstance(new_args[1], int): if isinstance(new_args[1], int):
new_args[dim + 1] //= total_shard_size # we will skip the dim with -1 value
if new_args[dim + 1] == -1:
continue
else:
new_args[dim + 1] //= total_shard_size
else: else:
new_args[1] = list(new_args[1]) new_args[1] = list(new_args[1])
new_args[1][dim] //= total_shard_size # we will skip the dim with -1 value
if new_args[1][dim] == -1:
continue
else:
new_args[1][dim] //= total_shard_size
node.args = tuple(new_args) node.args = tuple(new_args)
elif node.op == 'call_function': elif node.op == 'call_function':
@ -298,6 +450,7 @@ def runtime_preparation_pass(gm: torch.fx.GraphModule,
strategies_constructor: StrategiesConstructor = None): strategies_constructor: StrategiesConstructor = None):
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = _solution_annotatation( gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = _solution_annotatation(
gm, solution, strategies_constructor) gm, solution, strategies_constructor)
gm = _size_value_converting(gm, device_mesh)
gm = _node_args_converting(gm, device_mesh) gm = _node_args_converting(gm, device_mesh)
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed. # TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
# gm = implicit_comm_action_apply(gm) # gm = implicit_comm_action_apply(gm)

@ -28,8 +28,9 @@ def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStr
# switch the dimensions of the transposed weight # switch the dimensions of the transposed weight
sharding_spec = strategy.get_sharding_spec_by_name(weight_name) sharding_spec = strategy.get_sharding_spec_by_name(weight_name)
op_data = strategy.get_op_data_by_name(weight_name) op_data = strategy.get_op_data_by_name(weight_name)
assert op_data.logical_shape != op_data.data.shape, \ assert op_data.logical_shape[0] == op_data.data.shape[1] and \
"Expected the logical and physical shape of the linear operator's weight to be different, but found them to be the same" op_data.logical_shape[1] == op_data.data.shape[0], \
"Expected the logical shape of the linear operator's weight is equal to transposed physical shape"
dim_size = len(op_data.logical_shape) dim_size = len(op_data.logical_shape)
transpose_partition_dim(sharding_spec, 0, dim_size - 1) transpose_partition_dim(sharding_spec, 0, dim_size - 1)
return strategy return strategy

Loading…
Cancel
Save