mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] process size nodes in runtime pass (#2130)
* [autoparallel] process size nodes in runtime pass * polish codepull/2131/head
parent
536560ccc0
commit
a3c6924deb
|
@ -1,5 +1,6 @@
|
|||
import operator
|
||||
from copy import deepcopy
|
||||
from typing import List
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import torch
|
||||
from torch.fx import symbolic_trace
|
||||
|
@ -20,6 +21,35 @@ from colossalai.tensor.sharding_spec import ShardingSpec
|
|||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
|
||||
def size_processing(size: Union[int, torch.Size],
|
||||
dim_partition_dict: Dict[int, List[int]],
|
||||
device_mesh_info: Dict[int, int],
|
||||
target_dim: int = None,
|
||||
node_name: str = None):
|
||||
"""
|
||||
This method will be invoked during runtime to convert size node value depending on distributed information.
|
||||
"""
|
||||
if target_dim is not None:
|
||||
assert isinstance(size, int)
|
||||
if target_dim in dim_partition_dict:
|
||||
total_shard_size = 1
|
||||
for shard_dim in dim_partition_dict[target_dim]:
|
||||
total_shard_size *= device_mesh_info[shard_dim]
|
||||
size = size * total_shard_size
|
||||
|
||||
else:
|
||||
size = list(size)
|
||||
for dim, dim_size in enumerate(size):
|
||||
if dim in dim_partition_dict:
|
||||
total_shard_size = 1
|
||||
for shard_dim in dim_partition_dict[dim]:
|
||||
total_shard_size *= device_mesh_info[shard_dim]
|
||||
size[dim] = dim_size * total_shard_size
|
||||
size = torch.Size(size)
|
||||
|
||||
return size
|
||||
|
||||
|
||||
def _solution_annotatation(gm: torch.fx.GraphModule,
|
||||
solution: List[int],
|
||||
strategies_constructor: StrategiesConstructor = None):
|
||||
|
@ -103,6 +133,119 @@ def _solution_annotatation(gm: torch.fx.GraphModule,
|
|||
return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict
|
||||
|
||||
|
||||
def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
||||
"""
|
||||
In the auto parallel system, tensors may get shard on different devices, so the size of tensors
|
||||
need to be converted to the size of original tensor and managed by the users, such as torch.view,
|
||||
torch.reshape, etc. These nodes have enough information like input sharding_spec and
|
||||
output sharding_spec to decide how to convert the size value.
|
||||
"""
|
||||
mod_graph = gm.graph
|
||||
nodes = tuple(mod_graph.nodes)
|
||||
node_pairs = {}
|
||||
|
||||
for node in nodes:
|
||||
|
||||
if node.op == 'call_method' and node.target == 'size':
|
||||
# extract useful information from size node
|
||||
# dim_partition_dict will instruct the size value on which
|
||||
# dimension should be enlarged.
|
||||
sharding_spec = node.args[0].sharding_spec
|
||||
dim_partition_dict = sharding_spec.dim_partition_dict
|
||||
|
||||
# there are two usages of torch.Tensor.size:
|
||||
# tensor.size()
|
||||
# tensor.size(dim)
|
||||
# if a target_dim is assigned, then the output will be
|
||||
# in type of int, instead of torch.Size
|
||||
target_dim = None
|
||||
if len(node.args) > 1:
|
||||
target_dim = node.args[1]
|
||||
if target_dim < 0:
|
||||
target_dim += node.args[0]._meta_data.dim()
|
||||
|
||||
# DeviceMesh information instructs the scaling of the size value
|
||||
device_mesh_info = {}
|
||||
for dim, dim_size in enumerate(device_mesh.mesh_shape):
|
||||
device_mesh_info[dim] = dim_size
|
||||
|
||||
with mod_graph.inserting_after(node):
|
||||
size_processing_node = mod_graph.create_node('call_function',
|
||||
size_processing,
|
||||
args=(node, dim_partition_dict, device_mesh_info,
|
||||
target_dim, node.name))
|
||||
# store original node and processing node pair in node_pairs dictioanry
|
||||
# It will be used to replace the original node with processing node in slice object
|
||||
node_pairs[node] = size_processing_node
|
||||
size_processing_node._meta_data = node._meta_data
|
||||
|
||||
user_list = list(node.users.keys())
|
||||
for user in user_list:
|
||||
if user == size_processing_node:
|
||||
continue
|
||||
new_args = list(user.args)
|
||||
new_kwargs = dict(user.kwargs)
|
||||
# the origin node may be a positional argument or key word argument of user node
|
||||
if node in new_args:
|
||||
# substitute the origin node with size_processing_node
|
||||
new_args[new_args.index(node)] = size_processing_node
|
||||
user.args = tuple(new_args)
|
||||
elif str(node) in new_kwargs:
|
||||
# substitute the origin node with size_processing_node
|
||||
new_kwargs[str(node)] = size_processing_node
|
||||
user.kwargs = new_kwargs
|
||||
|
||||
if node.op == 'call_function' and node.target == operator.getitem:
|
||||
|
||||
getitem_index = node.args[1]
|
||||
# slice object is quite special in torch.fx graph,
|
||||
# On one side, we treat slice object same as type of int,
|
||||
# so we do not create a node for slice object. On the other side,
|
||||
# slice object could take fx.Node as its argument. And the user
|
||||
# relationship cannot be tracked in fx graph.
|
||||
# Therefore, I record the node_pairs in this pass, and use the it
|
||||
# to replace the original node argument inside the slice object if
|
||||
# it has been processed in above pass.
|
||||
|
||||
# There are three main usages of operator.getitem:
|
||||
# getitem(input, int)
|
||||
# getitem(input, slice)
|
||||
# getitem(input, Tuple[slice])
|
||||
# In this pass, we need process the last two cases because
|
||||
# node arguments may potentially appear in these cases.
|
||||
if isinstance(getitem_index, slice):
|
||||
new_start, new_stop, new_step = getitem_index.start, getitem_index.stop, getitem_index.step
|
||||
if getitem_index.start in node_pairs:
|
||||
new_start = node_pairs[getitem_index.start]
|
||||
elif getitem_index.stop in node_pairs:
|
||||
new_stop = node_pairs[getitem_index.stop]
|
||||
elif getitem_index.step in node_pairs:
|
||||
new_step = node_pairs[getitem_index.step]
|
||||
new_slice_item = slice(new_start, new_stop, new_step)
|
||||
new_args = (node.args[0], new_slice_item)
|
||||
node.args = new_args
|
||||
|
||||
elif isinstance(getitem_index, (tuple, list)):
|
||||
assert isinstance(getitem_index[0], slice)
|
||||
new_slice_items = []
|
||||
|
||||
for slice_item in getitem_index:
|
||||
new_start, new_stop, new_step = slice_item.start, slice_item.stop, slice_item.step
|
||||
if slice_item.start in node_pairs:
|
||||
new_start = node_pairs[slice_item.start]
|
||||
elif slice_item.stop in node_pairs:
|
||||
new_stop = node_pairs[slice_item.stop]
|
||||
elif slice_item.step in node_pairs:
|
||||
new_step = node_pairs[slice_item.step]
|
||||
new_slice_item = slice(new_start, new_stop, new_step)
|
||||
new_slice_items.append(new_slice_item)
|
||||
|
||||
new_args = (node.args[0], tuple(new_slice_items))
|
||||
node.args = new_args
|
||||
|
||||
return gm
|
||||
|
||||
|
||||
def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
||||
"""
|
||||
This pass will process node args to adapt the distributed tensor layout.
|
||||
|
@ -138,6 +281,7 @@ def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
|||
method = getattr(node.args[0]._meta_data.__class__, node.target)
|
||||
# process the node with (input, *shape) style args
|
||||
if method in (torch.Tensor.view, torch.Tensor.reshape):
|
||||
|
||||
for arg in node.args:
|
||||
if isinstance(arg, Node):
|
||||
if isinstance(arg._meta_data, (int, tuple, list)):
|
||||
|
@ -157,10 +301,18 @@ def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
|||
# 1. torch.view(input, *shape)
|
||||
# 2. torch.view(input, shape)
|
||||
if isinstance(new_args[1], int):
|
||||
new_args[dim + 1] //= total_shard_size
|
||||
# we will skip the dim with -1 value
|
||||
if new_args[dim + 1] == -1:
|
||||
continue
|
||||
else:
|
||||
new_args[dim + 1] //= total_shard_size
|
||||
else:
|
||||
new_args[1] = list(new_args[1])
|
||||
new_args[1][dim] //= total_shard_size
|
||||
# we will skip the dim with -1 value
|
||||
if new_args[1][dim] == -1:
|
||||
continue
|
||||
else:
|
||||
new_args[1][dim] //= total_shard_size
|
||||
node.args = tuple(new_args)
|
||||
|
||||
elif node.op == 'call_function':
|
||||
|
@ -298,6 +450,7 @@ def runtime_preparation_pass(gm: torch.fx.GraphModule,
|
|||
strategies_constructor: StrategiesConstructor = None):
|
||||
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = _solution_annotatation(
|
||||
gm, solution, strategies_constructor)
|
||||
gm = _size_value_converting(gm, device_mesh)
|
||||
gm = _node_args_converting(gm, device_mesh)
|
||||
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
|
||||
# gm = implicit_comm_action_apply(gm)
|
||||
|
|
|
@ -28,8 +28,9 @@ def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStr
|
|||
# switch the dimensions of the transposed weight
|
||||
sharding_spec = strategy.get_sharding_spec_by_name(weight_name)
|
||||
op_data = strategy.get_op_data_by_name(weight_name)
|
||||
assert op_data.logical_shape != op_data.data.shape, \
|
||||
"Expected the logical and physical shape of the linear operator's weight to be different, but found them to be the same"
|
||||
assert op_data.logical_shape[0] == op_data.data.shape[1] and \
|
||||
op_data.logical_shape[1] == op_data.data.shape[0], \
|
||||
"Expected the logical shape of the linear operator's weight is equal to transposed physical shape"
|
||||
dim_size = len(op_data.logical_shape)
|
||||
transpose_partition_dim(sharding_spec, 0, dim_size - 1)
|
||||
return strategy
|
||||
|
|
Loading…
Reference in New Issue