[fix] fix input_tensors buffer append input_obj(dict) --> Tuple (microbatch, input_obj) , and all bwd b related cal logic;

pull/6065/head
duanjunwen 2024-09-20 06:41:19 +00:00
parent 4753bf7add
commit 26783776f1
1 changed files with 32 additions and 28 deletions

View File

@ -458,6 +458,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk: Union[ModuleList, Module],
model_chunk_id: int,
optimizer: OptimizerWrapper,
micro_batch: Optional[dict],
input_obj: Optional[dict],
output_obj: Union[dict, torch.Tensor],
output_obj_grad: Optional[dict],
@ -468,7 +469,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk (ModuleList or Module): Model Chunk to be run;
model_chunk_id (int): The current model chunk idx;
optimizer (OptimizerWrapper): Optimizer to update the model
input_obj (Optional[dict]): x.
input_obj (Optional[Tuple(dict)]): x. (microbatch, input_obj)
output_obj (Union[dict, torch.Tensor]): y.
output_obj_grad (dict): dy.
@ -477,10 +478,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
"""
# calculate bwd b step ; only dx = w*dy;
# Retain the grad on the input_obj.
if input_obj is None:
return None
else:
# Retain the grad on the input_obj. No need retain_grad microbatch
if input_obj is not None:
tree_map(retain_grad, input_obj)
# x, y, dy list for backward_by_grad; Type: list[tensor];
@ -488,22 +487,28 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
output_obj_ = []
output_obj_grad_ = []
# get x from input_obj to input_obj_
# For chunk 0 stage 0, use micro_batch as input_obj_
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
for k, v in micro_batch.items():
if v.requires_grad:
input_obj_.append(micro_batch[k])
output_obj_.append(output_obj[k]) # y
output_obj_grad_.append(output_obj_grad[k]) # dy
# For loss backward; output_obj is loss; output_obj_grad should be None
elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
assert output_obj_grad is None
for k, v in input_obj.items():
if v.requires_grad:
input_obj_.append(input_obj[k])
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# loss backward; output_obj is loss; so output_obj_grad should be None
assert output_obj_grad is None
output_obj_grad_.append(output_obj_grad) # None
output_obj_.append(output_obj) # LOSS
output_obj_grad_.append(output_obj_grad) # None
# For other chunk stage, use input_obj as input_obj_;
else:
for k, v in input_obj.items():
if v.requires_grad:
output_obj_.append(output_obj[k])
output_obj_grad_.append(output_obj_grad[k])
input_obj_.append(input_obj[k])
output_obj_.append(output_obj[k]) # y
output_obj_grad_.append(output_obj_grad[k]) # dy
optimizer.backward_by_grad(
tensor=output_obj_,
@ -512,9 +517,13 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
retain_graph=True,
)
# format output_obj_grad
if input_obj is not None:
# Format output_obj_grad
input_obj_grad = {}
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
for k, v in micro_batch.items():
if isinstance(v, torch.Tensor) and v.grad is not None:
input_obj_grad[k] = v.grad
else:
for k, v in input_obj.items():
if isinstance(v, torch.Tensor) and v.grad is not None:
input_obj_grad[k] = v.grad
@ -551,10 +560,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
output_obj_.append(output_obj) # LOSS
output_obj_grad_.append(None) # None
else:
# for k, v in input_obj.items():
# if v.requires_grad:
# output_obj_.append(output_obj[k])
# output_obj_grad_.append(output_obj_grad[k])
for k, v in output_obj.items():
if v.requires_grad:
output_obj_.append(output_obj[k])
@ -634,10 +639,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
tree_map(deallocate, deallocate_output_obj)
# add input and output object for backward b
if input_obj is not None:
self.input_tensors[model_chunk_id].append(input_obj)
else:
self.input_tensors[model_chunk_id].append(micro_batch)
self.input_tensors[model_chunk_id].append((micro_batch, input_obj))
# for bwd b&w, we only need the graph(grad_fn) of output_obj
# Do not deallocate loss, deallocate other output_obj;
@ -703,7 +706,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
# get input and output object from buffer;
input_obj = self.input_tensors[model_chunk_id].pop(0)
micro_batch, input_obj = self.input_tensors[model_chunk_id].pop(0)
output_obj = self.output_tensors[model_chunk_id].pop(0)
# save output_tensor_grad for dw
@ -719,6 +722,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk=model_chunk,
model_chunk_id=model_chunk_id,
optimizer=optimizer,
micro_batch=micro_batch,
input_obj=input_obj,
output_obj=output_obj,
output_obj_grad=output_tensor_grad,