[fix] fix bwd step if condition; remove useless comments and format info;

pull/6034/head
duanjunwen 2024-09-03 08:56:08 +00:00
parent ab643c9af7
commit 4c1f81c683
4 changed files with 54 additions and 1188 deletions

View File

@ -55,9 +55,6 @@ class OptimizerWrapper:
"""
loss.backward(*args, **kwargs)
# def backward_by_grad(self, tensor: Tensor, grad: Tensor):
# torch.autograd.backward(tensor, grad)
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
"""
Performs a backward pass for dx or dw,
@ -78,26 +75,6 @@ class OptimizerWrapper:
retain_graph=retain_graph,
)
# def backward_b_w_by_grad(self, tensors: Tensor, grad_tensors: Tensor, inputs: Tensor, retain_graph: bool = True):
# """
# Performs a backward pass for dx or dw,
# for dx, we only calculate dx = w*dy here
# for dw, we only calculate dw = x*dy here
# Args:
# tensor (Tensor): y or loss of current chunk;
# grad_tensors (Tensor): dy of current chunk;
# input_obj (Tensor): for dx, input_obj is x of current chunk;
# for dw, input_obj is w of current chunk;
# retain_graph (bool): default to be True, we retain graph in backward_b
# """
# torch.autograd.backward(
# tensors=tensors,
# grad_tensors=grad_tensors,
# inputs=inputs,
# retain_graph=retain_graph,
# )
def state_dict(self):
"""
Returns the optimizer state.

View File

@ -33,14 +33,11 @@ def deallocate_output_tensor(out, deallocate_pipeline_outputs=False):
only useful for its '.grad_fn' field, and not its '.data'.
"""
if (out is None) or (not deallocate_pipeline_outputs):
print(
f"(out is None) or (not deallocate_pipeline_outputs): {(out is None) or (not deallocate_pipeline_outputs)}"
)
return
assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__
assert out._base is None, "counter-productive to free a view of another tensor."
# out.data = torch.empty((1,), device=out.device, dtype=out.dtype,)
out.data.storage().resize_(0)
out.data.untyped_storage().resize_(0)
class ZeroBubbleVPipeScheduler(PipelineSchedule):
@ -457,33 +454,15 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# Retain the grad on the input_obj.
tree_map(retain_grad, input_obj)
if model_chunk_id == 0:
# bwd step
optimizer.backward_by_grad(
tensor=output_obj,
grad=output_obj_grad,
inputs=input_obj,
retain_graph=True,
)
else:
if self.stage_manager.is_first_stage(ignore_chunk=True):
# loss backward; output_obj is loss
optimizer.backward_by_grad(
tensor=output_obj,
grad=None,
inputs=input_obj,
retain_graph=True,
)
else:
# commom bwd step
optimizer.backward_by_grad(
tensor=output_obj,
grad=output_obj_grad,
inputs=input_obj,
retain_graph=True,
)
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# loss backward; output_obj is loss
output_obj_grad = None
optimizer.backward_by_grad(
tensor=output_obj,
grad=output_obj_grad,
inputs=input_obj,
retain_graph=True,
)
return input_obj.grad
def backward_w_step(
@ -507,29 +486,39 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
Nothing need to return; we only calculate dw then update w;
"""
# calculate bwd w step ; only dw = x*dy;
if model_chunk_id == 0:
optimizer.backward_by_grad(
tensor=output_obj,
grad=output_obj_grad,
inputs=list(model_chunk[model_chunk_id].parameters()),
retain_graph=False,
)
else:
if self.stage_manager.is_first_stage(ignore_chunk=True):
optimizer.backward_by_grad(
tensor=output_obj,
grad=None,
inputs=list(model_chunk[model_chunk_id].parameters()),
retain_graph=False,
)
else:
optimizer.backward_by_grad(
tensor=output_obj,
grad=output_obj_grad,
inputs=list(model_chunk[model_chunk_id].parameters()),
retain_graph=False,
)
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# loss backward; output_obj is loss
output_obj_grad = None
optimizer.backward_by_grad(
tensor=output_obj,
grad=output_obj_grad,
inputs=list(model_chunk[model_chunk_id].parameters()),
retain_graph=False,
)
# if model_chunk_id == 0:
# optimizer.backward_by_grad(
# tensor=output_obj,
# grad=output_obj_grad,
# inputs=list(model_chunk[model_chunk_id].parameters()),
# retain_graph=False,
# )
# else:
# if self.stage_manager.is_first_stage(ignore_chunk=True):
# optimizer.backward_by_grad(
# tensor=output_obj,
# grad=None,
# inputs=list(model_chunk[model_chunk_id].parameters()),
# retain_graph=False,
# )
# else:
# optimizer.backward_by_grad(
# tensor=output_obj,
# grad=output_obj_grad,
# inputs=list(model_chunk[model_chunk_id].parameters()),
# retain_graph=False,
# )
def schedule_f(
self,
@ -578,15 +567,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
accum_loss=accum_loss,
outputs=outputs,
)
# add input and output object for backward b
self.input_tensors[model_chunk_id].append(input_obj)
# detached output; for bwd b&w, we only need the graph(grad_fn) of output_obj
detached_output_obj = output_obj.clone()
deallocate_output_tensor(detached_output_obj, deallocate_pipeline_outputs=True)
self.output_tensors[model_chunk_id].append(detached_output_obj)
# add output object for backward w
self.output_tensors_dw[model_chunk_id].append(detached_output_obj)
# Step3: send fwd
# add output to send_fwd_buffer
@ -603,6 +583,15 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
else:
self.send_forward_buffer[model_chunk_id].append(output_obj)
# add input and output object for backward b
self.input_tensors[model_chunk_id].append(input_obj)
# detached output; for bwd b&w, we only need the graph(grad_fn) of output_obj
detached_output_obj = output_obj.clone()
deallocate_output_tensor(detached_output_obj, deallocate_pipeline_outputs=True)
self.output_tensors[model_chunk_id].append(detached_output_obj)
# add output object for backward w
self.output_tensors_dw[model_chunk_id].append(detached_output_obj)
def schedule_b(
self,
scheduled_node,

File diff suppressed because it is too large Load Diff

View File

@ -50,7 +50,7 @@ def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
"num_microbatches": 4,
"zero_stage": 1,
"precision": "bf16",
"num_model_chunk": 4,
"num_model_chunk": 2,
},
],
)
@ -507,7 +507,7 @@ def run_fwd_bwd_iter_input(test_config):
"num_microbatches": 4,
"zero_stage": 1,
"precision": "bf16",
"num_model_chunk": 4,
"num_model_chunk": 2,
},
],
)
@ -702,8 +702,7 @@ def run_with_hybridplugin(test_config):
def run_with_moehybridplugin(test_config):
model_zoo.get_sub_registry("transformers_bert")
test_config["use_lazy_init"] = False
test_config["pp_size"] = 1 # Do NOT test Pipeline Parallel
test_config["initial_scale"] = 2**16 # avoid overflow
test_config["initial_scale"] = 2**16
model_list = [
"transformers_bert",
]