mirror of https://github.com/hpcaitech/ColossalAI
[fix] fix bwd step if condition; remove useless comments and format info;
parent
ab643c9af7
commit
4c1f81c683
|
@ -55,9 +55,6 @@ class OptimizerWrapper:
|
||||||
"""
|
"""
|
||||||
loss.backward(*args, **kwargs)
|
loss.backward(*args, **kwargs)
|
||||||
|
|
||||||
# def backward_by_grad(self, tensor: Tensor, grad: Tensor):
|
|
||||||
# torch.autograd.backward(tensor, grad)
|
|
||||||
|
|
||||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
|
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
|
||||||
"""
|
"""
|
||||||
Performs a backward pass for dx or dw,
|
Performs a backward pass for dx or dw,
|
||||||
|
@ -78,26 +75,6 @@ class OptimizerWrapper:
|
||||||
retain_graph=retain_graph,
|
retain_graph=retain_graph,
|
||||||
)
|
)
|
||||||
|
|
||||||
# def backward_b_w_by_grad(self, tensors: Tensor, grad_tensors: Tensor, inputs: Tensor, retain_graph: bool = True):
|
|
||||||
# """
|
|
||||||
# Performs a backward pass for dx or dw,
|
|
||||||
# for dx, we only calculate dx = w*dy here
|
|
||||||
# for dw, we only calculate dw = x*dy here
|
|
||||||
|
|
||||||
# Args:
|
|
||||||
# tensor (Tensor): y or loss of current chunk;
|
|
||||||
# grad_tensors (Tensor): dy of current chunk;
|
|
||||||
# input_obj (Tensor): for dx, input_obj is x of current chunk;
|
|
||||||
# for dw, input_obj is w of current chunk;
|
|
||||||
# retain_graph (bool): default to be True, we retain graph in backward_b
|
|
||||||
# """
|
|
||||||
# torch.autograd.backward(
|
|
||||||
# tensors=tensors,
|
|
||||||
# grad_tensors=grad_tensors,
|
|
||||||
# inputs=inputs,
|
|
||||||
# retain_graph=retain_graph,
|
|
||||||
# )
|
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
"""
|
"""
|
||||||
Returns the optimizer state.
|
Returns the optimizer state.
|
||||||
|
|
|
@ -33,14 +33,11 @@ def deallocate_output_tensor(out, deallocate_pipeline_outputs=False):
|
||||||
only useful for its '.grad_fn' field, and not its '.data'.
|
only useful for its '.grad_fn' field, and not its '.data'.
|
||||||
"""
|
"""
|
||||||
if (out is None) or (not deallocate_pipeline_outputs):
|
if (out is None) or (not deallocate_pipeline_outputs):
|
||||||
print(
|
|
||||||
f"(out is None) or (not deallocate_pipeline_outputs): {(out is None) or (not deallocate_pipeline_outputs)}"
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__
|
assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__
|
||||||
assert out._base is None, "counter-productive to free a view of another tensor."
|
assert out._base is None, "counter-productive to free a view of another tensor."
|
||||||
# out.data = torch.empty((1,), device=out.device, dtype=out.dtype,)
|
# out.data = torch.empty((1,), device=out.device, dtype=out.dtype,)
|
||||||
out.data.storage().resize_(0)
|
out.data.untyped_storage().resize_(0)
|
||||||
|
|
||||||
|
|
||||||
class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
|
@ -457,33 +454,15 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
# Retain the grad on the input_obj.
|
# Retain the grad on the input_obj.
|
||||||
tree_map(retain_grad, input_obj)
|
tree_map(retain_grad, input_obj)
|
||||||
|
|
||||||
if model_chunk_id == 0:
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
# bwd step
|
# loss backward; output_obj is loss
|
||||||
optimizer.backward_by_grad(
|
output_obj_grad = None
|
||||||
tensor=output_obj,
|
optimizer.backward_by_grad(
|
||||||
grad=output_obj_grad,
|
tensor=output_obj,
|
||||||
inputs=input_obj,
|
grad=output_obj_grad,
|
||||||
retain_graph=True,
|
inputs=input_obj,
|
||||||
)
|
retain_graph=True,
|
||||||
else:
|
)
|
||||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
|
||||||
# loss backward; output_obj is loss
|
|
||||||
optimizer.backward_by_grad(
|
|
||||||
tensor=output_obj,
|
|
||||||
grad=None,
|
|
||||||
inputs=input_obj,
|
|
||||||
retain_graph=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# commom bwd step
|
|
||||||
optimizer.backward_by_grad(
|
|
||||||
tensor=output_obj,
|
|
||||||
grad=output_obj_grad,
|
|
||||||
inputs=input_obj,
|
|
||||||
retain_graph=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return input_obj.grad
|
return input_obj.grad
|
||||||
|
|
||||||
def backward_w_step(
|
def backward_w_step(
|
||||||
|
@ -507,29 +486,39 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
Nothing need to return; we only calculate dw then update w;
|
Nothing need to return; we only calculate dw then update w;
|
||||||
"""
|
"""
|
||||||
# calculate bwd w step ; only dw = x*dy;
|
# calculate bwd w step ; only dw = x*dy;
|
||||||
if model_chunk_id == 0:
|
|
||||||
optimizer.backward_by_grad(
|
|
||||||
tensor=output_obj,
|
|
||||||
grad=output_obj_grad,
|
|
||||||
inputs=list(model_chunk[model_chunk_id].parameters()),
|
|
||||||
retain_graph=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
# loss backward; output_obj is loss
|
||||||
optimizer.backward_by_grad(
|
output_obj_grad = None
|
||||||
tensor=output_obj,
|
optimizer.backward_by_grad(
|
||||||
grad=None,
|
tensor=output_obj,
|
||||||
inputs=list(model_chunk[model_chunk_id].parameters()),
|
grad=output_obj_grad,
|
||||||
retain_graph=False,
|
inputs=list(model_chunk[model_chunk_id].parameters()),
|
||||||
)
|
retain_graph=False,
|
||||||
else:
|
)
|
||||||
optimizer.backward_by_grad(
|
# if model_chunk_id == 0:
|
||||||
tensor=output_obj,
|
# optimizer.backward_by_grad(
|
||||||
grad=output_obj_grad,
|
# tensor=output_obj,
|
||||||
inputs=list(model_chunk[model_chunk_id].parameters()),
|
# grad=output_obj_grad,
|
||||||
retain_graph=False,
|
# inputs=list(model_chunk[model_chunk_id].parameters()),
|
||||||
)
|
# retain_graph=False,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# else:
|
||||||
|
# if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
|
# optimizer.backward_by_grad(
|
||||||
|
# tensor=output_obj,
|
||||||
|
# grad=None,
|
||||||
|
# inputs=list(model_chunk[model_chunk_id].parameters()),
|
||||||
|
# retain_graph=False,
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
# optimizer.backward_by_grad(
|
||||||
|
# tensor=output_obj,
|
||||||
|
# grad=output_obj_grad,
|
||||||
|
# inputs=list(model_chunk[model_chunk_id].parameters()),
|
||||||
|
# retain_graph=False,
|
||||||
|
# )
|
||||||
|
|
||||||
def schedule_f(
|
def schedule_f(
|
||||||
self,
|
self,
|
||||||
|
@ -578,15 +567,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
accum_loss=accum_loss,
|
accum_loss=accum_loss,
|
||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
)
|
)
|
||||||
# add input and output object for backward b
|
|
||||||
self.input_tensors[model_chunk_id].append(input_obj)
|
|
||||||
|
|
||||||
# detached output; for bwd b&w, we only need the graph(grad_fn) of output_obj
|
|
||||||
detached_output_obj = output_obj.clone()
|
|
||||||
deallocate_output_tensor(detached_output_obj, deallocate_pipeline_outputs=True)
|
|
||||||
self.output_tensors[model_chunk_id].append(detached_output_obj)
|
|
||||||
# add output object for backward w
|
|
||||||
self.output_tensors_dw[model_chunk_id].append(detached_output_obj)
|
|
||||||
|
|
||||||
# Step3: send fwd
|
# Step3: send fwd
|
||||||
# add output to send_fwd_buffer
|
# add output to send_fwd_buffer
|
||||||
|
@ -603,6 +583,15 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
else:
|
else:
|
||||||
self.send_forward_buffer[model_chunk_id].append(output_obj)
|
self.send_forward_buffer[model_chunk_id].append(output_obj)
|
||||||
|
|
||||||
|
# add input and output object for backward b
|
||||||
|
self.input_tensors[model_chunk_id].append(input_obj)
|
||||||
|
# detached output; for bwd b&w, we only need the graph(grad_fn) of output_obj
|
||||||
|
detached_output_obj = output_obj.clone()
|
||||||
|
deallocate_output_tensor(detached_output_obj, deallocate_pipeline_outputs=True)
|
||||||
|
self.output_tensors[model_chunk_id].append(detached_output_obj)
|
||||||
|
# add output object for backward w
|
||||||
|
self.output_tensors_dw[model_chunk_id].append(detached_output_obj)
|
||||||
|
|
||||||
def schedule_b(
|
def schedule_b(
|
||||||
self,
|
self,
|
||||||
scheduled_node,
|
scheduled_node,
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -50,7 +50,7 @@ def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
|
||||||
"num_microbatches": 4,
|
"num_microbatches": 4,
|
||||||
"zero_stage": 1,
|
"zero_stage": 1,
|
||||||
"precision": "bf16",
|
"precision": "bf16",
|
||||||
"num_model_chunk": 4,
|
"num_model_chunk": 2,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -507,7 +507,7 @@ def run_fwd_bwd_iter_input(test_config):
|
||||||
"num_microbatches": 4,
|
"num_microbatches": 4,
|
||||||
"zero_stage": 1,
|
"zero_stage": 1,
|
||||||
"precision": "bf16",
|
"precision": "bf16",
|
||||||
"num_model_chunk": 4,
|
"num_model_chunk": 2,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -702,8 +702,7 @@ def run_with_hybridplugin(test_config):
|
||||||
def run_with_moehybridplugin(test_config):
|
def run_with_moehybridplugin(test_config):
|
||||||
model_zoo.get_sub_registry("transformers_bert")
|
model_zoo.get_sub_registry("transformers_bert")
|
||||||
test_config["use_lazy_init"] = False
|
test_config["use_lazy_init"] = False
|
||||||
test_config["pp_size"] = 1 # Do NOT test Pipeline Parallel
|
test_config["initial_scale"] = 2**16
|
||||||
test_config["initial_scale"] = 2**16 # avoid overflow
|
|
||||||
model_list = [
|
model_list = [
|
||||||
"transformers_bert",
|
"transformers_bert",
|
||||||
]
|
]
|
||||||
|
|
Loading…
Reference in New Issue