[fix] fix input_tensors buffer append input_obj(dict) --> Tuple (microbatch, input_obj) , and all bwd b related cal logic;

pull/6065/head
duanjunwen 2024-09-20 06:41:19 +00:00
parent 4753bf7add
commit 26783776f1
1 changed files with 32 additions and 28 deletions

View File

@ -458,6 +458,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk: Union[ModuleList, Module], model_chunk: Union[ModuleList, Module],
model_chunk_id: int, model_chunk_id: int,
optimizer: OptimizerWrapper, optimizer: OptimizerWrapper,
micro_batch: Optional[dict],
input_obj: Optional[dict], input_obj: Optional[dict],
output_obj: Union[dict, torch.Tensor], output_obj: Union[dict, torch.Tensor],
output_obj_grad: Optional[dict], output_obj_grad: Optional[dict],
@ -468,7 +469,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk (ModuleList or Module): Model Chunk to be run; model_chunk (ModuleList or Module): Model Chunk to be run;
model_chunk_id (int): The current model chunk idx; model_chunk_id (int): The current model chunk idx;
optimizer (OptimizerWrapper): Optimizer to update the model optimizer (OptimizerWrapper): Optimizer to update the model
input_obj (Optional[dict]): x. input_obj (Optional[Tuple(dict)]): x. (microbatch, input_obj)
output_obj (Union[dict, torch.Tensor]): y. output_obj (Union[dict, torch.Tensor]): y.
output_obj_grad (dict): dy. output_obj_grad (dict): dy.
@ -477,10 +478,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
""" """
# calculate bwd b step ; only dx = w*dy; # calculate bwd b step ; only dx = w*dy;
# Retain the grad on the input_obj. # Retain the grad on the input_obj. No need retain_grad microbatch
if input_obj is None: if input_obj is not None:
return None
else:
tree_map(retain_grad, input_obj) tree_map(retain_grad, input_obj)
# x, y, dy list for backward_by_grad; Type: list[tensor]; # x, y, dy list for backward_by_grad; Type: list[tensor];
@ -488,22 +487,28 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
output_obj_ = [] output_obj_ = []
output_obj_grad_ = [] output_obj_grad_ = []
# get x from input_obj to input_obj_ # For chunk 0 stage 0, use micro_batch as input_obj_
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
for k, v in micro_batch.items():
if v.requires_grad:
input_obj_.append(micro_batch[k])
output_obj_.append(output_obj[k]) # y
output_obj_grad_.append(output_obj_grad[k]) # dy
# For loss backward; output_obj is loss; output_obj_grad should be None
elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
assert output_obj_grad is None
for k, v in input_obj.items(): for k, v in input_obj.items():
if v.requires_grad: if v.requires_grad:
input_obj_.append(input_obj[k]) input_obj_.append(input_obj[k])
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# loss backward; output_obj is loss; so output_obj_grad should be None
assert output_obj_grad is None
output_obj_grad_.append(output_obj_grad) # None
output_obj_.append(output_obj) # LOSS output_obj_.append(output_obj) # LOSS
output_obj_grad_.append(output_obj_grad) # None
# For other chunk stage, use input_obj as input_obj_;
else: else:
for k, v in input_obj.items(): for k, v in input_obj.items():
if v.requires_grad: if v.requires_grad:
output_obj_.append(output_obj[k]) input_obj_.append(input_obj[k])
output_obj_grad_.append(output_obj_grad[k]) output_obj_.append(output_obj[k]) # y
output_obj_grad_.append(output_obj_grad[k]) # dy
optimizer.backward_by_grad( optimizer.backward_by_grad(
tensor=output_obj_, tensor=output_obj_,
@ -512,9 +517,13 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
retain_graph=True, retain_graph=True,
) )
# format output_obj_grad # Format output_obj_grad
if input_obj is not None:
input_obj_grad = {} input_obj_grad = {}
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
for k, v in micro_batch.items():
if isinstance(v, torch.Tensor) and v.grad is not None:
input_obj_grad[k] = v.grad
else:
for k, v in input_obj.items(): for k, v in input_obj.items():
if isinstance(v, torch.Tensor) and v.grad is not None: if isinstance(v, torch.Tensor) and v.grad is not None:
input_obj_grad[k] = v.grad input_obj_grad[k] = v.grad
@ -551,10 +560,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
output_obj_.append(output_obj) # LOSS output_obj_.append(output_obj) # LOSS
output_obj_grad_.append(None) # None output_obj_grad_.append(None) # None
else: else:
# for k, v in input_obj.items():
# if v.requires_grad:
# output_obj_.append(output_obj[k])
# output_obj_grad_.append(output_obj_grad[k])
for k, v in output_obj.items(): for k, v in output_obj.items():
if v.requires_grad: if v.requires_grad:
output_obj_.append(output_obj[k]) output_obj_.append(output_obj[k])
@ -634,10 +639,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
tree_map(deallocate, deallocate_output_obj) tree_map(deallocate, deallocate_output_obj)
# add input and output object for backward b # add input and output object for backward b
if input_obj is not None:
self.input_tensors[model_chunk_id].append(input_obj) self.input_tensors[model_chunk_id].append((micro_batch, input_obj))
else:
self.input_tensors[model_chunk_id].append(micro_batch)
# for bwd b&w, we only need the graph(grad_fn) of output_obj # for bwd b&w, we only need the graph(grad_fn) of output_obj
# Do not deallocate loss, deallocate other output_obj; # Do not deallocate loss, deallocate other output_obj;
@ -703,7 +706,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
# get input and output object from buffer; # get input and output object from buffer;
input_obj = self.input_tensors[model_chunk_id].pop(0) micro_batch, input_obj = self.input_tensors[model_chunk_id].pop(0)
output_obj = self.output_tensors[model_chunk_id].pop(0) output_obj = self.output_tensors[model_chunk_id].pop(0)
# save output_tensor_grad for dw # save output_tensor_grad for dw
@ -719,6 +722,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk=model_chunk, model_chunk=model_chunk,
model_chunk_id=model_chunk_id, model_chunk_id=model_chunk_id,
optimizer=optimizer, optimizer=optimizer,
micro_batch=micro_batch,
input_obj=input_obj, input_obj=input_obj,
output_obj=output_obj, output_obj=output_obj,
output_obj_grad=output_tensor_grad, output_obj_grad=output_tensor_grad,