From ab643c9af74a57d7e5fcdbf38c31b596db819a5b Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 3 Sep 2024 14:12:17 +0800 Subject: [PATCH] [fix] rm output.data after send fwd; --- .../pipeline/schedule/zero_bubble_pp.py | 25 +++++++++- tests/kit/model_zoo/transformers/__init__.py | 3 +- .../test_schedule/test_zerobubble_pp.py | 46 +------------------ 3 files changed, 25 insertions(+), 49 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index e24ca5ac1..2505be4d4 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -25,6 +25,24 @@ def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None: req.wait() +def deallocate_output_tensor(out, deallocate_pipeline_outputs=False): + """Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field. + + This method should be called right after the output tensor has been + sent to the next pipeline stage. At this point, the output tensor is + only useful for its '.grad_fn' field, and not its '.data'. + """ + if (out is None) or (not deallocate_pipeline_outputs): + print( + f"(out is None) or (not deallocate_pipeline_outputs): {(out is None) or (not deallocate_pipeline_outputs)}" + ) + return + assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__ + assert out._base is None, "counter-productive to free a view of another tensor." + # out.data = torch.empty((1,), device=out.device, dtype=out.dtype,) + out.data.storage().resize_(0) + + class ZeroBubbleVPipeScheduler(PipelineSchedule): def __init__( self, @@ -562,10 +580,13 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): ) # add input and output object for backward b self.input_tensors[model_chunk_id].append(input_obj) - self.output_tensors[model_chunk_id].append(output_obj) + # detached output; for bwd b&w, we only need the graph(grad_fn) of output_obj + detached_output_obj = output_obj.clone() + deallocate_output_tensor(detached_output_obj, deallocate_pipeline_outputs=True) + self.output_tensors[model_chunk_id].append(detached_output_obj) # add output object for backward w - self.output_tensors_dw[model_chunk_id].append(output_obj) + self.output_tensors_dw[model_chunk_id].append(detached_output_obj) # Step3: send fwd # add output to send_fwd_buffer diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index 029968231..4adc38619 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -2,8 +2,7 @@ from .albert import * from .bert import * from .blip2 import * from .bloom import * - -# from .chatglm2 import * +from .chatglm2 import * from .command import * from .deepseek import * from .falcon import * diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 9d0d39199..d5b76f66c 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -14,7 +14,6 @@ from colossalai.logging import disable_existing_loggers from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo @@ -701,56 +700,13 @@ def run_with_hybridplugin(test_config): ], ) def run_with_moehybridplugin(test_config): - sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") + model_zoo.get_sub_registry("transformers_bert") test_config["use_lazy_init"] = False test_config["pp_size"] = 1 # Do NOT test Pipeline Parallel test_config["initial_scale"] = 2**16 # avoid overflow model_list = [ "transformers_bert", ] - clear_layout_converter() - torch.set_default_dtype(torch.bfloat16) - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - data_gen_fn() - # print(f"data {data}") - # if name in model_list: - # ( - # org_model, - # org_optimizer, - # sharded_model, - # sharded_optimizer, - # criterion, - # booster, - # ) = build_model_from_hybrid_plugin(model_fn, loss_fn, test_config, torch.optim.SGD, torch.optim.SGD) - - # org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( - # org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster - # ) - - # stage_manager = booster.plugin.stage_manager - # tp_group = booster.plugin.tp_group - - # bert = unwrap_model(org_model, "BertModel", "bert") - # sharded_bert = unwrap_model(sharded_model, "BertModel", "bert") - # weight_layer_for_check = ["encoder.layer[0].output.dense", "encoder.layer[1].output.dense"] - - # org_optimizer.step() - # sharded_optimizer.step() - - # # check weights - # if test_config["precision"] == "bf16": - # atol, rtol = 5e-4, 5e-4 - # else: - # atol, rtol = 5e-4, 5e-4 - # if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): - # check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1) - # # check optim states - # # check_dist_optim_state(org_optimizer, sharded_optimizer.optim) - - # clear_layout_converter() - # Randomizer.reset_index() - # torch.cuda.empty_cache() - # print(f"Bert Model Zoo Test Passed") # TODO:6) support booster & Hybrid base 4)