[fix] fix pipeline util func deallocate --> release_tensor_data; fix bwd_b loss bwd branch;

pull/6065/head
duanjunwen 2024-09-20 09:48:35 +00:00
parent 1739df423c
commit da3220f48c
2 changed files with 10 additions and 10 deletions

View File

@ -169,8 +169,8 @@ def clone(x: Any) -> Any:
return x
def deallocate(x: Any) -> Any:
"""Call deallocate() on a tensor.
def release_tensor_data(x: Any) -> Any:
"""Call untyped_storage().resize_(0) on a tensor. Use to release tensor.data and keep grad_fn.
Args:
x (Any): Object to be called.

View File

@ -14,12 +14,12 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from ._utils import (
clone,
deallocate,
detach,
get_batch_size,
get_micro_batch,
merge_batch,
model_forward,
release_tensor_data,
require_grad,
retain_grad,
to_device,
@ -488,8 +488,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
assert output_obj_grad is None
input_obj_, _ = tree_flatten(input_obj)
output_obj_, _ = tree_flatten(output_obj) # LOSS
output_obj_grad_, _ = tree_flatten(output_obj_grad) # None
output_obj_.append(output_obj) # LOSS
output_obj_grad_.append(output_obj_grad) # None
# For other chunk stage, use input_obj as input_obj_;
else:
@ -614,20 +614,20 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
outputs=outputs,
)
# Step3: deallocate output for bwd b & w; (do not detach output)
# Step3: release_tensor_data output for bwd b & w; (do not detach output)
deallocate_output_obj = tree_map(clone, output_obj)
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# We should not deallocate bwd LOSS
# We should not release_tensor_data bwd LOSS
pass
else:
# deallocate output
tree_map(deallocate, deallocate_output_obj)
# release_tensor_data output
tree_map(release_tensor_data, deallocate_output_obj)
# add input and output object for backward b
self.input_tensors[model_chunk_id].append((micro_batch, input_obj))
# for bwd b&w, we only need the graph(grad_fn) of output_obj
# Do not deallocate loss, deallocate other output_obj;
# Do not release_tensor_data loss, release_tensor_data other output_obj;
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
self.output_tensors[model_chunk_id].append(deallocate_output_obj)
self.output_tensors_dw[model_chunk_id].append(deallocate_output_obj)