[autoparallel] refactor runtime pass (#2644)

* [autoparallel] refactor runtime pass

* add unit test

* polish
pull/2716/head
YuliangLiu0306 2023-02-15 10:36:19 +08:00 committed by GitHub
parent 89f8975fb8
commit cb2c6a2415
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 352 additions and 214 deletions

View File

@ -6,3 +6,8 @@ OUTPUT_SAVED_MOD = [
torch.nn.ReLU,
torch.nn.Softmax,
]
# SHAPE_ARGUMENT_OPS contains node with (input, *shape) style args.
# This list could be extended if any other method has the same
# argument style as view and reshape.
SHAPE_ARGUMENT_OPS = [torch.Tensor.view, torch.Tensor.reshape, torch.reshape]

View File

@ -19,6 +19,8 @@ from colossalai.tensor.comm_spec import _all_reduce
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from .constants import SHAPE_ARGUMENT_OPS
shape_consistency_manager = ShapeConsistencyManager()
@ -51,23 +53,16 @@ def size_processing(size: Union[int, torch.Size],
return size
def _solution_annotatation(gm: torch.fx.GraphModule,
solution: List[int],
strategies_constructor: StrategiesConstructor = None):
def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int],
strategies_constructor: StrategiesConstructor):
"""
This method is used to stick the solution strategy to the nodes and add the information
required in runtime into graph as placeholder nodes.
"""
mod_graph = gm.graph
# TODO: In future PR, strategies_constructor should be a required argument,
# instead of optional argument. This is because we don't need to consider nodes with
# no strategy in runtime preparation pass.
if strategies_constructor is not None:
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
no_strategy_nodes = strategies_constructor.no_strategy_nodes
else:
nodes = tuple(mod_graph.nodes)
no_strategy_nodes = []
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
no_strategy_nodes = strategies_constructor.no_strategy_nodes
# the dict to get origin sharding spec of node
origin_node_sharding_spec_dict = {}
@ -97,6 +92,7 @@ def _solution_annotatation(gm: torch.fx.GraphModule,
target_sharding_specs.append(target_sharding_spec)
sharding_spec_convert_dict[index] = target_sharding_specs
setattr(node, 'target_sharding_specs', target_sharding_specs)
# the get_attr node strategy is kind of pending strategy, which means we will change it
# to the same strategy of the user node.
if node.op == 'get_attr':
@ -134,7 +130,7 @@ def _solution_annotatation(gm: torch.fx.GraphModule,
return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict
def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
"""
In the auto parallel system, tensors may get shard on different devices, so the size of tensors
need to be converted to the size of original tensor and managed by the users, such as torch.view,
@ -145,6 +141,80 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
nodes = tuple(mod_graph.nodes)
node_pairs = {}
# DeviceMesh information instructs the scaling of the size value
device_mesh_info = {}
for dim, dim_size in enumerate(device_mesh.mesh_shape):
device_mesh_info[dim] = dim_size
def _extract_target_dim(node):
'''
A helper function to etract the target dimension from size node.
There are two usages of torch.Tensor.size:
1. tensor.size()
2. tensor.size(dim)
If a target_dim is assigned, then the output will be in type of int, instead of torch.Size.
Otherwise, the output will be in type of torch.Size and this function will return None.
'''
target_dim = None
if len(node.args) > 1:
target_dim = node.args[1]
if target_dim < 0:
target_dim += node.args[0]._meta_data.dim()
return target_dim
def _post_processing(node, size_processing_node):
'''
This function is used to process the dependency between the size node and its users after
inserting the size_process_node.
'''
# store original node and processing node pair in node_pairs dictioanry
# It will be used to replace the original node with processing node in slice object
node_pairs[node] = size_processing_node
size_processing_node._meta_data = node._meta_data
if 'activation_checkpoint' in node.meta:
size_processing_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint']
user_list = list(node.users.keys())
for user in user_list:
if user == size_processing_node:
continue
new_args = list(user.args)
new_kwargs = dict(user.kwargs)
# the origin node may be a positional argument or key word argument of user node
if node in new_args:
# substitute the origin node with size_processing_node
new_args[new_args.index(node)] = size_processing_node
user.args = tuple(new_args)
elif str(node) in new_kwargs:
# substitute the origin node with size_processing_node
new_kwargs[str(node)] = size_processing_node
user.kwargs = new_kwargs
def _update_slice_object_args(slice_object):
'''
This function is used to update the slice object argument list.
If the slice object contains the Node argument, then the size node will be replaced with
'''
if isinstance(slice_object, slice):
start = slice_object.start
stop = slice_object.stop
step = slice_object.step
if start in node_pairs:
start = node_pairs[start]
if stop in node_pairs:
stop = node_pairs[stop]
if step in node_pairs:
step = node_pairs[step]
return slice(start, stop, step)
elif isinstance(slice_object, int):
if slice_object in node_pairs:
return node_pairs[slice_object]
else:
return slice_object
else:
raise RuntimeError(f"Unsupported slice object type: {type(slice_object)}")
for node in nodes:
if node.op == 'call_method' and node.target == 'size':
@ -154,49 +224,15 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
sharding_spec = node.args[0].sharding_spec
dim_partition_dict = sharding_spec.dim_partition_dict
# there are two usages of torch.Tensor.size:
# tensor.size()
# tensor.size(dim)
# if a target_dim is assigned, then the output will be
# in type of int, instead of torch.Size
target_dim = None
if len(node.args) > 1:
target_dim = node.args[1]
if target_dim < 0:
target_dim += node.args[0]._meta_data.dim()
# DeviceMesh information instructs the scaling of the size value
device_mesh_info = {}
for dim, dim_size in enumerate(device_mesh.mesh_shape):
device_mesh_info[dim] = dim_size
target_dim = _extract_target_dim(node)
# insert size_processing node
with mod_graph.inserting_after(node):
size_processing_node = mod_graph.create_node('call_function',
size_processing,
args=(node, dim_partition_dict, device_mesh_info,
target_dim, node.name))
# store original node and processing node pair in node_pairs dictioanry
# It will be used to replace the original node with processing node in slice object
node_pairs[node] = size_processing_node
size_processing_node._meta_data = node._meta_data
if 'activation_checkpoint' in node.meta:
size_processing_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint']
user_list = list(node.users.keys())
for user in user_list:
if user == size_processing_node:
continue
new_args = list(user.args)
new_kwargs = dict(user.kwargs)
# the origin node may be a positional argument or key word argument of user node
if node in new_args:
# substitute the origin node with size_processing_node
new_args[new_args.index(node)] = size_processing_node
user.args = tuple(new_args)
elif str(node) in new_kwargs:
# substitute the origin node with size_processing_node
new_kwargs[str(node)] = size_processing_node
user.kwargs = new_kwargs
_post_processing(node, size_processing_node)
if node.op == 'call_function' and node.target == operator.getitem:
@ -217,14 +253,7 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
# In this pass, we need process the last two cases because
# node arguments may potentially appear in these cases.
if isinstance(getitem_index, slice):
new_start, new_stop, new_step = getitem_index.start, getitem_index.stop, getitem_index.step
if getitem_index.start in node_pairs:
new_start = node_pairs[getitem_index.start]
elif getitem_index.stop in node_pairs:
new_stop = node_pairs[getitem_index.stop]
elif getitem_index.step in node_pairs:
new_step = node_pairs[getitem_index.step]
new_slice_item = slice(new_start, new_stop, new_step)
new_slice_item = _update_slice_object_args(getitem_index)
new_args = (node.args[0], new_slice_item)
node.args = new_args
@ -237,16 +266,7 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
if slice_item is None:
new_slice_items.append(None)
continue
new_start, new_stop, new_step = slice_item.start, slice_item.stop, slice_item.step
if slice_item.start in node_pairs:
new_start = node_pairs[slice_item.start]
elif slice_item.stop in node_pairs:
new_stop = node_pairs[slice_item.stop]
elif slice_item.step in node_pairs:
new_step = node_pairs[slice_item.step]
new_slice_item = slice(new_start, new_stop, new_step)
new_slice_item = _update_slice_object_args(slice_item)
new_slice_items.append(new_slice_item)
new_args = (node.args[0], tuple(new_slice_items))
@ -255,104 +275,109 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
return gm
def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
"""
This pass will process node args to adapt the distributed tensor layout.
"""
mod_graph = gm.graph
nodes = tuple(mod_graph.nodes)
def _extract_info_from_sharding_spec(sharding_spec):
'''
This function is used to extract the dim_partition_dict and device_mesh from
sharding spec instance or a list of sharding spec.
'''
if isinstance(sharding_spec, ShardingSpec):
dim_partition_dict = sharding_spec.dim_partition_dict
device_mesh = sharding_spec.device_mesh
return dim_partition_dict, device_mesh
if sharding_spec is None:
return None, None
assert isinstance(sharding_spec,
(tuple, list)), 'sharding_spec should be type of ShardingSpec, tuple, list or None'
device_mesh = sharding_spec[0].device_mesh
dim_partition_dict = []
for element in sharding_spec:
dim_partition_dict.append(_extract_info_from_sharding_spec(element))
return dim_partition_dict, sharding_spec
def _process_node_arguments(node):
new_args = []
for arg in node.args:
# There are two args style:
# 1. (input, *shape)
# 2. (input, shape)
# We will extract the elements from shape and add them into the new_args
# Finally, the args style of new_args will be unified to (input, *shape)
if isinstance(arg, Node):
if isinstance(arg._meta_data, (tuple, list)):
new_args.extend(arg._meta_data)
elif isinstance(arg._meta_data, int):
new_args.append(arg._meta_data)
else:
new_args.append(arg)
else:
assert isinstance(arg,
(int, tuple, list)), 'The argument in view node should be either type of Node or int.'
if isinstance(arg, (tuple, list)):
new_args.extend(arg)
else:
new_args.append(arg)
return new_args
def _scale_args_adapt_sharding_spec(dim_partition_dict, device_mesh, node):
new_args = _process_node_arguments(node)
if node.op == 'call_method':
args_to_process = list(new_args[1:])
else:
args_to_process = list(new_args)
for dim, shard_dims in dim_partition_dict.items():
total_shard_size = 1
for shard_dim in shard_dims:
total_shard_size *= device_mesh.shape[shard_dim]
# we will skip the dim with -1 value
if args_to_process[dim] == -1:
continue
else:
# TODO: add assertion here to make sure the dim size is divisible by total_shard_size
args_to_process[dim] //= total_shard_size
args_to_process = tuple(args_to_process)
if node.op == 'call_method':
new_args = (new_args[0],) + args_to_process
else:
new_args = args_to_process
node.args = new_args
def _filter_node_with_shape_args(node):
if node.op == 'call_method':
target = getattr(node.args[0]._meta_data.__class__, node.target)
elif node.op == 'call_function':
target = node.target
else:
target = None
if target in SHAPE_ARGUMENT_OPS:
return True
return False
for node in nodes:
# skip the placeholder node added in _solution_annotation pass
if not hasattr(node, 'sharding_spec'):
continue
def _process_sharding_spec(sharding_spec):
if isinstance(sharding_spec, ShardingSpec):
dim_partition_dict = sharding_spec.dim_partition_dict
device_mesh = sharding_spec.device_mesh
return dim_partition_dict, device_mesh
if sharding_spec is None:
return None, None
assert isinstance(sharding_spec,
(tuple, list)), 'sharding_spec should be type of ShardingSpec, tuple, list or None'
device_mesh = sharding_spec[0].device_mesh
dim_partition_dict = []
for element in sharding_spec:
dim_partition_dict.append(_process_sharding_spec(element))
return dim_partition_dict, sharding_spec
output_dim_partition_dict, device_mesh = _process_sharding_spec(node.sharding_spec)
new_args = []
if node.op == 'call_method':
method = getattr(node.args[0]._meta_data.__class__, node.target)
# process the node with (input, *shape) style args
if method in (torch.Tensor.view, torch.Tensor.reshape):
for arg in node.args:
if isinstance(arg, Node):
if isinstance(arg._meta_data, (int, tuple, list)):
new_args.append(arg._meta_data)
else:
new_args.append(arg)
else:
assert isinstance(
arg, (int, tuple, list)), 'The argument in view node should be either type of Node or int.'
new_args.append(arg)
for dim, shard_dims in output_dim_partition_dict.items():
total_shard_size = 1
for shard_dim in shard_dims:
total_shard_size *= device_mesh.shape[shard_dim]
# There are two ways to use torch.view:
# 1. torch.view(input, *shape)
# 2. torch.view(input, shape)
if isinstance(new_args[1], int):
# we will skip the dim with -1 value
if new_args[dim + 1] == -1:
continue
else:
new_args[dim + 1] //= total_shard_size
else:
new_args[1] = list(new_args[1])
# we will skip the dim with -1 value
if new_args[1][dim] == -1:
continue
else:
new_args[1][dim] //= total_shard_size
node.args = tuple(new_args)
elif node.op == 'call_function':
target = node.target
# process the node with (input, torch.Size) style args
if target in (torch.reshape,):
for arg in node.args:
if isinstance(arg, Node):
if isinstance(arg._meta_data, (tuple, list)):
new_args.append(list(arg._meta_data))
else:
new_args.append(arg)
else:
assert isinstance(
arg, (tuple, list)), 'The argument in reshape node should be either type of Node or tuple.'
new_args.append(list(arg))
for dim, shard_dims in output_dim_partition_dict.items():
# we will skip the dim with -1 value
if new_args[1][dim] == -1:
continue
total_shard_size = 1
for shard_dim in shard_dims:
total_shard_size *= device_mesh.shape[shard_dim]
new_args[1][dim] //= total_shard_size
node.args = tuple(new_args)
output_dim_partition_dict, device_mesh = _extract_info_from_sharding_spec(node.sharding_spec)
if _filter_node_with_shape_args(node):
_scale_args_adapt_sharding_spec(output_dim_partition_dict, device_mesh, node)
return gm
def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, overlap=False):
def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, overlap=False):
"""
Apply the sharding action to the module parameters and buffers following the
instructions of solver solution.
@ -361,6 +386,49 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, o
nodes = tuple(mod_graph.nodes)
# This stream is created for overlaping the communication and computation.
reduction_stream = torch.cuda.Stream()
def _add_hook_for_grad_communication(node, param):
comm_actions = node.best_strategy.communication_actions
def _filter_param_to_hook(node, op_data, comm_action):
if node.op == 'call_module' and op_data.type == OperationDataType.PARAM and op_data.name == param.name and comm_action.comm_type == CommType.HOOK:
return True
if node.op == 'get_attr' and isinstance(
node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:
return True
return False
for operation_data, comm_action in comm_actions.items():
comm_spec_to_use = comm_action.comm_spec
# register hook to the parameters
if _filter_param_to_hook(node, operation_data, comm_action):
def wrapper(param, comm_spec, stream, overlap):
def hook_fn(grad):
if overlap:
with torch.cuda.stream(stream):
_all_reduce(grad, comm_spec, async_op=True)
else:
_all_reduce(grad, comm_spec, async_op=False)
param.register_hook(hook_fn)
wrapper(param, comm_spec_to_use, reduction_stream, overlap=overlap)
def _shard_param(param, target_sharding_spec):
# apply the sharding spec of parameters
if target_sharding_spec.dim_partition_dict != {}:
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
setattr(param, 'sharding_spec', origin_sharding_spec)
# TODO: build a ColoParamter class to manager the distributed parameters
# we could use .data here, because all the operations just happen before the real training
# loop, so we don't need to track these operations in the autograd graph.
param = torch.nn.Parameter(
shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec,
target_sharding_spec).detach().clone())
for node in nodes:
if node.op == 'call_module':
target_module = node.graph.owning_module.get_submodule(node.target)
@ -370,36 +438,10 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, o
setattr(target_module, 'processed', True)
for name, param in target_module.named_parameters():
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
# apply the sharding spec of parameters
if target_sharding_spec.dim_partition_dict != {}:
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
setattr(param, 'sharding_spec', origin_sharding_spec)
# TODO: build a ColoParamter class to manager the distributed parameters
# we could use .data here, because all the operations just happen before the real training
# loop, so we don't need to track these operations in the autograd graph.
param = torch.nn.Parameter(
shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec,
target_sharding_spec).detach().clone())
_shard_param(param, target_sharding_spec)
setattr(target_module, name, param)
comm_actions = node.best_strategy.communication_actions
for operation_data, comm_action in comm_actions.items():
comm_spec_to_use = comm_action.comm_spec
# register hook to the parameters
if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK:
def wrapper(param, comm_spec, stream, overlap):
def hook_fn(grad):
if overlap:
with torch.cuda.stream(stream):
_all_reduce(grad, comm_spec, async_op=True)
else:
_all_reduce(grad, comm_spec, async_op=False)
param.register_hook(hook_fn)
wrapper(param, comm_spec_to_use, reduction_stream, overlap=overlap)
_add_hook_for_grad_communication(node, param)
sharded_buffer_dict = {}
# apply the sharding spec of buffers
@ -427,37 +469,12 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, o
target = getattr(target_module, atoms[-1])
target_sharding_spec = node.sharding_spec
if target_sharding_spec.dim_partition_dict != {}:
origin_sharding_spec = ShardingSpec(device_mesh, target.shape, {})
setattr(target, 'sharding_spec', origin_sharding_spec)
# TODO: build a ColoParamter class to manager the distributed parameters
# we could use .data here, because all the operations just happen before the real training
# loop, so we don't need to track these operations in the autograd graph.
target = torch.nn.Parameter(
shape_consistency_manager.apply_for_autoparallel_runtime(target.data, target.sharding_spec,
target_sharding_spec).detach().clone())
_shard_param(target, target_sharding_spec)
assert hasattr(target_module, atoms[-1])
setattr(target_module, atoms[-1], target)
_add_hook_for_grad_communication(node, target)
comm_actions = node.best_strategy.communication_actions
for operation_data, comm_action in comm_actions.items():
comm_spec_to_use = comm_action.comm_spec
# register hook to the parameters
if isinstance(node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:
def wrapper(param, comm_spec, stream, overlap):
def hook_fn(grad):
if overlap:
with torch.cuda.stream(stream):
_all_reduce(grad, comm_spec, async_op=True)
else:
_all_reduce(grad, comm_spec, async_op=False)
param.register_hook(hook_fn)
wrapper(target, comm_spec_to_use, reduction_stream, overlap=overlap)
return gm
@ -471,14 +488,14 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule):
def runtime_preparation_pass(gm: torch.fx.GraphModule,
solution: List[int],
device_mesh: DeviceMesh,
strategies_constructor: StrategiesConstructor = None,
strategies_constructor: StrategiesConstructor,
overlap=False):
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = _solution_annotatation(
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = solution_annotatation_pass(
gm, solution, strategies_constructor)
gm = _size_value_converting(gm, device_mesh)
gm = _node_args_converting(gm, device_mesh)
gm = size_value_converting_pass(gm, device_mesh)
gm = node_args_converting_pass(gm, device_mesh)
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
# gm = implicit_comm_action_apply(gm)
gm = _module_params_sharding(gm, device_mesh, overlap=overlap)
gm = module_params_sharding_pass(gm, device_mesh, overlap=overlap)
return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict

View File

@ -0,0 +1,54 @@
import torch
import torch.nn.functional as F
from colossalai.auto_parallel.passes.runtime_preparation_pass import node_args_converting_pass
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec
class TestModule(torch.nn.Module):
def forward(self, x):
x = x.view(4, 4, 2)
return x
def insert_narrow(gm, x_node):
graph = gm.graph
with graph.inserting_after(x_node):
shard_node = graph.create_node('call_method', 'narrow', args=(x_node, 0, 0, 2), kwargs={})
view_node = list(x_node.users.keys())[0]
new_args = list(view_node.args)
new_args[0] = shard_node
view_node.args = tuple(new_args)
return gm
def test_node_args_converting_pass():
model = TestModule()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
meta_args = {'x': torch.rand(4, 8).to('meta')}
input = torch.rand(4, 8)
tracer = ColoTracer()
graph = tracer.trace(root=model, meta_args=meta_args)
x_node = list(graph.nodes)[0]
view_node = list(graph.nodes)[1]
sharding_spec = ShardingSpec(device_mesh, entire_shape=(4, 8), dim_partition_dict={0: [0]})
setattr(x_node, 'sharding_spec', sharding_spec)
setattr(view_node, 'sharding_spec', sharding_spec)
gm = ColoGraphModule(model, graph)
gm = node_args_converting_pass(gm, device_mesh)
gm = insert_narrow(gm, x_node)
gm.recompile()
output = gm(input)
assert output.shape == torch.Size([2, 4, 2])
if __name__ == '__main__':
test_node_args_converting_pass()

View File

@ -0,0 +1,65 @@
import torch
import torch.nn.functional as F
from colossalai.auto_parallel.passes.runtime_preparation_pass import size_value_converting_pass
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec
class TestModule(torch.nn.Module):
def forward(self, x):
size = x.size()
return size
def insert_narrow(gm, x_node):
graph = gm.graph
with graph.inserting_after(x_node):
shard_node = graph.create_node('call_method', 'narrow', args=(x_node, 0, 0, 2), kwargs={})
size_node = list(x_node.users.keys())[0]
size_node.args = (shard_node,)
return gm
def recover_narrow(gm, narrow_node):
graph = gm.graph
size_node = list(graph.nodes)[2]
x_node = narrow_node.args[0]
size_node.args = (x_node,)
graph.erase_node(narrow_node)
return gm
def test_size_value_converting_pass():
model = TestModule()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
meta_args = {'x': torch.rand(4, 8).to('meta')}
input = torch.rand(4, 8)
tracer = ColoTracer()
graph = tracer.trace(root=model, meta_args=meta_args)
x_node = list(graph.nodes)[0]
x_sharding_spec = ShardingSpec(device_mesh, entire_shape=(4, 8), dim_partition_dict={0: [0]})
setattr(x_node, 'sharding_spec', x_sharding_spec)
gm = ColoGraphModule(model, graph)
gm = insert_narrow(gm, x_node)
gm.recompile()
size = gm(input)
assert size == torch.Size([2, 8])
narrow_node = list(gm.graph.nodes)[1]
gm = recover_narrow(gm, narrow_node)
gm = size_value_converting_pass(gm, device_mesh)
gm = insert_narrow(gm, x_node)
gm.recompile()
size = gm(input)
assert size == torch.Size([4, 8])
if __name__ == '__main__':
test_size_value_converting_pass()

View File

@ -1,12 +1,9 @@
from faulthandler import disable
from functools import partial
from xml.dom import WrongDocumentErr
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from typing_extensions import Self
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (