mirror of https://github.com/hpcaitech/ColossalAI
[fix] rm comments;
parent
c6d6ee39bd
commit
b6616f544e
|
@ -489,12 +489,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
|
||||
# For chunk 0 stage 0, use micro_batch as input_obj_
|
||||
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# for k, v in micro_batch.items():
|
||||
# if v.requires_grad:
|
||||
# input_obj_.append(micro_batch[k])
|
||||
# output_obj_.append(output_obj[k]) # y
|
||||
# output_obj_grad_.append(output_obj_grad[k]) # dy
|
||||
|
||||
input_obj_, _ = tree_flatten(micro_batch)
|
||||
output_obj_, _ = tree_flatten(output_obj) # y
|
||||
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
|
||||
|
@ -502,22 +496,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
# For loss backward; output_obj is loss; output_obj_grad should be None
|
||||
elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
assert output_obj_grad is None
|
||||
# for k, v in input_obj.items():
|
||||
# if v.requires_grad:
|
||||
# input_obj_.append(input_obj[k])
|
||||
input_obj_, _ = tree_flatten(input_obj)
|
||||
# output_obj_.append(output_obj) # LOSS
|
||||
# output_obj_grad_.append(output_obj_grad) # None
|
||||
output_obj_, _ = tree_flatten(output_obj) # LOSS
|
||||
output_obj_grad_, _ = tree_flatten(output_obj_grad) # None
|
||||
|
||||
# For other chunk stage, use input_obj as input_obj_;
|
||||
else:
|
||||
# for k, v in input_obj.items():
|
||||
# if v.requires_grad:
|
||||
# input_obj_.append(input_obj[k])
|
||||
# output_obj_.append(output_obj[k]) # y
|
||||
# output_obj_grad_.append(output_obj_grad[k]) # dy
|
||||
input_obj_, _ = tree_flatten(input_obj)
|
||||
output_obj_, _ = tree_flatten(output_obj) # y
|
||||
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
|
||||
|
@ -572,10 +556,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
output_obj_.append(output_obj) # LOSS
|
||||
output_obj_grad_.append(None) # None
|
||||
else:
|
||||
# for k, v in output_obj.items():
|
||||
# if v.requires_grad:
|
||||
# output_obj_.append(output_obj[k])
|
||||
# output_obj_grad_.append(output_obj_grad[k])
|
||||
output_obj_, _ = tree_flatten(output_obj) # y
|
||||
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
|
||||
|
||||
|
@ -653,7 +633,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
tree_map(deallocate, deallocate_output_obj)
|
||||
|
||||
# add input and output object for backward b
|
||||
|
||||
self.input_tensors[model_chunk_id].append((micro_batch, input_obj))
|
||||
|
||||
# for bwd b&w, we only need the graph(grad_fn) of output_obj
|
||||
|
|
Loading…
Reference in New Issue