[fix] fix require_grad & deallocate call;

pull/6065/head
duanjunwen 2024-09-19 05:53:03 +00:00
parent 1f5c7258aa
commit 6ee9584b9a
2 changed files with 16 additions and 33 deletions

View File

@ -137,7 +137,7 @@ def require_grad(x: Any) -> None:
Args:
x (Any): Object to be called.
"""
if isinstance(x, torch.Tensor) and x.requires_grad:
if isinstance(x, torch.Tensor) and not x.requires_grad:
x.requires_grad_()

View File

@ -12,7 +12,18 @@ from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.schedule.v_schedule import ScheduledNode
from colossalai.pipeline.stage_manager import PipelineStageManager
from ._utils import clone, detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
from ._utils import (
clone,
deallocate,
detach,
get_batch_size,
get_micro_batch,
merge_batch,
model_forward,
require_grad,
retain_grad,
to_device,
)
from .base import PipelineSchedule
AUTO_SCHEDULE_COMMUNICATION_TYPES = {"RECV_FORWARD", "RECV_BACKWARD", "SEND_FORWARD", "SEND_BACKWARD"}
@ -24,35 +35,6 @@ def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None:
req.wait()
def deallocate_output_tensor(out, deallocate_pipeline_outputs=False):
"""Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
This method should be called right after the output tensor has been
sent to the next pipeline stage. At this point, the output tensor is
only useful for its '.grad_fn' field, and not its '.data'.
"""
if (out is None) or (not deallocate_pipeline_outputs):
return
assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__
assert out._base is None, "counter-productive to free a view of another tensor."
# out.data = torch.empty((1,), device=out.device, dtype=out.dtype,)
out.data.untyped_storage().resize_(0)
def require_grad(tensor):
"""Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
This method should be called right after the output tensor has been
sent to the next pipeline stage. At this point, the output tensor is
only useful for its '.grad_fn' field, and not its '.data'.
"""
if tensor is None:
return
assert isinstance(tensor, torch.Tensor), "expected Tensor, found %s." % type(tensor).__name__
assert tensor._base is None, "counter-productive to free a view of another tensor."
tensor.requires_grad_()
class ZeroBubbleVPipeScheduler(PipelineSchedule):
def __init__(
self,
@ -590,7 +572,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
# Here, let input_obj.requires_grad_()
if input_obj is not None:
# if input_obj is not None:
if not isinstance(input_obj, torch.Tensor):
tree_map(require_grad, input_obj)
# Also requires_grad_ for micro_batch in stage 0 chunk 0 fwd,
@ -614,7 +597,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
pass
else:
# deallocate output
tree_map(partial(deallocate_output_tensor, deallocate_pipeline_outputs=True), deallocate_output_obj)
tree_map(deallocate, deallocate_output_obj)
# add input and output object for backward b
if input_obj is not None: