From 3b1b91eaf4c5490bb2eeec28f234a6541a922047 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Wed, 28 Dec 2022 19:29:08 +0800 Subject: [PATCH] [autoparallel] record parameter attribute in colotracer (#2217) * [autoparallel] record parameter attribute in collotracer * [autoparallel] fix construct_meta_info bug --- .../passes/runtime_apply_pass.py | 4 ++-- colossalai/fx/tracer/tracer.py | 22 ++++++++++++++++++ .../test_gpt/test_gpt2_performance.py | 23 +++++++++++-------- .../test_gpt/test_runtime_with_gpt_modules.py | 2 -- .../test_gpt/test_solver_with_gpt_module.py | 3 +-- .../test_node_handler/test_addmm_handler.py | 5 +++- 6 files changed, 43 insertions(+), 16 deletions(-) diff --git a/colossalai/auto_parallel/passes/runtime_apply_pass.py b/colossalai/auto_parallel/passes/runtime_apply_pass.py index caf118c89..df4a3fde7 100644 --- a/colossalai/auto_parallel/passes/runtime_apply_pass.py +++ b/colossalai/auto_parallel/passes/runtime_apply_pass.py @@ -174,8 +174,8 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule): runtime_apply, args=(node, origin_dict_node, input_dict_node, node_to_index_dict[node], user_node_index)) - meta_info = construct_meta_info(node, user_node) - setattr(shape_consistency_node, 'best_metainfo', meta_info) + # meta_info = construct_meta_info(node, user_node) + # setattr(shape_consistency_node, 'best_metainfo', meta_info) new_args = list(user_node.args) new_kwargs = dict(user_node.kwargs) diff --git a/colossalai/fx/tracer/tracer.py b/colossalai/fx/tracer/tracer.py index bf6f9c23b..1ae31f958 100644 --- a/colossalai/fx/tracer/tracer.py +++ b/colossalai/fx/tracer/tracer.py @@ -229,6 +229,15 @@ class ColoTracer(Tracer): args_metas, kwargs_metas = extract_meta(*args, **kwargs) if kind == "call_function": + # Our meta data will not record the nn.parameter.Parameter attribute。 + # It works fine in most of the case, but it may cause some problems after + # the bias addition manipulation. + # Therefore, I need to record the nn.parameter.Parameter attribute for the operation + # added by the bias addition manipulation following the get_attr node. + convert_to_parameter = False + if target in (torch.transpose, torch.reshape) and isinstance(args_metas[0], + torch.nn.parameter.Parameter): + convert_to_parameter = True # fetch patched function if meta_patched_function.has(target): meta_target = meta_patched_function.get(target) @@ -241,7 +250,18 @@ class ColoTracer(Tracer): meta_out = meta_target(*args_metas, **kwargs_metas) if isinstance(meta_out, torch.Tensor): meta_out = meta_out.to(device="meta") + if convert_to_parameter: + meta_out = torch.nn.Parameter(meta_out) + elif kind == "call_method": + # Our meta data will not record the nn.parameter.Parameter attribute。 + # It works fine in most of the case, but it may cause some problems after + # the bias addition manipulation. + # Therefore, I need to record the nn.parameter.Parameter attribute for the operation + # added by the bias addition manipulation following the get_attr node. + convert_to_parameter = False + if target in (torch.Tensor.view,) and isinstance(args_metas[0], torch.nn.parameter.Parameter): + convert_to_parameter = True method = getattr(args_metas[0].__class__, target) # fetch patched method @@ -251,6 +271,8 @@ class ColoTracer(Tracer): meta_target = method meta_out = meta_target(*args_metas, **kwargs_metas) + if convert_to_parameter: + meta_out = torch.nn.Parameter(meta_out) elif kind == "call_module": if not hasattr(self, "orig_forward"): raise AttributeError(f"{self} does not have an attribute called orig_forward") diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_gpt2_performance.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_gpt2_performance.py index 87155307f..ac5b1d983 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_gpt2_performance.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_gpt2_performance.py @@ -35,13 +35,14 @@ from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.utils import free_port from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2LMHeadModel, GPTLMLoss -BATCH_SIZE = 128 -SEQ_LENGTH = 128 -HIDDEN_DIM = 4096 -NUM_HEADS = 32 +BATCH_SIZE = 32 +SEQ_LENGTH = 256 +HIDDEN_DIM = 16384 +NUM_HEADS = 128 NUM_LAYERS = 4 VOCAB_SIZE = 50257 NUM_STEPS = 10 +FP16 = True def get_cpu_mem(): @@ -57,7 +58,8 @@ def get_mem_info(prefix=''): def get_tflops(model_numel, batch_size, seq_len, step_time): - return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) + # Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu + return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) / 4 # Randomly Generated Data @@ -72,8 +74,11 @@ def main(): launch_from_torch(config={}) logger = get_dist_logger() config = transformers.GPT2Config(n_position=SEQ_LENGTH, n_layer=NUM_LAYERS, n_head=NUM_HEADS, n_embd=HIDDEN_DIM) - - model = GPT2LMHeadModel(config=config).to('cuda') + if FP16: + model = GPT2LMHeadModel(config=config).half().to('cuda') + else: + model = GPT2LMHeadModel(config=config).to('cuda') + global_numel = sum([p.numel() for p in model.parameters()]) input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) @@ -108,6 +113,7 @@ def main(): ret = solver.call_solver_serialized_args() solution = list(ret[0]) + # solution = [0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 2, 13, 8, 9, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 12, 8, 8, 8, 0, 0, 20, 12, 12, 12, 6, 6, 6, 6, 2, 6, 0, 0, 4, 0, 0, 0, 4, 0, 4, 3, 3, 12, 3, 3, 8, 8, 8, 8, 8, 8, 8, 8, 3, 8, 2, 2, 11, 4, 4, 0, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 12, 8, 8, 8, 0, 0, 20, 12, 12, 12, 6, 6, 6, 6, 2, 6, 0, 0, 4, 0, 0, 0, 4, 0, 4, 3, 3, 12, 3, 3, 8, 8, 8, 8, 8, 8, 8, 8, 3, 8, 2, 2, 11, 4, 4, 0, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 12, 8, 8, 8, 0, 0, 20, 12, 12, 12, 6, 6, 6, 6, 2, 6, 0, 0, 4, 0, 0, 0, 4, 0, 4, 3, 3, 12, 3, 3, 8, 8, 8, 8, 8, 8, 8, 8, 3, 8, 2, 2, 11, 4, 4, 0, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 12, 8, 8, 8, 0, 0, 20, 12, 12, 12, 6, 6, 6, 6, 2, 6, 0, 0, 4, 0, 0, 0, 4, 0, 4, 3, 3, 12, 3, 3, 8, 8, 8, 8, 8, 8, 8, 8, 3, 8, 2, 2, 11, 4, 4, 9, 0, 0, 8, 0] print(solution) gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass( gm, solution, device_mesh, strategies_constructor) @@ -125,9 +131,8 @@ def main(): criterion = GPTLMLoss() optimizer = torch.optim.Adam(gm.parameters(), lr=0.01) - numel = sum([p.numel() for p in model.parameters()]) logger.info(get_mem_info(prefix='After init model, '), ranks=[0]) - get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LENGTH) + get_tflops_func = partial(get_tflops, global_numel, BATCH_SIZE, SEQ_LENGTH) torch.cuda.synchronize() model.train() # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py index 361c22d26..c7f9988f1 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py @@ -102,13 +102,11 @@ def check_attention_layer(rank, model_cls, world_size, port): else: input_sample = ( input_ids.to('cuda'), - token_type_ids.to('cuda'), attention_mask.to('cuda'), ) test_input_sample = copy.deepcopy(input_sample) meta_input_sample = { 'input_ids': input_ids.to('meta'), - 'token_type_ids': token_type_ids.to('meta'), 'attention_mask': attention_mask.to('meta'), } diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py index 478b77e76..26ad0d3a0 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py @@ -50,9 +50,8 @@ def test_self_attention_block(model_cls): } else: input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) input_sample = {k: v.to('meta') for k, v in kwargs.items()} graph = tracer.trace(root=model, meta_args=input_sample) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py index a555db776..aa5a57474 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py @@ -130,7 +130,10 @@ def check_addmm_function_handler(rank, input_shape, model_cls, world_size, port) assert mapping['other'].name == "transpose" assert mapping['other'].data.shape == torch.Size([16, 8]) - assert mapping['other'].type == OperationDataType.ARG + if model_cls == AddmmModel: + assert mapping['other'].type == OperationDataType.ARG + else: + assert mapping['other'].type == OperationDataType.PARAM assert mapping['other'].logical_shape == torch.Size([8, 16]) assert mapping['output'].name == "linear"