|
|
|
@ -1,5 +1,6 @@
|
|
|
|
|
import operator
|
|
|
|
|
from copy import deepcopy
|
|
|
|
|
from typing import List
|
|
|
|
|
from typing import Dict, List, Union
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from torch.fx import symbolic_trace
|
|
|
|
@ -20,6 +21,35 @@ from colossalai.tensor.sharding_spec import ShardingSpec
|
|
|
|
|
shape_consistency_manager = ShapeConsistencyManager()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def size_processing(size: Union[int, torch.Size],
|
|
|
|
|
dim_partition_dict: Dict[int, List[int]],
|
|
|
|
|
device_mesh_info: Dict[int, int],
|
|
|
|
|
target_dim: int = None,
|
|
|
|
|
node_name: str = None):
|
|
|
|
|
"""
|
|
|
|
|
This method will be invoked during runtime to convert size node value depending on distributed information.
|
|
|
|
|
"""
|
|
|
|
|
if target_dim is not None:
|
|
|
|
|
assert isinstance(size, int)
|
|
|
|
|
if target_dim in dim_partition_dict:
|
|
|
|
|
total_shard_size = 1
|
|
|
|
|
for shard_dim in dim_partition_dict[target_dim]:
|
|
|
|
|
total_shard_size *= device_mesh_info[shard_dim]
|
|
|
|
|
size = size * total_shard_size
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
size = list(size)
|
|
|
|
|
for dim, dim_size in enumerate(size):
|
|
|
|
|
if dim in dim_partition_dict:
|
|
|
|
|
total_shard_size = 1
|
|
|
|
|
for shard_dim in dim_partition_dict[dim]:
|
|
|
|
|
total_shard_size *= device_mesh_info[shard_dim]
|
|
|
|
|
size[dim] = dim_size * total_shard_size
|
|
|
|
|
size = torch.Size(size)
|
|
|
|
|
|
|
|
|
|
return size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _solution_annotatation(gm: torch.fx.GraphModule,
|
|
|
|
|
solution: List[int],
|
|
|
|
|
strategies_constructor: StrategiesConstructor = None):
|
|
|
|
@ -103,6 +133,119 @@ def _solution_annotatation(gm: torch.fx.GraphModule,
|
|
|
|
|
return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
|
|
|
|
"""
|
|
|
|
|
In the auto parallel system, tensors may get shard on different devices, so the size of tensors
|
|
|
|
|
need to be converted to the size of original tensor and managed by the users, such as torch.view,
|
|
|
|
|
torch.reshape, etc. These nodes have enough information like input sharding_spec and
|
|
|
|
|
output sharding_spec to decide how to convert the size value.
|
|
|
|
|
"""
|
|
|
|
|
mod_graph = gm.graph
|
|
|
|
|
nodes = tuple(mod_graph.nodes)
|
|
|
|
|
node_pairs = {}
|
|
|
|
|
|
|
|
|
|
for node in nodes:
|
|
|
|
|
|
|
|
|
|
if node.op == 'call_method' and node.target == 'size':
|
|
|
|
|
# extract useful information from size node
|
|
|
|
|
# dim_partition_dict will instruct the size value on which
|
|
|
|
|
# dimension should be enlarged.
|
|
|
|
|
sharding_spec = node.args[0].sharding_spec
|
|
|
|
|
dim_partition_dict = sharding_spec.dim_partition_dict
|
|
|
|
|
|
|
|
|
|
# there are two usages of torch.Tensor.size:
|
|
|
|
|
# tensor.size()
|
|
|
|
|
# tensor.size(dim)
|
|
|
|
|
# if a target_dim is assigned, then the output will be
|
|
|
|
|
# in type of int, instead of torch.Size
|
|
|
|
|
target_dim = None
|
|
|
|
|
if len(node.args) > 1:
|
|
|
|
|
target_dim = node.args[1]
|
|
|
|
|
if target_dim < 0:
|
|
|
|
|
target_dim += node.args[0]._meta_data.dim()
|
|
|
|
|
|
|
|
|
|
# DeviceMesh information instructs the scaling of the size value
|
|
|
|
|
device_mesh_info = {}
|
|
|
|
|
for dim, dim_size in enumerate(device_mesh.mesh_shape):
|
|
|
|
|
device_mesh_info[dim] = dim_size
|
|
|
|
|
|
|
|
|
|
with mod_graph.inserting_after(node):
|
|
|
|
|
size_processing_node = mod_graph.create_node('call_function',
|
|
|
|
|
size_processing,
|
|
|
|
|
args=(node, dim_partition_dict, device_mesh_info,
|
|
|
|
|
target_dim, node.name))
|
|
|
|
|
# store original node and processing node pair in node_pairs dictioanry
|
|
|
|
|
# It will be used to replace the original node with processing node in slice object
|
|
|
|
|
node_pairs[node] = size_processing_node
|
|
|
|
|
size_processing_node._meta_data = node._meta_data
|
|
|
|
|
|
|
|
|
|
user_list = list(node.users.keys())
|
|
|
|
|
for user in user_list:
|
|
|
|
|
if user == size_processing_node:
|
|
|
|
|
continue
|
|
|
|
|
new_args = list(user.args)
|
|
|
|
|
new_kwargs = dict(user.kwargs)
|
|
|
|
|
# the origin node may be a positional argument or key word argument of user node
|
|
|
|
|
if node in new_args:
|
|
|
|
|
# substitute the origin node with size_processing_node
|
|
|
|
|
new_args[new_args.index(node)] = size_processing_node
|
|
|
|
|
user.args = tuple(new_args)
|
|
|
|
|
elif str(node) in new_kwargs:
|
|
|
|
|
# substitute the origin node with size_processing_node
|
|
|
|
|
new_kwargs[str(node)] = size_processing_node
|
|
|
|
|
user.kwargs = new_kwargs
|
|
|
|
|
|
|
|
|
|
if node.op == 'call_function' and node.target == operator.getitem:
|
|
|
|
|
|
|
|
|
|
getitem_index = node.args[1]
|
|
|
|
|
# slice object is quite special in torch.fx graph,
|
|
|
|
|
# On one side, we treat slice object same as type of int,
|
|
|
|
|
# so we do not create a node for slice object. On the other side,
|
|
|
|
|
# slice object could take fx.Node as its argument. And the user
|
|
|
|
|
# relationship cannot be tracked in fx graph.
|
|
|
|
|
# Therefore, I record the node_pairs in this pass, and use the it
|
|
|
|
|
# to replace the original node argument inside the slice object if
|
|
|
|
|
# it has been processed in above pass.
|
|
|
|
|
|
|
|
|
|
# There are three main usages of operator.getitem:
|
|
|
|
|
# getitem(input, int)
|
|
|
|
|
# getitem(input, slice)
|
|
|
|
|
# getitem(input, Tuple[slice])
|
|
|
|
|
# In this pass, we need process the last two cases because
|
|
|
|
|
# node arguments may potentially appear in these cases.
|
|
|
|
|
if isinstance(getitem_index, slice):
|
|
|
|
|
new_start, new_stop, new_step = getitem_index.start, getitem_index.stop, getitem_index.step
|
|
|
|
|
if getitem_index.start in node_pairs:
|
|
|
|
|
new_start = node_pairs[getitem_index.start]
|
|
|
|
|
elif getitem_index.stop in node_pairs:
|
|
|
|
|
new_stop = node_pairs[getitem_index.stop]
|
|
|
|
|
elif getitem_index.step in node_pairs:
|
|
|
|
|
new_step = node_pairs[getitem_index.step]
|
|
|
|
|
new_slice_item = slice(new_start, new_stop, new_step)
|
|
|
|
|
new_args = (node.args[0], new_slice_item)
|
|
|
|
|
node.args = new_args
|
|
|
|
|
|
|
|
|
|
elif isinstance(getitem_index, (tuple, list)):
|
|
|
|
|
assert isinstance(getitem_index[0], slice)
|
|
|
|
|
new_slice_items = []
|
|
|
|
|
|
|
|
|
|
for slice_item in getitem_index:
|
|
|
|
|
new_start, new_stop, new_step = slice_item.start, slice_item.stop, slice_item.step
|
|
|
|
|
if slice_item.start in node_pairs:
|
|
|
|
|
new_start = node_pairs[slice_item.start]
|
|
|
|
|
elif slice_item.stop in node_pairs:
|
|
|
|
|
new_stop = node_pairs[slice_item.stop]
|
|
|
|
|
elif slice_item.step in node_pairs:
|
|
|
|
|
new_step = node_pairs[slice_item.step]
|
|
|
|
|
new_slice_item = slice(new_start, new_stop, new_step)
|
|
|
|
|
new_slice_items.append(new_slice_item)
|
|
|
|
|
|
|
|
|
|
new_args = (node.args[0], tuple(new_slice_items))
|
|
|
|
|
node.args = new_args
|
|
|
|
|
|
|
|
|
|
return gm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
|
|
|
|
"""
|
|
|
|
|
This pass will process node args to adapt the distributed tensor layout.
|
|
|
|
@ -138,6 +281,7 @@ def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
|
|
|
|
method = getattr(node.args[0]._meta_data.__class__, node.target)
|
|
|
|
|
# process the node with (input, *shape) style args
|
|
|
|
|
if method in (torch.Tensor.view, torch.Tensor.reshape):
|
|
|
|
|
|
|
|
|
|
for arg in node.args:
|
|
|
|
|
if isinstance(arg, Node):
|
|
|
|
|
if isinstance(arg._meta_data, (int, tuple, list)):
|
|
|
|
@ -157,9 +301,17 @@ def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
|
|
|
|
# 1. torch.view(input, *shape)
|
|
|
|
|
# 2. torch.view(input, shape)
|
|
|
|
|
if isinstance(new_args[1], int):
|
|
|
|
|
# we will skip the dim with -1 value
|
|
|
|
|
if new_args[dim + 1] == -1:
|
|
|
|
|
continue
|
|
|
|
|
else:
|
|
|
|
|
new_args[dim + 1] //= total_shard_size
|
|
|
|
|
else:
|
|
|
|
|
new_args[1] = list(new_args[1])
|
|
|
|
|
# we will skip the dim with -1 value
|
|
|
|
|
if new_args[1][dim] == -1:
|
|
|
|
|
continue
|
|
|
|
|
else:
|
|
|
|
|
new_args[1][dim] //= total_shard_size
|
|
|
|
|
node.args = tuple(new_args)
|
|
|
|
|
|
|
|
|
@ -298,6 +450,7 @@ def runtime_preparation_pass(gm: torch.fx.GraphModule,
|
|
|
|
|
strategies_constructor: StrategiesConstructor = None):
|
|
|
|
|
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = _solution_annotatation(
|
|
|
|
|
gm, solution, strategies_constructor)
|
|
|
|
|
gm = _size_value_converting(gm, device_mesh)
|
|
|
|
|
gm = _node_args_converting(gm, device_mesh)
|
|
|
|
|
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
|
|
|
|
|
# gm = implicit_comm_action_apply(gm)
|
|
|
|
|