[fix] fix bwd step if condition; remove useless comments and format info;

pull/6034/head
duanjunwen 2024-09-03 08:56:08 +00:00
parent ab643c9af7
commit 4c1f81c683
4 changed files with 54 additions and 1188 deletions

View File

@ -55,9 +55,6 @@ class OptimizerWrapper:
""" """
loss.backward(*args, **kwargs) loss.backward(*args, **kwargs)
# def backward_by_grad(self, tensor: Tensor, grad: Tensor):
# torch.autograd.backward(tensor, grad)
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False): def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
""" """
Performs a backward pass for dx or dw, Performs a backward pass for dx or dw,
@ -78,26 +75,6 @@ class OptimizerWrapper:
retain_graph=retain_graph, retain_graph=retain_graph,
) )
# def backward_b_w_by_grad(self, tensors: Tensor, grad_tensors: Tensor, inputs: Tensor, retain_graph: bool = True):
# """
# Performs a backward pass for dx or dw,
# for dx, we only calculate dx = w*dy here
# for dw, we only calculate dw = x*dy here
# Args:
# tensor (Tensor): y or loss of current chunk;
# grad_tensors (Tensor): dy of current chunk;
# input_obj (Tensor): for dx, input_obj is x of current chunk;
# for dw, input_obj is w of current chunk;
# retain_graph (bool): default to be True, we retain graph in backward_b
# """
# torch.autograd.backward(
# tensors=tensors,
# grad_tensors=grad_tensors,
# inputs=inputs,
# retain_graph=retain_graph,
# )
def state_dict(self): def state_dict(self):
""" """
Returns the optimizer state. Returns the optimizer state.

View File

@ -33,14 +33,11 @@ def deallocate_output_tensor(out, deallocate_pipeline_outputs=False):
only useful for its '.grad_fn' field, and not its '.data'. only useful for its '.grad_fn' field, and not its '.data'.
""" """
if (out is None) or (not deallocate_pipeline_outputs): if (out is None) or (not deallocate_pipeline_outputs):
print(
f"(out is None) or (not deallocate_pipeline_outputs): {(out is None) or (not deallocate_pipeline_outputs)}"
)
return return
assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__ assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__
assert out._base is None, "counter-productive to free a view of another tensor." assert out._base is None, "counter-productive to free a view of another tensor."
# out.data = torch.empty((1,), device=out.device, dtype=out.dtype,) # out.data = torch.empty((1,), device=out.device, dtype=out.dtype,)
out.data.storage().resize_(0) out.data.untyped_storage().resize_(0)
class ZeroBubbleVPipeScheduler(PipelineSchedule): class ZeroBubbleVPipeScheduler(PipelineSchedule):
@ -457,33 +454,15 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# Retain the grad on the input_obj. # Retain the grad on the input_obj.
tree_map(retain_grad, input_obj) tree_map(retain_grad, input_obj)
if model_chunk_id == 0: if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# bwd step # loss backward; output_obj is loss
optimizer.backward_by_grad( output_obj_grad = None
tensor=output_obj, optimizer.backward_by_grad(
grad=output_obj_grad, tensor=output_obj,
inputs=input_obj, grad=output_obj_grad,
retain_graph=True, inputs=input_obj,
) retain_graph=True,
else: )
if self.stage_manager.is_first_stage(ignore_chunk=True):
# loss backward; output_obj is loss
optimizer.backward_by_grad(
tensor=output_obj,
grad=None,
inputs=input_obj,
retain_graph=True,
)
else:
# commom bwd step
optimizer.backward_by_grad(
tensor=output_obj,
grad=output_obj_grad,
inputs=input_obj,
retain_graph=True,
)
return input_obj.grad return input_obj.grad
def backward_w_step( def backward_w_step(
@ -507,29 +486,39 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
Nothing need to return; we only calculate dw then update w; Nothing need to return; we only calculate dw then update w;
""" """
# calculate bwd w step ; only dw = x*dy; # calculate bwd w step ; only dw = x*dy;
if model_chunk_id == 0:
optimizer.backward_by_grad(
tensor=output_obj,
grad=output_obj_grad,
inputs=list(model_chunk[model_chunk_id].parameters()),
retain_graph=False,
)
else: if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
if self.stage_manager.is_first_stage(ignore_chunk=True): # loss backward; output_obj is loss
optimizer.backward_by_grad( output_obj_grad = None
tensor=output_obj, optimizer.backward_by_grad(
grad=None, tensor=output_obj,
inputs=list(model_chunk[model_chunk_id].parameters()), grad=output_obj_grad,
retain_graph=False, inputs=list(model_chunk[model_chunk_id].parameters()),
) retain_graph=False,
else: )
optimizer.backward_by_grad( # if model_chunk_id == 0:
tensor=output_obj, # optimizer.backward_by_grad(
grad=output_obj_grad, # tensor=output_obj,
inputs=list(model_chunk[model_chunk_id].parameters()), # grad=output_obj_grad,
retain_graph=False, # inputs=list(model_chunk[model_chunk_id].parameters()),
) # retain_graph=False,
# )
# else:
# if self.stage_manager.is_first_stage(ignore_chunk=True):
# optimizer.backward_by_grad(
# tensor=output_obj,
# grad=None,
# inputs=list(model_chunk[model_chunk_id].parameters()),
# retain_graph=False,
# )
# else:
# optimizer.backward_by_grad(
# tensor=output_obj,
# grad=output_obj_grad,
# inputs=list(model_chunk[model_chunk_id].parameters()),
# retain_graph=False,
# )
def schedule_f( def schedule_f(
self, self,
@ -578,15 +567,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
accum_loss=accum_loss, accum_loss=accum_loss,
outputs=outputs, outputs=outputs,
) )
# add input and output object for backward b
self.input_tensors[model_chunk_id].append(input_obj)
# detached output; for bwd b&w, we only need the graph(grad_fn) of output_obj
detached_output_obj = output_obj.clone()
deallocate_output_tensor(detached_output_obj, deallocate_pipeline_outputs=True)
self.output_tensors[model_chunk_id].append(detached_output_obj)
# add output object for backward w
self.output_tensors_dw[model_chunk_id].append(detached_output_obj)
# Step3: send fwd # Step3: send fwd
# add output to send_fwd_buffer # add output to send_fwd_buffer
@ -603,6 +583,15 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
else: else:
self.send_forward_buffer[model_chunk_id].append(output_obj) self.send_forward_buffer[model_chunk_id].append(output_obj)
# add input and output object for backward b
self.input_tensors[model_chunk_id].append(input_obj)
# detached output; for bwd b&w, we only need the graph(grad_fn) of output_obj
detached_output_obj = output_obj.clone()
deallocate_output_tensor(detached_output_obj, deallocate_pipeline_outputs=True)
self.output_tensors[model_chunk_id].append(detached_output_obj)
# add output object for backward w
self.output_tensors_dw[model_chunk_id].append(detached_output_obj)
def schedule_b( def schedule_b(
self, self,
scheduled_node, scheduled_node,

File diff suppressed because it is too large Load Diff

View File

@ -50,7 +50,7 @@ def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
"num_microbatches": 4, "num_microbatches": 4,
"zero_stage": 1, "zero_stage": 1,
"precision": "bf16", "precision": "bf16",
"num_model_chunk": 4, "num_model_chunk": 2,
}, },
], ],
) )
@ -507,7 +507,7 @@ def run_fwd_bwd_iter_input(test_config):
"num_microbatches": 4, "num_microbatches": 4,
"zero_stage": 1, "zero_stage": 1,
"precision": "bf16", "precision": "bf16",
"num_model_chunk": 4, "num_model_chunk": 2,
}, },
], ],
) )
@ -702,8 +702,7 @@ def run_with_hybridplugin(test_config):
def run_with_moehybridplugin(test_config): def run_with_moehybridplugin(test_config):
model_zoo.get_sub_registry("transformers_bert") model_zoo.get_sub_registry("transformers_bert")
test_config["use_lazy_init"] = False test_config["use_lazy_init"] = False
test_config["pp_size"] = 1 # Do NOT test Pipeline Parallel test_config["initial_scale"] = 2**16
test_config["initial_scale"] = 2**16 # avoid overflow
model_list = [ model_list = [
"transformers_bert", "transformers_bert",
] ]