[fix] rm output.data after send fwd;

pull/6034/head
duanjunwen 2024-09-03 14:12:17 +08:00
parent a48afc4a66
commit ab643c9af7
3 changed files with 25 additions and 49 deletions

View File

@ -25,6 +25,24 @@ def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None:
req.wait()
def deallocate_output_tensor(out, deallocate_pipeline_outputs=False):
"""Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
This method should be called right after the output tensor has been
sent to the next pipeline stage. At this point, the output tensor is
only useful for its '.grad_fn' field, and not its '.data'.
"""
if (out is None) or (not deallocate_pipeline_outputs):
print(
f"(out is None) or (not deallocate_pipeline_outputs): {(out is None) or (not deallocate_pipeline_outputs)}"
)
return
assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__
assert out._base is None, "counter-productive to free a view of another tensor."
# out.data = torch.empty((1,), device=out.device, dtype=out.dtype,)
out.data.storage().resize_(0)
class ZeroBubbleVPipeScheduler(PipelineSchedule):
def __init__(
self,
@ -562,10 +580,13 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
)
# add input and output object for backward b
self.input_tensors[model_chunk_id].append(input_obj)
self.output_tensors[model_chunk_id].append(output_obj)
# detached output; for bwd b&w, we only need the graph(grad_fn) of output_obj
detached_output_obj = output_obj.clone()
deallocate_output_tensor(detached_output_obj, deallocate_pipeline_outputs=True)
self.output_tensors[model_chunk_id].append(detached_output_obj)
# add output object for backward w
self.output_tensors_dw[model_chunk_id].append(output_obj)
self.output_tensors_dw[model_chunk_id].append(detached_output_obj)
# Step3: send fwd
# add output to send_fwd_buffer

View File

@ -2,8 +2,7 @@ from .albert import *
from .bert import *
from .blip2 import *
from .bloom import *
# from .chatglm2 import *
from .chatglm2 import *
from .command import *
from .deepseek import *
from .falcon import *

View File

@ -14,7 +14,6 @@ from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
@ -701,56 +700,13 @@ def run_with_hybridplugin(test_config):
],
)
def run_with_moehybridplugin(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
model_zoo.get_sub_registry("transformers_bert")
test_config["use_lazy_init"] = False
test_config["pp_size"] = 1 # Do NOT test Pipeline Parallel
test_config["initial_scale"] = 2**16 # avoid overflow
model_list = [
"transformers_bert",
]
clear_layout_converter()
torch.set_default_dtype(torch.bfloat16)
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
data_gen_fn()
# print(f"data {data}")
# if name in model_list:
# (
# org_model,
# org_optimizer,
# sharded_model,
# sharded_optimizer,
# criterion,
# booster,
# ) = build_model_from_hybrid_plugin(model_fn, loss_fn, test_config, torch.optim.SGD, torch.optim.SGD)
# org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
# org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
# )
# stage_manager = booster.plugin.stage_manager
# tp_group = booster.plugin.tp_group
# bert = unwrap_model(org_model, "BertModel", "bert")
# sharded_bert = unwrap_model(sharded_model, "BertModel", "bert")
# weight_layer_for_check = ["encoder.layer[0].output.dense", "encoder.layer[1].output.dense"]
# org_optimizer.step()
# sharded_optimizer.step()
# # check weights
# if test_config["precision"] == "bf16":
# atol, rtol = 5e-4, 5e-4
# else:
# atol, rtol = 5e-4, 5e-4
# if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
# check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1)
# # check optim states
# # check_dist_optim_state(org_optimizer, sharded_optimizer.optim)
# clear_layout_converter()
# Randomizer.reset_index()
# torch.cuda.empty_cache()
# print(f"Bert Model Zoo Test Passed")
# TODO:6) support booster & Hybrid base 4)