Browse Source

[fix] rm output.data after send fwd;

pull/6034/head
duanjunwen 3 months ago
parent
commit
ab643c9af7
  1. 25
      colossalai/pipeline/schedule/zero_bubble_pp.py
  2. 3
      tests/kit/model_zoo/transformers/__init__.py
  3. 46
      tests/test_pipeline/test_schedule/test_zerobubble_pp.py

25
colossalai/pipeline/schedule/zero_bubble_pp.py

@ -25,6 +25,24 @@ def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None:
req.wait()
def deallocate_output_tensor(out, deallocate_pipeline_outputs=False):
"""Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
This method should be called right after the output tensor has been
sent to the next pipeline stage. At this point, the output tensor is
only useful for its '.grad_fn' field, and not its '.data'.
"""
if (out is None) or (not deallocate_pipeline_outputs):
print(
f"(out is None) or (not deallocate_pipeline_outputs): {(out is None) or (not deallocate_pipeline_outputs)}"
)
return
assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__
assert out._base is None, "counter-productive to free a view of another tensor."
# out.data = torch.empty((1,), device=out.device, dtype=out.dtype,)
out.data.storage().resize_(0)
class ZeroBubbleVPipeScheduler(PipelineSchedule):
def __init__(
self,
@ -562,10 +580,13 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
)
# add input and output object for backward b
self.input_tensors[model_chunk_id].append(input_obj)
self.output_tensors[model_chunk_id].append(output_obj)
# detached output; for bwd b&w, we only need the graph(grad_fn) of output_obj
detached_output_obj = output_obj.clone()
deallocate_output_tensor(detached_output_obj, deallocate_pipeline_outputs=True)
self.output_tensors[model_chunk_id].append(detached_output_obj)
# add output object for backward w
self.output_tensors_dw[model_chunk_id].append(output_obj)
self.output_tensors_dw[model_chunk_id].append(detached_output_obj)
# Step3: send fwd
# add output to send_fwd_buffer

3
tests/kit/model_zoo/transformers/__init__.py

@ -2,8 +2,7 @@ from .albert import *
from .bert import *
from .blip2 import *
from .bloom import *
# from .chatglm2 import *
from .chatglm2 import *
from .command import *
from .deepseek import *
from .falcon import *

46
tests/test_pipeline/test_schedule/test_zerobubble_pp.py

@ -14,7 +14,6 @@ from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
@ -701,56 +700,13 @@ def run_with_hybridplugin(test_config):
],
)
def run_with_moehybridplugin(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
model_zoo.get_sub_registry("transformers_bert")
test_config["use_lazy_init"] = False
test_config["pp_size"] = 1 # Do NOT test Pipeline Parallel
test_config["initial_scale"] = 2**16 # avoid overflow
model_list = [
"transformers_bert",
]
clear_layout_converter()
torch.set_default_dtype(torch.bfloat16)
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
data_gen_fn()
# print(f"data {data}")
# if name in model_list:
# (
# org_model,
# org_optimizer,
# sharded_model,
# sharded_optimizer,
# criterion,
# booster,
# ) = build_model_from_hybrid_plugin(model_fn, loss_fn, test_config, torch.optim.SGD, torch.optim.SGD)
# org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
# org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
# )
# stage_manager = booster.plugin.stage_manager
# tp_group = booster.plugin.tp_group
# bert = unwrap_model(org_model, "BertModel", "bert")
# sharded_bert = unwrap_model(sharded_model, "BertModel", "bert")
# weight_layer_for_check = ["encoder.layer[0].output.dense", "encoder.layer[1].output.dense"]
# org_optimizer.step()
# sharded_optimizer.step()
# # check weights
# if test_config["precision"] == "bf16":
# atol, rtol = 5e-4, 5e-4
# else:
# atol, rtol = 5e-4, 5e-4
# if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
# check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1)
# # check optim states
# # check_dist_optim_state(org_optimizer, sharded_optimizer.optim)
# clear_layout_converter()
# Randomizer.reset_index()
# torch.cuda.empty_cache()
# print(f"Bert Model Zoo Test Passed")
# TODO:6) support booster & Hybrid base 4)

Loading…
Cancel
Save