From 4851f2d60738d1d9ff78b9892683662433a78645 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Mon, 26 Dec 2022 21:57:39 +0800 Subject: [PATCH] [autoparallel] update_getattr_handler (#2193) --- .../passes/runtime_preparation_pass.py | 25 +++---- .../tensor_shard/node_handler/node_handler.py | 7 +- .../strategy/getattr_generator.py | 62 ++++++++++++---- .../test_node_handler/test_addmm_handler.py | 73 +++++++++++++------ .../test_node_handler/test_getattr_handler.py | 11 ++- .../test_node_handler/utils.py | 16 +++- 6 files changed, 136 insertions(+), 58 deletions(-) diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py index 0b898a43e..b29ff3a65 100644 --- a/colossalai/auto_parallel/passes/runtime_preparation_pass.py +++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py @@ -6,6 +6,7 @@ import torch from torch.fx import symbolic_trace from torch.fx.node import Node +from colossalai.auto_parallel.tensor_shard.constants import RESHAPE_FUNC_OP from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( CommAction, CommType, @@ -96,27 +97,23 @@ def _solution_annotatation(gm: torch.fx.GraphModule, # to the same strategy of the user node. if node.op == 'get_attr': assert len(target_sharding_specs) == 1, f'sharing weight is not supported in current version.' - new_sharding_spec = target_sharding_specs[0] - user_strategy = node.strategies_vector.successor_nodes[0].best_strategy - op_data_in_user = user_strategy.get_op_data_by_name(str(node)) - origin_node_sharding_spec_dict[index] = new_sharding_spec + target_node = node.strategies_vector.successor_nodes[0] + node_name = str(node) + if target_node.op == 'call_function' and target_node.target in RESHAPE_FUNC_OP: + node_name = str(target_node) + target_node = target_node.strategies_vector.successor_nodes[0] + user_strategy = target_node.best_strategy + op_data_in_user = user_strategy.get_op_data_by_name(node_name) origin_pending_strategy = node.best_strategy origin_op_data = origin_pending_strategy.get_op_data_by_name(str(node)) - new_sharding_specs = origin_pending_strategy.sharding_specs - new_sharding_specs[origin_op_data] = new_sharding_spec + new_communication_actions = {} if op_data_in_user in user_strategy.communication_actions: new_communication_action = user_strategy.communication_actions.pop(op_data_in_user) new_communication_action.arg_index = 0 new_communication_actions[origin_op_data] = new_communication_action - new_strategy = ShardingStrategy(name=str(new_sharding_spec.sharding_sequence), - sharding_specs=new_sharding_specs, - compute_cost=origin_pending_strategy.compute_cost, - communication_cost=origin_pending_strategy.communication_cost, - memory_cost=origin_pending_strategy.memory_cost, - communication_actions=new_communication_actions) - setattr(node, 'best_strategy', new_strategy) - setattr(node, 'sharding_spec', new_sharding_spec) + node.best_strategy.communication_actions = new_communication_actions + comm_action_dict = {} for op_data, comm_action in node.best_strategy.communication_actions.items(): comm_action_dict[op_data.name] = comm_action diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py index 6d603f63e..812b4b169 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py @@ -86,12 +86,7 @@ class NodeHandler(ABC): if prev_sharding_spec is None: return TrainCycleItem(fwd=0, bwd=0, total=0) elif isinstance(prev_sharding_spec, ShardingSpec): - if isinstance(data, torch.nn.parameter.Parameter): - # we won't compute the resharding cost for the parameters, - # since the parameters will be sharded before runtime and - # not converted during runtime. - return TrainCycleItem(fwd=0, bwd=0, total=0) - elif isinstance(data, torch.Tensor): + if isinstance(data, torch.Tensor): dtype = data.dtype size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() _, _, consistency_cost = shape_consistency_manager.shape_consistency( diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py index 753ab1726..bbeb9a639 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py @@ -1,6 +1,12 @@ from typing import List from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem +from colossalai.auto_parallel.tensor_shard.utils import ( + enumerate_all_possible_1d_sharding, + enumerate_all_possible_2d_sharding, + ignore_sharding_exception, +) +from colossalai.tensor.sharding_spec import ShardingSpecException from .strategy_generator import StrategyGenerator @@ -37,17 +43,47 @@ class GetattrGenerator(StrategyGenerator): memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost + @ignore_sharding_exception + def enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1): + # we check for the output logical shape to get the number of dimensions + dim_partition_list = [] + dim_size = len(self.op_data['output'].logical_shape) + + # enumerate all the 2D sharding cases + sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size) + dim_partition_list.extend(sharding_list_2d) + + # enumerate all the 1D sharding cases + sharding_list_1d_on_dim_0 = enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size) + dim_partition_list.extend(sharding_list_1d_on_dim_0) + sharding_list_1d_on_dim_1 = enumerate_all_possible_1d_sharding(mesh_dim_1, dim_size) + dim_partition_list.extend(sharding_list_1d_on_dim_1) + + # add empty dict for fully replicated case + dim_partition_list.append({}) + + # sharding strategy bookkeeping + strategy_list = [] + + # convert these dim partition dict to sharding strategy + for dim_partition_dict in dim_partition_list: + dim_partition_dict_mapping = dict(output=dim_partition_dict) + + try: + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + communication_action_mapping = {} + + # get name + name = f"get_attr {sharding_spec_mapping['output'].sharding_sequence}" + sharding_strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(sharding_strategy) + except ShardingSpecException: + continue + + return strategy_list + def collate_strategies(self) -> List[ShardingStrategy]: - dim_partition_dict_mapping = { - "output": {}, - } - communication_action_mapping = {} - sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) - - name = 'Replica Attribute' - - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) - - return [strategy] + return self.enumerate_all_possible_output(0, 1) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py index 767864296..a555db776 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py @@ -35,25 +35,59 @@ class AddmmModel(nn.Module): return x -def check_linear_function_handler(rank, input_shape, world_size, port): +class AddmmModel_with_param(nn.Module): + + def __init__(self, weight_shape, bias_shape): + super().__init__() + self.weight = torch.nn.Parameter(torch.rand(weight_shape)) + self.bias = torch.nn.Parameter(torch.rand(bias_shape)) + + def forward(self, m1): + x = torch.addmm(self.bias, m1, self.weight, beta=3, alpha=2) + return x + + +def check_addmm_function_handler(rank, input_shape, model_cls, world_size, port): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - model = AddmmModel().cuda() + if model_cls == AddmmModel: + model = AddmmModel().cuda() + else: + model = AddmmModel_with_param(weight_shape=(8, 16), bias_shape=input_shape).cuda() physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - input = torch.rand(input_shape).cuda() - m1 = torch.rand(4, 8).cuda() - m2 = torch.rand(8, 16).cuda() - # the index of addmm node in computation graph - node_index = 4 - # strategy number of linear node - strategy_number = 14 - # construct input args - input_args = [input, m1, m2] - # construct meta arg names - meta_arg_names = ['input', 'm1', 'm2'] + if model_cls == AddmmModel: + input = torch.rand(input_shape).cuda() + m1 = torch.rand(4, 8).cuda() + m2 = torch.rand(8, 16).cuda() + # construct input args + input_args = [input, m1, m2] + # construct meta arg names + meta_arg_names = ['input', 'm1', 'm2'] + meta_args_for_tracer = {} + for meta_arg, input_arg in zip(meta_arg_names, input_args): + meta_args_for_tracer[meta_arg] = input_arg.to('meta') + + # the index of addmm node in computation graph + node_index = 4 + # strategy number of linear node + strategy_number = 14 + else: + m1 = torch.rand(4, 8).cuda() + # construct input args + input_args = [m1] + # construct meta arg names + meta_arg_names = ['m1'] + # the index of addmm node in computation graph + meta_args_for_tracer = {} + for meta_arg, input_arg in zip(meta_arg_names, input_args): + meta_args_for_tracer[meta_arg] = input_arg.to('meta') + node_index = 4 + # strategy number of linear node + strategy_number = 14 + numerical_test_for_node_strategy(model=model, device_mesh=device_mesh, node_index=node_index, @@ -73,12 +107,7 @@ def check_linear_function_handler(rank, input_shape, world_size, port): # %mul_1 : [#users=1] = call_function[target=operator.mul](args = (2, %linear), kwargs = {}) # %add : [#users=1] = call_function[target=operator.add](args = (%mul_1, %mul), kwargs = {}) # return add - graph = tracer.trace(model, - meta_args={ - "input": torch.rand(input_shape).to('meta'), - 'm1': torch.rand(4, 8).to('meta'), - 'm2': torch.rand(8, 16).to('meta'), - }) + graph = tracer.trace(model, meta_args=meta_args_for_tracer) gm = ColoGraphModule(model, graph) # [input_1, m1, m2, addmm, output] node_list = list(graph.nodes) @@ -155,11 +184,13 @@ def check_linear_function_handler(rank, input_shape, world_size, port): @run_on_environment_flag(name='AUTO_PARALLEL') @pytest.mark.dist @parameterize('input_shape', [(16,), (4, 16)]) +@parameterize('model_cls', [AddmmModel, AddmmModel_with_param]) @rerun_if_address_is_in_use() -def test_addmm_handler(input_shape): +def test_addmm_handler(input_shape, model_cls): world_size = 4 - run_func_function = partial(check_linear_function_handler, + run_func_function = partial(check_addmm_function_handler, input_shape=input_shape, + model_cls=model_cls, world_size=world_size, port=free_port()) mp.spawn(run_func_function, nprocs=world_size) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py index ad093c2ed..d3af5ac6f 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py @@ -39,6 +39,7 @@ def test_getattr_handler(): strategies_vector=getattr_strategies_vector) getattr_handler.register_strategy(compute_resharding_cost=False) + # check operation data mapping mapping = getattr_handler.get_operation_data_mapping() @@ -51,7 +52,15 @@ def test_getattr_handler(): assert mapping['output'].data.shape == torch.Size((16, 4, 3, 3)) assert mapping['output'].type == OperationDataType.OUTPUT strategy_name_list = [val.name for val in getattr_handler.strategies_vector] - assert "Replica Attribute" in strategy_name_list + assert 'get_attr [S0, S1, R, R]' in strategy_name_list + assert 'get_attr [S1, S0, R, R]' in strategy_name_list + assert 'get_attr [S01, R, R, R]' in strategy_name_list + assert 'get_attr [R, S01, R, R]' in strategy_name_list + assert 'get_attr [S0, R, R, R]' in strategy_name_list + assert 'get_attr [R, S0, R, R]' in strategy_name_list + assert 'get_attr [S1, R, R, R]' in strategy_name_list + assert 'get_attr [R, S1, R, R]' in strategy_name_list + assert 'get_attr [R, R, R, R]' in strategy_name_list if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py index 9d9a625a4..d02e1e31e 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py @@ -149,10 +149,20 @@ def numerical_test_for_node_strategy(model: torch.nn.Module, param_sharding_spec = strategy_in_use.get_sharding_spec_by_name(param_name) else: if 'weight' in name: - param_sharding_spec = list(graph.nodes)[4].sharding_spec - elif 'bias' in name: - param_sharding_spec = list(graph.nodes)[5].sharding_spec + param_sharding_spec = None + for node in list(graph.nodes): + if 'weight' in node.name: + param_sharding_spec = node.sharding_spec + + elif 'bias' in name: + param_sharding_spec = None + + for node in list(graph.nodes): + if 'bias' in node.name: + param_sharding_spec = node.sharding_spec + + assert param_sharding_spec is not None grad_sharded = param_to_shard_dict[name].grad grad_to_compare = param_to_compare_dict[name].grad global_grad = to_global(grad_sharded, param_sharding_spec)