mirror of https://github.com/hpcaitech/ColossalAI
[fix] fix input_tensors buffer append input_obj(dict) --> Tuple (microbatch, input_obj) , and all bwd b related cal logic;
parent
4753bf7add
commit
26783776f1
|
@ -458,6 +458,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
model_chunk: Union[ModuleList, Module],
|
model_chunk: Union[ModuleList, Module],
|
||||||
model_chunk_id: int,
|
model_chunk_id: int,
|
||||||
optimizer: OptimizerWrapper,
|
optimizer: OptimizerWrapper,
|
||||||
|
micro_batch: Optional[dict],
|
||||||
input_obj: Optional[dict],
|
input_obj: Optional[dict],
|
||||||
output_obj: Union[dict, torch.Tensor],
|
output_obj: Union[dict, torch.Tensor],
|
||||||
output_obj_grad: Optional[dict],
|
output_obj_grad: Optional[dict],
|
||||||
|
@ -468,7 +469,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
model_chunk (ModuleList or Module): Model Chunk to be run;
|
model_chunk (ModuleList or Module): Model Chunk to be run;
|
||||||
model_chunk_id (int): The current model chunk idx;
|
model_chunk_id (int): The current model chunk idx;
|
||||||
optimizer (OptimizerWrapper): Optimizer to update the model
|
optimizer (OptimizerWrapper): Optimizer to update the model
|
||||||
input_obj (Optional[dict]): x.
|
input_obj (Optional[Tuple(dict)]): x. (microbatch, input_obj)
|
||||||
output_obj (Union[dict, torch.Tensor]): y.
|
output_obj (Union[dict, torch.Tensor]): y.
|
||||||
output_obj_grad (dict): dy.
|
output_obj_grad (dict): dy.
|
||||||
|
|
||||||
|
@ -477,10 +478,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
"""
|
"""
|
||||||
# calculate bwd b step ; only dx = w*dy;
|
# calculate bwd b step ; only dx = w*dy;
|
||||||
|
|
||||||
# Retain the grad on the input_obj.
|
# Retain the grad on the input_obj. No need retain_grad microbatch
|
||||||
if input_obj is None:
|
if input_obj is not None:
|
||||||
return None
|
|
||||||
else:
|
|
||||||
tree_map(retain_grad, input_obj)
|
tree_map(retain_grad, input_obj)
|
||||||
|
|
||||||
# x, y, dy list for backward_by_grad; Type: list[tensor];
|
# x, y, dy list for backward_by_grad; Type: list[tensor];
|
||||||
|
@ -488,22 +487,28 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
output_obj_ = []
|
output_obj_ = []
|
||||||
output_obj_grad_ = []
|
output_obj_grad_ = []
|
||||||
|
|
||||||
# get x from input_obj to input_obj_
|
# For chunk 0 stage 0, use micro_batch as input_obj_
|
||||||
|
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
|
for k, v in micro_batch.items():
|
||||||
|
if v.requires_grad:
|
||||||
|
input_obj_.append(micro_batch[k])
|
||||||
|
output_obj_.append(output_obj[k]) # y
|
||||||
|
output_obj_grad_.append(output_obj_grad[k]) # dy
|
||||||
|
# For loss backward; output_obj is loss; output_obj_grad should be None
|
||||||
|
elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
|
assert output_obj_grad is None
|
||||||
for k, v in input_obj.items():
|
for k, v in input_obj.items():
|
||||||
if v.requires_grad:
|
if v.requires_grad:
|
||||||
input_obj_.append(input_obj[k])
|
input_obj_.append(input_obj[k])
|
||||||
|
|
||||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
|
||||||
# loss backward; output_obj is loss; so output_obj_grad should be None
|
|
||||||
assert output_obj_grad is None
|
|
||||||
output_obj_grad_.append(output_obj_grad) # None
|
|
||||||
output_obj_.append(output_obj) # LOSS
|
output_obj_.append(output_obj) # LOSS
|
||||||
|
output_obj_grad_.append(output_obj_grad) # None
|
||||||
|
# For other chunk stage, use input_obj as input_obj_;
|
||||||
else:
|
else:
|
||||||
for k, v in input_obj.items():
|
for k, v in input_obj.items():
|
||||||
if v.requires_grad:
|
if v.requires_grad:
|
||||||
output_obj_.append(output_obj[k])
|
input_obj_.append(input_obj[k])
|
||||||
output_obj_grad_.append(output_obj_grad[k])
|
output_obj_.append(output_obj[k]) # y
|
||||||
|
output_obj_grad_.append(output_obj_grad[k]) # dy
|
||||||
|
|
||||||
optimizer.backward_by_grad(
|
optimizer.backward_by_grad(
|
||||||
tensor=output_obj_,
|
tensor=output_obj_,
|
||||||
|
@ -512,9 +517,13 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
retain_graph=True,
|
retain_graph=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# format output_obj_grad
|
# Format output_obj_grad
|
||||||
if input_obj is not None:
|
|
||||||
input_obj_grad = {}
|
input_obj_grad = {}
|
||||||
|
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
|
for k, v in micro_batch.items():
|
||||||
|
if isinstance(v, torch.Tensor) and v.grad is not None:
|
||||||
|
input_obj_grad[k] = v.grad
|
||||||
|
else:
|
||||||
for k, v in input_obj.items():
|
for k, v in input_obj.items():
|
||||||
if isinstance(v, torch.Tensor) and v.grad is not None:
|
if isinstance(v, torch.Tensor) and v.grad is not None:
|
||||||
input_obj_grad[k] = v.grad
|
input_obj_grad[k] = v.grad
|
||||||
|
@ -551,10 +560,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
output_obj_.append(output_obj) # LOSS
|
output_obj_.append(output_obj) # LOSS
|
||||||
output_obj_grad_.append(None) # None
|
output_obj_grad_.append(None) # None
|
||||||
else:
|
else:
|
||||||
# for k, v in input_obj.items():
|
|
||||||
# if v.requires_grad:
|
|
||||||
# output_obj_.append(output_obj[k])
|
|
||||||
# output_obj_grad_.append(output_obj_grad[k])
|
|
||||||
for k, v in output_obj.items():
|
for k, v in output_obj.items():
|
||||||
if v.requires_grad:
|
if v.requires_grad:
|
||||||
output_obj_.append(output_obj[k])
|
output_obj_.append(output_obj[k])
|
||||||
|
@ -634,10 +639,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
tree_map(deallocate, deallocate_output_obj)
|
tree_map(deallocate, deallocate_output_obj)
|
||||||
|
|
||||||
# add input and output object for backward b
|
# add input and output object for backward b
|
||||||
if input_obj is not None:
|
|
||||||
self.input_tensors[model_chunk_id].append(input_obj)
|
self.input_tensors[model_chunk_id].append((micro_batch, input_obj))
|
||||||
else:
|
|
||||||
self.input_tensors[model_chunk_id].append(micro_batch)
|
|
||||||
|
|
||||||
# for bwd b&w, we only need the graph(grad_fn) of output_obj
|
# for bwd b&w, we only need the graph(grad_fn) of output_obj
|
||||||
# Do not deallocate loss, deallocate other output_obj;
|
# Do not deallocate loss, deallocate other output_obj;
|
||||||
|
@ -703,7 +706,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
|
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
|
||||||
|
|
||||||
# get input and output object from buffer;
|
# get input and output object from buffer;
|
||||||
input_obj = self.input_tensors[model_chunk_id].pop(0)
|
micro_batch, input_obj = self.input_tensors[model_chunk_id].pop(0)
|
||||||
output_obj = self.output_tensors[model_chunk_id].pop(0)
|
output_obj = self.output_tensors[model_chunk_id].pop(0)
|
||||||
|
|
||||||
# save output_tensor_grad for dw
|
# save output_tensor_grad for dw
|
||||||
|
@ -719,6 +722,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
model_chunk=model_chunk,
|
model_chunk=model_chunk,
|
||||||
model_chunk_id=model_chunk_id,
|
model_chunk_id=model_chunk_id,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
|
micro_batch=micro_batch,
|
||||||
input_obj=input_obj,
|
input_obj=input_obj,
|
||||||
output_obj=output_obj,
|
output_obj=output_obj,
|
||||||
output_obj_grad=output_tensor_grad,
|
output_obj_grad=output_tensor_grad,
|
||||||
|
|
Loading…
Reference in New Issue