From b4cc59b61e4f8921eb2a06417279cddc3c5b6e33 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Thu, 27 Oct 2022 10:42:54 +0800 Subject: [PATCH] [autoparallel] add numerical test for node strategies (#1760) * [autoparallel] add numerical test for node strategies * polish code * polish code --- .../passes/runtime_apply_pass.py | 52 ++++++-- .../passes/runtime_preparation_pass.py | 1 + .../strategy/conv_strategy_generator.py | 24 ++-- .../strategy/strategy_generator.py | 6 +- .../tensor_shard/sharding_strategy.py | 1 + colossalai/device/device_mesh.py | 19 ++- colossalai/tensor/shape_consistency.py | 9 ++ colossalai/tensor/sharding_spec.py | 13 +- .../test_node_handler/test_conv_handler.py | 92 ++++++++++--- .../test_node_handler/utils.py | 126 ++++++++++++++++++ 10 files changed, 283 insertions(+), 60 deletions(-) create mode 100644 tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py diff --git a/colossalai/auto_parallel/passes/runtime_apply_pass.py b/colossalai/auto_parallel/passes/runtime_apply_pass.py index 09f123665..cc2466273 100644 --- a/colossalai/auto_parallel/passes/runtime_apply_pass.py +++ b/colossalai/auto_parallel/passes/runtime_apply_pass.py @@ -24,7 +24,6 @@ def runtime_apply(node: Node, origin_dict: Dict, input_dict: Dict, node_index: i """ origin_sharding_spec = origin_dict[node_index] target_sharding_spec = input_dict[node_index][user_node_index] - return shape_consistency_manager.apply_for_autoparallel_runtime(node, origin_sharding_spec, target_sharding_spec) @@ -81,18 +80,24 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule): if not hasattr(node, 'best_strategy') or node.op == 'output': continue - for user_node in node.strategies_vector.successor_nodes: - user_node_index = user_node.strategies_vector.predecessor_nodes.index(node) + for user_node_index, user_node in enumerate(node.strategies_vector.successor_nodes): with mod_graph.inserting_before(user_node): shape_consistency_node = mod_graph.create_node('call_function', runtime_apply, args=(node, origin_dict_node, input_dict_node, node_to_index_dict[node], user_node_index)) - - origin_index_args = user_node.args.index(node) new_args = list(user_node.args) - new_args[origin_index_args] = shape_consistency_node - user_node.args = new_args + new_kwargs = dict(user_node.kwargs) + # the origin node may be a positional argument or key word argument of user node + if node in new_args: + # substitute the origin node with shape_consistency_node + origin_index_args = new_args.index(node) + new_args[origin_index_args] = shape_consistency_node + user_node.args = new_args + elif str(node) in new_kwargs: + # substitute the origin node with shape_consistency_node + new_kwargs[str(node)] = shape_consistency_node + user_node.kwargs = new_kwargs return gm @@ -112,18 +117,31 @@ def _comm_spec_apply(gm: torch.fx.GraphModule): comm_actions = node.best_strategy.communication_actions for op_data, comm_action in comm_actions.items(): - comm_object = node.args[comm_action.arg_index] + if op_data.type == OperationDataType.PARAM: continue if comm_action.comm_type == CommType.BEFORE: + if comm_action.key_for_kwarg is not None: + comm_object = node.kwargs[comm_action.key_for_kwarg] + else: + comm_object = node.args[comm_action.arg_index] with mod_graph.inserting_before(node): comm_spec_apply_node = mod_graph.create_node('call_function', runtime_comm_spec_apply, args=(comm_object, comm_actions_dict_node, node_to_index_dict[node], op_data.name)) - new_args = list(node.args) - new_args[comm_action.arg_index] = comm_spec_apply_node - node.args = new_args + # the origin node may be a positional argument or key word argument of user node + if comm_action.key_for_kwarg is not None: + # substitute the origin node with comm_spec_apply_node + new_kwargs = dict(node.kwargs) + new_kwargs[comm_action.key_for_kwarg] = comm_spec_apply_node + node.kwargs = new_kwargs + else: + # substitute the origin node with comm_spec_apply_node + new_args = list(node.args) + new_args[comm_action.arg_index] = comm_spec_apply_node + node.args = new_args + elif comm_action.comm_type == CommType.AFTER: with mod_graph.inserting_after(node): comm_spec_apply_node = mod_graph.create_node('call_function', @@ -135,8 +153,16 @@ def _comm_spec_apply(gm: torch.fx.GraphModule): if user == comm_spec_apply_node: continue new_args = list(user.args) - new_args[new_args.index(node)] = comm_spec_apply_node - user.args = tuple(new_args) + new_kwargs = dict(user.kwargs) + # the origin node may be a positional argument or key word argument of user node + if node in new_args: + # substitute the origin node with comm_spec_apply_node + new_args[new_args.index(node)] = comm_spec_apply_node + user.args = tuple(new_args) + elif str(node) in new_kwargs: + # substitute the origin node with comm_spec_apply_node + new_kwargs[str(node)] = comm_spec_apply_node + user.kwargs = new_kwargs return gm diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py index 796a95ee4..00268e3f5 100644 --- a/colossalai/auto_parallel/passes/runtime_preparation_pass.py +++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py @@ -77,6 +77,7 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh): if target_sharding_spec.dim_partition_dict != {}: origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {}) setattr(param, 'sharding_spec', origin_sharding_spec) + # TODO: build a ColoParamter class to manager the distributed parameters param_sharded = torch.nn.Parameter( shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec, target_sharding_spec).detach().clone()) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py index 83476e4fe..f7e4543f8 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py @@ -4,7 +4,6 @@ import warnings from functools import reduce from typing import List - from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( CommAction, CommType, @@ -12,10 +11,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( ShardingStrategy, TrainCycleItem, ) - -from colossalai.auto_parallel.tensor_shard.utils import \ - ignore_sharding_exception - +from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception from colossalai.tensor.shape_consistency import CollectiveCommPattern from .strategy_generator import StrategyGenerator @@ -135,7 +131,8 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec=sharding_spec_mapping["input"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_1, - comm_type=CommType.BEFORE) + comm_type=CommType.BEFORE, + arg_index=0) communication_action_mapping = {"input": input_comm_action} if self.is_param("other"): @@ -223,8 +220,7 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim_1, - comm_type=CommType.AFTER, - arg_index=0) + comm_type=CommType.AFTER) communication_action_mapping = {"output": output_comm_action} @@ -277,8 +273,7 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.AFTER, - arg_index=0) + comm_type=CommType.AFTER) input_comm_action = self.get_communication_action( sharding_spec_mapping["input"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, @@ -316,8 +311,7 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.AFTER, - arg_index=0) + comm_type=CommType.AFTER) communication_action_mapping = {"output": output_comm_action} @@ -351,7 +345,8 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec_mapping["input"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.BEFORE) + comm_type=CommType.BEFORE, + arg_index=0) communication_action_mapping = {"input": input_comm_action} @@ -441,8 +436,7 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], - comm_type=CommType.AFTER, - arg_index=0) + comm_type=CommType.AFTER) communication_action_mapping = {"output": output_comm_action} diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py index 8f57ee6a0..b3903b9d7 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py @@ -109,7 +109,8 @@ class StrategyGenerator(ABC): communication_pattern: CollectiveCommPattern, logical_process_axis: Union[int, List[int]], comm_type: CommType, - arg_index: int = -1) -> CommAction: + arg_index: int = -1, + key_for_kwarg: any = None) -> CommAction: """ A factory method to produce a CommAction object. """ @@ -117,7 +118,8 @@ class StrategyGenerator(ABC): communication_pattern=communication_pattern, logical_process_axis=logical_process_axis), comm_type=comm_type, - arg_index=arg_index) + arg_index=arg_index, + key_for_kwarg=key_for_kwarg) def update_communication_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: """ diff --git a/colossalai/auto_parallel/tensor_shard/sharding_strategy.py b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py index 8dbb0014b..334fb10d7 100644 --- a/colossalai/auto_parallel/tensor_shard/sharding_strategy.py +++ b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py @@ -115,6 +115,7 @@ class CommAction: comm_spec: CommSpec = None comm_type: CommType = None arg_index: int = -1 + key_for_kwarg: any = None @dataclass diff --git a/colossalai/device/device_mesh.py b/colossalai/device/device_mesh.py index df010e7d7..403bbe4ae 100644 --- a/colossalai/device/device_mesh.py +++ b/colossalai/device/device_mesh.py @@ -1,5 +1,6 @@ -from functools import reduce import operator +from functools import reduce + import torch import torch.distributed as dist @@ -11,7 +12,7 @@ class DeviceMesh: can be viewed as a 1x16 or a 4x4 logical mesh). Each mesh dimension has its own latency and bandwidth. We use alpha-beta model to model the communication cost. - + Arguments: physical_mesh_id (torch.Tensor): physical view of the devices in global rank. mesh_shape (torch.Size): shape of logical view. @@ -64,6 +65,18 @@ class DeviceMesh: def logical_mesh_id(self): return self._logical_mesh_id + def __deepcopy__(self, memo): + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + for k, v in self.__dict__.items(): + if k != 'process_groups_dict': + setattr(result, k, __import__("copy").deepcopy(v, memo)) + else: + setattr(result, k, v) + + return result + def flatten(self): """ Flatten the logical mesh into an effective 1d logical mesh, @@ -90,7 +103,7 @@ class DeviceMesh: def create_process_groups_for_logical_mesh(self): ''' This method is used to initialize the logical process groups which will be used in communications - among logical device mesh. + among logical device mesh. Note: if init_process_group set to False, you have to call this method manually. Otherwise, the communication related function, such as ShapeConsistencyManager.apply will raise errors. ''' diff --git a/colossalai/tensor/shape_consistency.py b/colossalai/tensor/shape_consistency.py index d96040817..4ec5ad9e9 100644 --- a/colossalai/tensor/shape_consistency.py +++ b/colossalai/tensor/shape_consistency.py @@ -28,6 +28,15 @@ class ShapeConsistencyOptions: pass +def to_global(distributed_tensor: torch.Tensor, sharding_spec: ShardingSpec): + shape_consistency_manager = ShapeConsistencyManager() + global_sharding_spec = ShardingSpec(sharding_spec.device_mesh, sharding_spec.entire_shape, {}) + with torch.no_grad(): + global_tensor = shape_consistency_manager.apply_for_autoparallel_runtime(distributed_tensor, sharding_spec, + global_sharding_spec) + return global_tensor + + def set_shape_consistency_options(options: ShapeConsistencyOptions): """ Configure the shape consistency manager via function call. diff --git a/colossalai/tensor/sharding_spec.py b/colossalai/tensor/sharding_spec.py index fababb6e7..37d397885 100644 --- a/colossalai/tensor/sharding_spec.py +++ b/colossalai/tensor/sharding_spec.py @@ -6,7 +6,6 @@ from functools import reduce import torch from colossalai.device.device_mesh import DeviceMesh -from colossalai.tensor.utils import (all_gather_simulator, all_to_all_simulator, shard_simulator) __all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec'] @@ -23,7 +22,7 @@ class _DimSpec: This class is used internally in ShardingSpec. Argument: - shard_list(List[int]): if shard_list is None, the dim spec will be 'R' type. + shard_list(List[int]): if shard_list is None, the dim spec will be 'R' type. Otherwise, the element in shard_list means the data will be sharded in that dimension. ''' @@ -62,7 +61,7 @@ class _DimSpec: def build_difference_2d_dict(self): ''' - Build a difference maping for 2D device mesh case. It will be used to + Build a difference maping for 2D device mesh case. It will be used to compute the difference between DimSpec pairs. ''' @@ -159,9 +158,9 @@ class ShardingNotDivisibleError(ShardingSpecException): class ShardingSpec: ''' Sharding spec for a tensor, it contains info of the logical device mesh this tensor belong - to, the entire shape of the tensor before sharded, and the sharding sequence looks like + to, the entire shape of the tensor before sharded, and the sharding sequence looks like [R, R, S0, S1]. - + Argument: device_mesh(DeviceMesh): A logical view of a physical mesh. entire_shape(torch.Size): The entire shape of tensor before sharded. @@ -260,10 +259,10 @@ class ShardingSpec: # device_mesh_shape: (4, 4) sharding_spec_to_compare = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_to_compare) print(sharding_spec.sharding_sequence_difference(sharding_spec_to_compare)) - + Output: 25 - + Argument: other(ShardingSpec): The ShardingSpec to compared with. diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py index 97025729c..dc86712f6 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py @@ -1,27 +1,44 @@ +from functools import partial + +import pytest import torch +import torch.multiprocessing as mp import torch.nn as nn from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler, ConvModuleHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer -from colossalai.testing import parameterize +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy -@parameterize('bias', [True, False]) -def test_conv_module_handler(bias): - model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1, bias=bias).to('meta')) - tracer = ColoTracer() +def check_conv_module_handler(rank, bias, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1, bias=bias)).cuda() # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {}) # return _0 + input = torch.rand(4, 4, 64, 64).cuda() + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # index of conv node in this graph + node_index = 1 + # total number of conv strategies + strategy_number = 16 + numerical_test_for_node_strategy(model, device_mesh, node_index, strategy_number, [input], ['input']) + tracer = ColoTracer() graph = tracer.trace(model, meta_args={"input": torch.rand(4, 4, 64, 64).to('meta')}) gm = ColoGraphModule(model, graph) - physical_mesh_id = torch.arange(0, 4) - - mesh_shape = (2, 2) - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) conv_mod_node = list(graph.nodes)[1] strategies_vector = StrategiesVector(conv_mod_node) @@ -38,26 +55,26 @@ def test_conv_module_handler(bias): assert op_data.data is not None assert mapping['input'].name == "input_1" - assert mapping['input'].data.is_meta + # assert mapping['input'].data.is_meta assert mapping['input'].data.shape == torch.Size([4, 4, 64, 64]) assert mapping['input'].type == OperationDataType.ARG assert mapping['input'].logical_shape == torch.Size([4, 4, 64, 64]) assert mapping['other'].name == "weight" - assert mapping['other'].data.is_meta + # assert mapping['other'].data.is_meta assert mapping['other'].data.shape == torch.Size([16, 4, 3, 3]) assert mapping['other'].type == OperationDataType.PARAM assert mapping['other'].logical_shape == torch.Size([4, 16, 3, 3]) if bias: assert mapping['bias'].name == "bias" - assert mapping['bias'].data.is_meta + # assert mapping['bias'].data.is_meta assert mapping['bias'].data.shape == torch.Size([16]) assert mapping['bias'].type == OperationDataType.PARAM assert mapping['bias'].logical_shape == torch.Size([16]) assert mapping['output'].name == "_0" - assert mapping['output'].data.is_meta + # assert mapping['output'].data.is_meta assert mapping['output'].data.shape == torch.Size([4, 16, 64, 64]) assert mapping['output'].type == OperationDataType.OUTPUT @@ -129,9 +146,28 @@ class ConvModel(nn.Module): return x -@parameterize('bias', [True, False]) -def test_conv_function_handler(bias): - model = ConvModel() +def check_conv_function_handler(rank, bias, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = ConvModel().cuda() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + input = torch.rand(4, 4, 64, 64).cuda() + others = torch.rand(16, 4, 3, 3).cuda() + input_args = [input, others] + meta_arg_names = ['input', 'others'] + input_kwargs = {} + # total number of conv strategies + strategy_number = 16 + node_index = 2 + if bias: + bias_tensor = torch.rand(16).cuda() + input_kwargs['bias'] = bias_tensor + node_index += 1 + numerical_test_for_node_strategy(model, device_mesh, node_index, strategy_number, input_args, meta_arg_names, + input_kwargs) + tracer = ColoTracer() # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] @@ -143,10 +179,6 @@ def test_conv_function_handler(bias): meta_args['bias'] = torch.rand(16).to('meta') graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) - physical_mesh_id = torch.arange(0, 4) - - mesh_shape = (2, 2) - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) if bias: conv_mod_node = list(graph.nodes)[3] @@ -248,6 +280,26 @@ def test_conv_function_handler(bias): assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[1] +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@parameterize('bias', [True, False]) +@rerun_if_address_is_in_use() +def test_conv_module_handler(bias): + world_size = 4 + run_func = partial(check_conv_module_handler, bias=bias, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@parameterize('bias', [True, False]) +@rerun_if_address_is_in_use() +def test_conv_function_handler(bias): + world_size = 4 + run_func = partial(check_conv_function_handler, bias=bias, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + if __name__ == '__main__': test_conv_module_handler() test_conv_function_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py new file mode 100644 index 000000000..47ee6be79 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py @@ -0,0 +1,126 @@ +import copy +from typing import Dict, List + +import torch +from torch.fx import GraphModule + +from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass +from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass +from colossalai.auto_parallel.tensor_shard.solver import SolverOptions, StrategiesConstructor +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.tensor.shape_consistency import to_global +from colossalai.testing.comparison import assert_close + + +def _build_model_to_compare(model: torch.nn.Module, input_args: List[torch.Tensor], + input_kwargs: Dict[str, torch.Tensor], grad_dict: Dict[any, torch.Tensor]): + + model_to_compare = copy.deepcopy(model) + args_to_compare = [] + kwargs_to_compare = {} + for arg_index, input_tensor in enumerate(input_args): + + def wrapper(param, index): + + def hook_fn(grad): + grad_dict[index] = grad + + param.register_hook(hook_fn) + + arg_to_compare = copy.deepcopy(input_tensor) + arg_to_compare.requires_grad = True + wrapper(arg_to_compare, arg_index) + # arg_to_compare.register_hook(hook_fn) + args_to_compare.append(arg_to_compare) + + for name, input_kwarg in input_kwargs.items(): + + def wrapper(param, name): + + def hook_fn(grad): + grad_dict[name] = grad + + param.register_hook(hook_fn) + + kwarg_to_compare = copy.deepcopy(input_kwarg) + kwarg_to_compare.requires_grad = True + wrapper(kwarg_to_compare, name) + kwargs_to_compare[name] = kwarg_to_compare + + return model_to_compare, args_to_compare, kwargs_to_compare + + +def numerical_test_for_node_strategy(model: torch.nn.Module, + device_mesh: DeviceMesh, + node_index: int, + strategy_number: int, + input_args: List[torch.Tensor], + meta_arg_names: List[str], + input_kwargs: Dict[str, torch.Tensor] = {}): + for strategy_index in range(strategy_number): + print(f'#strategy_index: {strategy_index}') + # We need to copy the model to avoid do backward more than once in same graph + grad_to_compare_dict = {} + grad_to_shard_dict = {} + model_to_compare, args_to_compare, kwargs_to_compare = _build_model_to_compare( + model, input_args, input_kwargs, grad_to_compare_dict) + model_to_shard, args_to_shard, kwargs_to_shard = _build_model_to_compare(model, input_args, input_kwargs, + grad_to_shard_dict) + + zero_tensor = torch.Tensor(0).cuda() + + tracer = ColoTracer() + input_sample = {} + for input_arg, meta_arg_name in zip(input_args, meta_arg_names): + input_sample[meta_arg_name] = torch.rand(input_arg.shape).to('meta') + for meta_kwarg_name, input_kwarg in input_kwargs.items(): + input_sample[meta_kwarg_name] = torch.rand(input_kwarg.shape).to('meta') + graph = tracer.trace(root=model_to_shard, meta_args=input_sample) + gm = GraphModule(model_to_shard, graph, model_to_shard.__class__.__name__) + solver_options = SolverOptions(fast=True) + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + target_node = list(graph.nodes)[node_index] + + # solution construction + solution_len = len(strategies_constructor.leaf_strategies) + solution = [0] * solution_len + solution[node_index] = strategy_index + gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass( + gm, solution, device_mesh) + gm = runtime_apply_pass(gm) + gm.recompile() + + # forward result compare + output = gm(*args_to_shard, + sharding_spec_convert_dict=sharding_spec_dict, + origin_node_sharding_spec_dict=origin_spec_dict, + comm_actions_dict=comm_actions_dict, + **kwargs_to_shard) + # except: + # print(gm) + output_to_compare = model_to_compare(*args_to_compare, **kwargs_to_compare) + assert_close((output - output_to_compare).sum(), zero_tensor) + + # backward result compare + loss = output.sum() + loss_to_compare = output_to_compare.sum() + loss.backward() + loss_to_compare.backward() + for key in grad_to_shard_dict.keys(): + grad_to_shard = grad_to_shard_dict[key] + grad_to_compare = grad_to_compare_dict[key] + assert_close((grad_to_shard - grad_to_compare).sum(), zero_tensor) + + # extract the strategy used in this iter + strategy_in_use = target_node.strategies_vector[strategy_index] + param_to_shard_dict = dict(model_to_shard.named_parameters()) + param_to_compare_dict = dict(model_to_compare.named_parameters()) + for name in param_to_shard_dict.keys(): + param_name = name.split('.')[-1] + param_sharding_spec = strategy_in_use.get_sharding_spec_by_name(param_name) + grad_sharded = param_to_shard_dict[name].grad + grad_to_compare = param_to_compare_dict[name].grad + global_grad = to_global(grad_sharded, param_sharding_spec) + assert_close((global_grad - grad_to_compare).sum(), zero_tensor)