mirror of https://github.com/hpcaitech/ColossalAI
[fix] fix pipeline util func deallocate --> release_tensor_data; fix bwd_b loss bwd branch;
parent
1739df423c
commit
da3220f48c
|
@ -169,8 +169,8 @@ def clone(x: Any) -> Any:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def deallocate(x: Any) -> Any:
|
def release_tensor_data(x: Any) -> Any:
|
||||||
"""Call deallocate() on a tensor.
|
"""Call untyped_storage().resize_(0) on a tensor. Use to release tensor.data and keep grad_fn.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (Any): Object to be called.
|
x (Any): Object to be called.
|
||||||
|
|
|
@ -14,12 +14,12 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
|
||||||
from ._utils import (
|
from ._utils import (
|
||||||
clone,
|
clone,
|
||||||
deallocate,
|
|
||||||
detach,
|
detach,
|
||||||
get_batch_size,
|
get_batch_size,
|
||||||
get_micro_batch,
|
get_micro_batch,
|
||||||
merge_batch,
|
merge_batch,
|
||||||
model_forward,
|
model_forward,
|
||||||
|
release_tensor_data,
|
||||||
require_grad,
|
require_grad,
|
||||||
retain_grad,
|
retain_grad,
|
||||||
to_device,
|
to_device,
|
||||||
|
@ -488,8 +488,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
assert output_obj_grad is None
|
assert output_obj_grad is None
|
||||||
input_obj_, _ = tree_flatten(input_obj)
|
input_obj_, _ = tree_flatten(input_obj)
|
||||||
output_obj_, _ = tree_flatten(output_obj) # LOSS
|
output_obj_.append(output_obj) # LOSS
|
||||||
output_obj_grad_, _ = tree_flatten(output_obj_grad) # None
|
output_obj_grad_.append(output_obj_grad) # None
|
||||||
|
|
||||||
# For other chunk stage, use input_obj as input_obj_;
|
# For other chunk stage, use input_obj as input_obj_;
|
||||||
else:
|
else:
|
||||||
|
@ -614,20 +614,20 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Step3: deallocate output for bwd b & w; (do not detach output)
|
# Step3: release_tensor_data output for bwd b & w; (do not detach output)
|
||||||
deallocate_output_obj = tree_map(clone, output_obj)
|
deallocate_output_obj = tree_map(clone, output_obj)
|
||||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
# We should not deallocate bwd LOSS
|
# We should not release_tensor_data bwd LOSS
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
# deallocate output
|
# release_tensor_data output
|
||||||
tree_map(deallocate, deallocate_output_obj)
|
tree_map(release_tensor_data, deallocate_output_obj)
|
||||||
|
|
||||||
# add input and output object for backward b
|
# add input and output object for backward b
|
||||||
self.input_tensors[model_chunk_id].append((micro_batch, input_obj))
|
self.input_tensors[model_chunk_id].append((micro_batch, input_obj))
|
||||||
|
|
||||||
# for bwd b&w, we only need the graph(grad_fn) of output_obj
|
# for bwd b&w, we only need the graph(grad_fn) of output_obj
|
||||||
# Do not deallocate loss, deallocate other output_obj;
|
# Do not release_tensor_data loss, release_tensor_data other output_obj;
|
||||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
self.output_tensors[model_chunk_id].append(deallocate_output_obj)
|
self.output_tensors[model_chunk_id].append(deallocate_output_obj)
|
||||||
self.output_tensors_dw[model_chunk_id].append(deallocate_output_obj)
|
self.output_tensors_dw[model_chunk_id].append(deallocate_output_obj)
|
||||||
|
|
Loading…
Reference in New Issue