[autoparallel] add numerical test for node strategies (#1760)

* [autoparallel] add numerical test for node strategies

* polish code

* polish code
pull/1766/head
YuliangLiu0306 2022-10-27 10:42:54 +08:00 committed by GitHub
parent 25952b67d7
commit b4cc59b61e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 283 additions and 60 deletions

View File

@ -24,7 +24,6 @@ def runtime_apply(node: Node, origin_dict: Dict, input_dict: Dict, node_index: i
"""
origin_sharding_spec = origin_dict[node_index]
target_sharding_spec = input_dict[node_index][user_node_index]
return shape_consistency_manager.apply_for_autoparallel_runtime(node, origin_sharding_spec, target_sharding_spec)
@ -81,18 +80,24 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
if not hasattr(node, 'best_strategy') or node.op == 'output':
continue
for user_node in node.strategies_vector.successor_nodes:
user_node_index = user_node.strategies_vector.predecessor_nodes.index(node)
for user_node_index, user_node in enumerate(node.strategies_vector.successor_nodes):
with mod_graph.inserting_before(user_node):
shape_consistency_node = mod_graph.create_node('call_function',
runtime_apply,
args=(node, origin_dict_node, input_dict_node,
node_to_index_dict[node], user_node_index))
origin_index_args = user_node.args.index(node)
new_args = list(user_node.args)
new_args[origin_index_args] = shape_consistency_node
user_node.args = new_args
new_kwargs = dict(user_node.kwargs)
# the origin node may be a positional argument or key word argument of user node
if node in new_args:
# substitute the origin node with shape_consistency_node
origin_index_args = new_args.index(node)
new_args[origin_index_args] = shape_consistency_node
user_node.args = new_args
elif str(node) in new_kwargs:
# substitute the origin node with shape_consistency_node
new_kwargs[str(node)] = shape_consistency_node
user_node.kwargs = new_kwargs
return gm
@ -112,18 +117,31 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
comm_actions = node.best_strategy.communication_actions
for op_data, comm_action in comm_actions.items():
comm_object = node.args[comm_action.arg_index]
if op_data.type == OperationDataType.PARAM:
continue
if comm_action.comm_type == CommType.BEFORE:
if comm_action.key_for_kwarg is not None:
comm_object = node.kwargs[comm_action.key_for_kwarg]
else:
comm_object = node.args[comm_action.arg_index]
with mod_graph.inserting_before(node):
comm_spec_apply_node = mod_graph.create_node('call_function',
runtime_comm_spec_apply,
args=(comm_object, comm_actions_dict_node,
node_to_index_dict[node], op_data.name))
new_args = list(node.args)
new_args[comm_action.arg_index] = comm_spec_apply_node
node.args = new_args
# the origin node may be a positional argument or key word argument of user node
if comm_action.key_for_kwarg is not None:
# substitute the origin node with comm_spec_apply_node
new_kwargs = dict(node.kwargs)
new_kwargs[comm_action.key_for_kwarg] = comm_spec_apply_node
node.kwargs = new_kwargs
else:
# substitute the origin node with comm_spec_apply_node
new_args = list(node.args)
new_args[comm_action.arg_index] = comm_spec_apply_node
node.args = new_args
elif comm_action.comm_type == CommType.AFTER:
with mod_graph.inserting_after(node):
comm_spec_apply_node = mod_graph.create_node('call_function',
@ -135,8 +153,16 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
if user == comm_spec_apply_node:
continue
new_args = list(user.args)
new_args[new_args.index(node)] = comm_spec_apply_node
user.args = tuple(new_args)
new_kwargs = dict(user.kwargs)
# the origin node may be a positional argument or key word argument of user node
if node in new_args:
# substitute the origin node with comm_spec_apply_node
new_args[new_args.index(node)] = comm_spec_apply_node
user.args = tuple(new_args)
elif str(node) in new_kwargs:
# substitute the origin node with comm_spec_apply_node
new_kwargs[str(node)] = comm_spec_apply_node
user.kwargs = new_kwargs
return gm

View File

@ -77,6 +77,7 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh):
if target_sharding_spec.dim_partition_dict != {}:
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
setattr(param, 'sharding_spec', origin_sharding_spec)
# TODO: build a ColoParamter class to manager the distributed parameters
param_sharded = torch.nn.Parameter(
shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec,
target_sharding_spec).detach().clone())

View File

@ -4,7 +4,6 @@ import warnings
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
@ -12,10 +11,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import \
ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
@ -135,7 +131,8 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE)
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping = {"input": input_comm_action}
if self.is_param("other"):
@ -223,8 +220,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.AFTER,
arg_index=0)
comm_type=CommType.AFTER)
communication_action_mapping = {"output": output_comm_action}
@ -277,8 +273,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.AFTER,
arg_index=0)
comm_type=CommType.AFTER)
input_comm_action = self.get_communication_action(
sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
@ -316,8 +311,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.AFTER,
arg_index=0)
comm_type=CommType.AFTER)
communication_action_mapping = {"output": output_comm_action}
@ -351,7 +345,8 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE)
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping = {"input": input_comm_action}
@ -441,8 +436,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.AFTER,
arg_index=0)
comm_type=CommType.AFTER)
communication_action_mapping = {"output": output_comm_action}

View File

@ -109,7 +109,8 @@ class StrategyGenerator(ABC):
communication_pattern: CollectiveCommPattern,
logical_process_axis: Union[int, List[int]],
comm_type: CommType,
arg_index: int = -1) -> CommAction:
arg_index: int = -1,
key_for_kwarg: any = None) -> CommAction:
"""
A factory method to produce a CommAction object.
"""
@ -117,7 +118,8 @@ class StrategyGenerator(ABC):
communication_pattern=communication_pattern,
logical_process_axis=logical_process_axis),
comm_type=comm_type,
arg_index=arg_index)
arg_index=arg_index,
key_for_kwarg=key_for_kwarg)
def update_communication_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
"""

View File

@ -115,6 +115,7 @@ class CommAction:
comm_spec: CommSpec = None
comm_type: CommType = None
arg_index: int = -1
key_for_kwarg: any = None
@dataclass

View File

@ -1,5 +1,6 @@
from functools import reduce
import operator
from functools import reduce
import torch
import torch.distributed as dist
@ -11,7 +12,7 @@ class DeviceMesh:
can be viewed as a 1x16 or a 4x4 logical mesh). Each mesh dimension has its
own latency and bandwidth. We use alpha-beta model to model the
communication cost.
Arguments:
physical_mesh_id (torch.Tensor): physical view of the devices in global rank.
mesh_shape (torch.Size): shape of logical view.
@ -64,6 +65,18 @@ class DeviceMesh:
def logical_mesh_id(self):
return self._logical_mesh_id
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if k != 'process_groups_dict':
setattr(result, k, __import__("copy").deepcopy(v, memo))
else:
setattr(result, k, v)
return result
def flatten(self):
"""
Flatten the logical mesh into an effective 1d logical mesh,
@ -90,7 +103,7 @@ class DeviceMesh:
def create_process_groups_for_logical_mesh(self):
'''
This method is used to initialize the logical process groups which will be used in communications
among logical device mesh.
among logical device mesh.
Note: if init_process_group set to False, you have to call this method manually. Otherwise,
the communication related function, such as ShapeConsistencyManager.apply will raise errors.
'''

View File

@ -28,6 +28,15 @@ class ShapeConsistencyOptions:
pass
def to_global(distributed_tensor: torch.Tensor, sharding_spec: ShardingSpec):
shape_consistency_manager = ShapeConsistencyManager()
global_sharding_spec = ShardingSpec(sharding_spec.device_mesh, sharding_spec.entire_shape, {})
with torch.no_grad():
global_tensor = shape_consistency_manager.apply_for_autoparallel_runtime(distributed_tensor, sharding_spec,
global_sharding_spec)
return global_tensor
def set_shape_consistency_options(options: ShapeConsistencyOptions):
"""
Configure the shape consistency manager via function call.

View File

@ -6,7 +6,6 @@ from functools import reduce
import torch
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.utils import (all_gather_simulator, all_to_all_simulator, shard_simulator)
__all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec']
@ -23,7 +22,7 @@ class _DimSpec:
This class is used internally in ShardingSpec.
Argument:
shard_list(List[int]): if shard_list is None, the dim spec will be 'R' type.
shard_list(List[int]): if shard_list is None, the dim spec will be 'R' type.
Otherwise, the element in shard_list means the data will be sharded in that dimension.
'''
@ -62,7 +61,7 @@ class _DimSpec:
def build_difference_2d_dict(self):
'''
Build a difference maping for 2D device mesh case. It will be used to
Build a difference maping for 2D device mesh case. It will be used to
compute the difference between DimSpec pairs.
'''
@ -159,9 +158,9 @@ class ShardingNotDivisibleError(ShardingSpecException):
class ShardingSpec:
'''
Sharding spec for a tensor, it contains info of the logical device mesh this tensor belong
to, the entire shape of the tensor before sharded, and the sharding sequence looks like
to, the entire shape of the tensor before sharded, and the sharding sequence looks like
[R, R, S0, S1].
Argument:
device_mesh(DeviceMesh): A logical view of a physical mesh.
entire_shape(torch.Size): The entire shape of tensor before sharded.
@ -260,10 +259,10 @@ class ShardingSpec:
# device_mesh_shape: (4, 4)
sharding_spec_to_compare = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_to_compare)
print(sharding_spec.sharding_sequence_difference(sharding_spec_to_compare))
Output:
25
Argument:
other(ShardingSpec): The ShardingSpec to compared with.

