[autoparallel] add numerical test for node strategies (#1760)

* [autoparallel] add numerical test for node strategies

* polish code

* polish code
pull/1766/head
YuliangLiu0306 2022-10-27 10:42:54 +08:00 committed by GitHub
parent 25952b67d7
commit b4cc59b61e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 283 additions and 60 deletions

View File

@ -24,7 +24,6 @@ def runtime_apply(node: Node, origin_dict: Dict, input_dict: Dict, node_index: i
"""
origin_sharding_spec = origin_dict[node_index]
target_sharding_spec = input_dict[node_index][user_node_index]
return shape_consistency_manager.apply_for_autoparallel_runtime(node, origin_sharding_spec, target_sharding_spec)
@ -81,18 +80,24 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
if not hasattr(node, 'best_strategy') or node.op == 'output':
continue
for user_node in node.strategies_vector.successor_nodes:
user_node_index = user_node.strategies_vector.predecessor_nodes.index(node)
for user_node_index, user_node in enumerate(node.strategies_vector.successor_nodes):
with mod_graph.inserting_before(user_node):
shape_consistency_node = mod_graph.create_node('call_function',
runtime_apply,
args=(node, origin_dict_node, input_dict_node,
node_to_index_dict[node], user_node_index))
origin_index_args = user_node.args.index(node)
new_args = list(user_node.args)
new_kwargs = dict(user_node.kwargs)
# the origin node may be a positional argument or key word argument of user node
if node in new_args:
# substitute the origin node with shape_consistency_node
origin_index_args = new_args.index(node)
new_args[origin_index_args] = shape_consistency_node
user_node.args = new_args
elif str(node) in new_kwargs:
# substitute the origin node with shape_consistency_node
new_kwargs[str(node)] = shape_consistency_node
user_node.kwargs = new_kwargs
return gm
@ -112,18 +117,31 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
comm_actions = node.best_strategy.communication_actions
for op_data, comm_action in comm_actions.items():
comm_object = node.args[comm_action.arg_index]
if op_data.type == OperationDataType.PARAM:
continue
if comm_action.comm_type == CommType.BEFORE:
if comm_action.key_for_kwarg is not None:
comm_object = node.kwargs[comm_action.key_for_kwarg]
else:
comm_object = node.args[comm_action.arg_index]
with mod_graph.inserting_before(node):
comm_spec_apply_node = mod_graph.create_node('call_function',
runtime_comm_spec_apply,
args=(comm_object, comm_actions_dict_node,
node_to_index_dict[node], op_data.name))
# the origin node may be a positional argument or key word argument of user node
if comm_action.key_for_kwarg is not None:
# substitute the origin node with comm_spec_apply_node
new_kwargs = dict(node.kwargs)
new_kwargs[comm_action.key_for_kwarg] = comm_spec_apply_node
node.kwargs = new_kwargs
else:
# substitute the origin node with comm_spec_apply_node
new_args = list(node.args)
new_args[comm_action.arg_index] = comm_spec_apply_node
node.args = new_args
elif comm_action.comm_type == CommType.AFTER:
with mod_graph.inserting_after(node):
comm_spec_apply_node = mod_graph.create_node('call_function',
@ -135,8 +153,16 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
if user == comm_spec_apply_node:
continue
new_args = list(user.args)
new_kwargs = dict(user.kwargs)
# the origin node may be a positional argument or key word argument of user node
if node in new_args:
# substitute the origin node with comm_spec_apply_node
new_args[new_args.index(node)] = comm_spec_apply_node
user.args = tuple(new_args)
elif str(node) in new_kwargs:
# substitute the origin node with comm_spec_apply_node
new_kwargs[str(node)] = comm_spec_apply_node
user.kwargs = new_kwargs
return gm

View File

@ -77,6 +77,7 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh):
if target_sharding_spec.dim_partition_dict != {}:
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
setattr(param, 'sharding_spec', origin_sharding_spec)
# TODO: build a ColoParamter class to manager the distributed parameters
param_sharded = torch.nn.Parameter(
shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec,
target_sharding_spec).detach().clone())

View File

@ -4,7 +4,6 @@ import warnings
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
@ -12,10 +11,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import \
ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
@ -135,7 +131,8 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE)
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping = {"input": input_comm_action}
if self.is_param("other"):
@ -223,8 +220,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.AFTER,
arg_index=0)
comm_type=CommType.AFTER)
communication_action_mapping = {"output": output_comm_action}
@ -277,8 +273,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.AFTER,
arg_index=0)
comm_type=CommType.AFTER)
input_comm_action = self.get_communication_action(
sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
@ -316,8 +311,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.AFTER,
arg_index=0)
comm_type=CommType.AFTER)
communication_action_mapping = {"output": output_comm_action}
@ -351,7 +345,8 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE)
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping = {"input": input_comm_action}
@ -441,8 +436,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.AFTER,
arg_index=0)
comm_type=CommType.AFTER)
communication_action_mapping = {"output": output_comm_action}

View File

@ -109,7 +109,8 @@ class StrategyGenerator(ABC):
communication_pattern: CollectiveCommPattern,
logical_process_axis: Union[int, List[int]],
comm_type: CommType,
arg_index: int = -1) -> CommAction:
arg_index: int = -1,
key_for_kwarg: any = None) -> CommAction:
"""
A factory method to produce a CommAction object.
"""
@ -117,7 +118,8 @@ class StrategyGenerator(ABC):
communication_pattern=communication_pattern,
logical_process_axis=logical_process_axis),
comm_type=comm_type,
arg_index=arg_index)
arg_index=arg_index,
key_for_kwarg=key_for_kwarg)
def update_communication_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
"""

View File

@ -115,6 +115,7 @@ class CommAction:
comm_spec: CommSpec = None
comm_type: CommType = None
arg_index: int = -1
key_for_kwarg: any = None
@dataclass

View File

@ -1,5 +1,6 @@
from functools import reduce
import operator
from functools import reduce
import torch
import torch.distributed as dist
@ -64,6 +65,18 @@ class DeviceMesh:
def logical_mesh_id(self):
return self._logical_mesh_id
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if k != 'process_groups_dict':
setattr(result, k, __import__("copy").deepcopy(v, memo))
else:
setattr(result, k, v)
return result
def flatten(self):
"""
Flatten the logical mesh into an effective 1d logical mesh,

View File

@ -28,6 +28,15 @@ class ShapeConsistencyOptions:
pass
def to_global(distributed_tensor: torch.Tensor, sharding_spec: ShardingSpec):
shape_consistency_manager = ShapeConsistencyManager()
global_sharding_spec = ShardingSpec(sharding_spec.device_mesh, sharding_spec.entire_shape, {})
with torch.no_grad():
global_tensor = shape_consistency_manager.apply_for_autoparallel_runtime(distributed_tensor, sharding_spec,
global_sharding_spec)
return global_tensor
def set_shape_consistency_options(options: ShapeConsistencyOptions):
"""
Configure the shape consistency manager via function call.

View File

@ -6,7 +6,6 @@ from functools import reduce
import torch
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.utils import (all_gather_simulator, all_to_all_simulator, shard_simulator)
__all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec']

View File

@ -1,27 +1,44 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler, ConvModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing import parameterize
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@parameterize('bias', [True, False])
def test_conv_module_handler(bias):
model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1, bias=bias).to('meta'))
tracer = ColoTracer()
def check_conv_module_handler(rank, bias, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1, bias=bias)).cuda()
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {})
# return _0
input = torch.rand(4, 4, 64, 64).cuda()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
# index of conv node in this graph
node_index = 1
# total number of conv strategies
strategy_number = 16
numerical_test_for_node_strategy(model, device_mesh, node_index, strategy_number, [input], ['input'])
tracer = ColoTracer()
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 4, 64, 64).to('meta')})
gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
conv_mod_node = list(graph.nodes)[1]
strategies_vector = StrategiesVector(conv_mod_node)
@ -38,26 +55,26 @@ def test_conv_module_handler(bias):
assert op_data.data is not None
assert mapping['input'].name == "input_1"
assert mapping['input'].data.is_meta
# assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 4, 64, 64])
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([4, 4, 64, 64])
assert mapping['other'].name == "weight"
assert mapping['other'].data.is_meta
# assert mapping['other'].data.is_meta
assert mapping['other'].data.shape == torch.Size([16, 4, 3, 3])
assert mapping['other'].type == OperationDataType.PARAM
assert mapping['other'].logical_shape == torch.Size([4, 16, 3, 3])
if bias:
assert mapping['bias'].name == "bias"
assert mapping['bias'].data.is_meta
# assert mapping['bias'].data.is_meta
assert mapping['bias'].data.shape == torch.Size([16])
assert mapping['bias'].type == OperationDataType.PARAM
assert mapping['bias'].logical_shape == torch.Size([16])
assert mapping['output'].name == "_0"
assert mapping['output'].data.is_meta
# assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([4, 16, 64, 64])
assert mapping['output'].type == OperationDataType.OUTPUT
@ -129,9 +146,28 @@ class ConvModel(nn.Module):
return x
@parameterize('bias', [True, False])
def test_conv_function_handler(bias):
model = ConvModel()
def check_conv_function_handler(rank, bias, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = ConvModel().cuda()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
input = torch.rand(4, 4, 64, 64).cuda()
others = torch.rand(16, 4, 3, 3).cuda()
input_args = [input, others]
meta_arg_names = ['input', 'others']
input_kwargs = {}
# total number of conv strategies
strategy_number = 16
node_index = 2
if bias:
bias_tensor = torch.rand(16).cuda()
input_kwargs['bias'] = bias_tensor
node_index += 1
numerical_test_for_node_strategy(model, device_mesh, node_index, strategy_number, input_args, meta_arg_names,
input_kwargs)
tracer = ColoTracer()
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
@ -143,10 +179,6 @@ def test_conv_function_handler(bias):
meta_args['bias'] = torch.rand(16).to('meta')
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
if bias:
conv_mod_node = list(graph.nodes)[3]
@ -248,6 +280,26 @@ def test_conv_function_handler(bias):
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[1]
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@parameterize('bias', [True, False])
@rerun_if_address_is_in_use()
def test_conv_module_handler(bias):
world_size = 4
run_func = partial(check_conv_module_handler, bias=bias, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@parameterize('bias', [True, False])
@rerun_if_address_is_in_use()
def test_conv_function_handler(bias):
world_size = 4
run_func = partial(check_conv_function_handler, bias=bias, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_conv_module_handler()
test_conv_function_handler()

View File

@ -0,0 +1,126 @@
import copy
from typing import Dict, List
import torch
from torch.fx import GraphModule
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.solver import SolverOptions, StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.shape_consistency import to_global
from colossalai.testing.comparison import assert_close
def _build_model_to_compare(model: torch.nn.Module, input_args: List[torch.Tensor],
input_kwargs: Dict[str, torch.Tensor], grad_dict: Dict[any, torch.Tensor]):
model_to_compare = copy.deepcopy(model)
args_to_compare = []
kwargs_to_compare = {}
for arg_index, input_tensor in enumerate(input_args):
def wrapper(param, index):
def hook_fn(grad):
grad_dict[index] = grad
param.register_hook(hook_fn)
arg_to_compare = copy.deepcopy(input_tensor)
arg_to_compare.requires_grad = True
wrapper(arg_to_compare, arg_index)
# arg_to_compare.register_hook(hook_fn)
args_to_compare.append(arg_to_compare)
for name, input_kwarg in input_kwargs.items():
def wrapper(param, name):
def hook_fn(grad):
grad_dict[name] = grad
param.register_hook(hook_fn)
kwarg_to_compare = copy.deepcopy(input_kwarg)
kwarg_to_compare.requires_grad = True
wrapper(kwarg_to_compare, name)
kwargs_to_compare[name] = kwarg_to_compare
return model_to_compare, args_to_compare, kwargs_to_compare
def numerical_test_for_node_strategy(model: torch.nn.Module,
device_mesh: DeviceMesh,
node_index: int,
strategy_number: int,
input_args: List[torch.Tensor],
meta_arg_names: List[str],
input_kwargs: Dict[str, torch.Tensor] = {}):
for strategy_index in range(strategy_number):
print(f'#strategy_index: {strategy_index}')
# We need to copy the model to avoid do backward more than once in same graph
grad_to_compare_dict = {}
grad_to_shard_dict = {}
model_to_compare, args_to_compare, kwargs_to_compare = _build_model_to_compare(
model, input_args, input_kwargs, grad_to_compare_dict)
model_to_shard, args_to_shard, kwargs_to_shard = _build_model_to_compare(model, input_args, input_kwargs,
grad_to_shard_dict)
zero_tensor = torch.Tensor(0).cuda()
tracer = ColoTracer()
input_sample = {}
for input_arg, meta_arg_name in zip(input_args, meta_arg_names):
input_sample[meta_arg_name] = torch.rand(input_arg.shape).to('meta')
for meta_kwarg_name, input_kwarg in input_kwargs.items():
input_sample[meta_kwarg_name] = torch.rand(input_kwarg.shape).to('meta')
graph = tracer.trace(root=model_to_shard, meta_args=input_sample)
gm = GraphModule(model_to_shard, graph, model_to_shard.__class__.__name__)
solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
target_node = list(graph.nodes)[node_index]
# solution construction
solution_len = len(strategies_constructor.leaf_strategies)
solution = [0] * solution_len
solution[node_index] = strategy_index
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
gm, solution, device_mesh)
gm = runtime_apply_pass(gm)
gm.recompile()
# forward result compare
output = gm(*args_to_shard,
sharding_spec_convert_dict=sharding_spec_dict,
origin_node_sharding_spec_dict=origin_spec_dict,
comm_actions_dict=comm_actions_dict,
**kwargs_to_shard)
# except:
# print(gm)
output_to_compare = model_to_compare(*args_to_compare, **kwargs_to_compare)
assert_close((output - output_to_compare).sum(), zero_tensor)
# backward result compare
loss = output.sum()
loss_to_compare = output_to_compare.sum()
loss.backward()
loss_to_compare.backward()
for key in grad_to_shard_dict.keys():
grad_to_shard = grad_to_shard_dict[key]
grad_to_compare = grad_to_compare_dict[key]
assert_close((grad_to_shard - grad_to_compare).sum(), zero_tensor)
# extract the strategy used in this iter
strategy_in_use = target_node.strategies_vector[strategy_index]
param_to_shard_dict = dict(model_to_shard.named_parameters())
param_to_compare_dict = dict(model_to_compare.named_parameters())
for name in param_to_shard_dict.keys():
param_name = name.split('.')[-1]
param_sharding_spec = strategy_in_use.get_sharding_spec_by_name(param_name)
grad_sharded = param_to_shard_dict[name].grad
grad_to_compare = param_to_compare_dict[name].grad
global_grad = to_global(grad_sharded, param_sharding_spec)
assert_close((global_grad - grad_to_compare).sum(), zero_tensor)