View File

@ -1,27 +1,44 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler, ConvModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing import parameterize
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@parameterize('bias', [True, False])
def test_conv_module_handler(bias):
model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1, bias=bias).to('meta'))
tracer = ColoTracer()
def check_conv_module_handler(rank, bias, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1, bias=bias)).cuda()
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {})
# return _0
input = torch.rand(4, 4, 64, 64).cuda()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
# index of conv node in this graph
node_index = 1
# total number of conv strategies
strategy_number = 16
numerical_test_for_node_strategy(model, device_mesh, node_index, strategy_number, [input], ['input'])
tracer = ColoTracer()
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 4, 64, 64).to('meta')})
gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
conv_mod_node = list(graph.nodes)[1]
strategies_vector = StrategiesVector(conv_mod_node)
@ -38,26 +55,26 @@ def test_conv_module_handler(bias):
assert op_data.data is not None
assert mapping['input'].name == "input_1"
assert mapping['input'].data.is_meta
# assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 4, 64, 64])
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([4, 4, 64, 64])
assert mapping['other'].name == "weight"
assert mapping['other'].data.is_meta
# assert mapping['other'].data.is_meta
assert mapping['other'].data.shape == torch.Size([16, 4, 3, 3])
assert mapping['other'].type == OperationDataType.PARAM
assert mapping['other'].logical_shape == torch.Size([4, 16, 3, 3])
if bias:
assert mapping['bias'].name == "bias"
assert mapping['bias'].data.is_meta
# assert mapping['bias'].data.is_meta
assert mapping['bias'].data.shape == torch.Size([16])
assert mapping['bias'].type == OperationDataType.PARAM
assert mapping['bias'].logical_shape == torch.Size([16])
assert mapping['output'].name == "_0"
assert mapping['output'].data.is_meta
# assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([4, 16, 64, 64])
assert mapping['output'].type == OperationDataType.OUTPUT
@ -129,9 +146,28 @@ class ConvModel(nn.Module):
return x
@parameterize('bias', [True, False])
def test_conv_function_handler(bias):
model = ConvModel()
def check_conv_function_handler(rank, bias, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = ConvModel().cuda()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
input = torch.rand(4, 4, 64, 64).cuda()
others = torch.rand(16, 4, 3, 3).cuda()
input_args = [input, others]
meta_arg_names = ['input', 'others']
input_kwargs = {}
# total number of conv strategies
strategy_number = 16
node_index = 2
if bias:
bias_tensor = torch.rand(16).cuda()
input_kwargs['bias'] = bias_tensor
node_index += 1
numerical_test_for_node_strategy(model, device_mesh, node_index, strategy_number, input_args, meta_arg_names,
input_kwargs)
tracer = ColoTracer()
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
@ -143,10 +179,6 @@ def test_conv_function_handler(bias):
meta_args['bias'] = torch.rand(16).to('meta')
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
if bias:
conv_mod_node = list(graph.nodes)[3]
@ -248,6 +280,26 @@ def test_conv_function_handler(bias):
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[1]
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@parameterize('bias', [True, False])
@rerun_if_address_is_in_use()
def test_conv_module_handler(bias):
world_size = 4
run_func = partial(check_conv_module_handler, bias=bias, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@parameterize('bias', [True, False])
@rerun_if_address_is_in_use()
def test_conv_function_handler(bias):
world_size = 4
run_func = partial(check_conv_function_handler, bias=bias, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_conv_module_handler()
test_conv_function_handler()

View File

@ -0,0 +1,126 @@
import copy
from typing import Dict, List
import torch
from torch.fx import GraphModule
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.solver import SolverOptions, StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.shape_consistency import to_global
from colossalai.testing.comparison import assert_close
def _build_model_to_compare(model: torch.nn.Module, input_args: List[torch.Tensor],
input_kwargs: Dict[str, torch.Tensor], grad_dict: Dict[any, torch.Tensor]):
model_to_compare = copy.deepcopy(model)
args_to_compare = []
kwargs_to_compare = {}
for arg_index, input_tensor in enumerate(input_args):
def wrapper(param, index):
def hook_fn(grad):
grad_dict[index] = grad
param.register_hook(hook_fn)
arg_to_compare = copy.deepcopy(input_tensor)
arg_to_compare.requires_grad = True
wrapper(arg_to_compare, arg_index)
# arg_to_compare.register_hook(hook_fn)
args_to_compare.append(arg_to_compare)
for name, input_kwarg in input_kwargs.items():
def wrapper(param, name):
def hook_fn(grad):
grad_dict[name] = grad
param.register_hook(hook_fn)
kwarg_to_compare = copy.deepcopy(input_kwarg)
kwarg_to_compare.requires_grad = True
wrapper(kwarg_to_compare, name)
kwargs_to_compare[name] = kwarg_to_compare
return model_to_compare, args_to_compare, kwargs_to_compare
def numerical_test_for_node_strategy(model: torch.nn.Module,
device_mesh: DeviceMesh,
node_index: int,
strategy_number: int,
input_args: List[torch.Tensor],
meta_arg_names: List[str],
input_kwargs: Dict[str, torch.Tensor] = {}):
for strategy_index in range(strategy_number):
print(f'#strategy_index: {strategy_index}')
# We need to copy the model to avoid do backward more than once in same graph
grad_to_compare_dict = {}
grad_to_shard_dict = {}
model_to_compare, args_to_compare, kwargs_to_compare = _build_model_to_compare(
model, input_args, input_kwargs, grad_to_compare_dict)
model_to_shard, args_to_shard, kwargs_to_shard = _build_model_to_compare(model, input_args, input_kwargs,
grad_to_shard_dict)
zero_tensor = torch.Tensor(0).cuda()
tracer = ColoTracer()
input_sample = {}
for input_arg, meta_arg_name in zip(input_args, meta_arg_names):
input_sample[meta_arg_name] = torch.rand(input_arg.shape).to('meta')
for meta_kwarg_name, input_kwarg in input_kwargs.items():
input_sample[meta_kwarg_name] = torch.rand(input_kwarg.shape).to('meta')
graph = tracer.trace(root=model_to_shard, meta_args=input_sample)
gm = GraphModule(model_to_shard, graph, model_to_shard.__class__.__name__)
solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
target_node = list(graph.nodes)[node_index]
# solution construction
solution_len = len(strategies_constructor.leaf_strategies)
solution = [0] * solution_len
solution[node_index] = strategy_index
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
gm, solution, device_mesh)
gm = runtime_apply_pass(gm)
gm.recompile()
# forward result compare
output = gm(*args_to_shard,
sharding_spec_convert_dict=sharding_spec_dict,
origin_node_sharding_spec_dict=origin_spec_dict,
comm_actions_dict=comm_actions_dict,
**kwargs_to_shard)
# except:
# print(gm)
output_to_compare = model_to_compare(*args_to_compare, **kwargs_to_compare)
assert_close((output - output_to_compare).sum(), zero_tensor)
# backward result compare
loss = output.sum()
loss_to_compare = output_to_compare.sum()
loss.backward()
loss_to_compare.backward()
for key in grad_to_shard_dict.keys():
grad_to_shard = grad_to_shard_dict[key]
grad_to_compare = grad_to_compare_dict[key]
assert_close((grad_to_shard - grad_to_compare).sum(), zero_tensor)
# extract the strategy used in this iter
strategy_in_use = target_node.strategies_vector[strategy_index]
param_to_shard_dict = dict(model_to_shard.named_parameters())
param_to_compare_dict = dict(model_to_compare.named_parameters())
for name in param_to_shard_dict.keys():
param_name = name.split('.')[-1]
param_sharding_spec = strategy_in_use.get_sharding_spec_by_name(param_name)
grad_sharded = param_to_shard_dict[name].grad
grad_to_compare = param_to_compare_dict[name].grad
global_grad = to_global(grad_sharded, param_sharding_spec)
assert_close((global_grad - grad_to_compare).sum(), zero_tensor)