[fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;

pull/6034/head
duanjunwen 2024-08-30 05:42:43 +00:00
parent 6af81d8c0d
commit 8eb6eac225
3 changed files with 63 additions and 68 deletions

View File

@ -58,14 +58,17 @@ class OptimizerWrapper:
def backward_by_grad(self, tensor: Tensor, grad: Tensor): def backward_by_grad(self, tensor: Tensor, grad: Tensor):
torch.autograd.backward(tensor, grad) torch.autograd.backward(tensor, grad)
def backward_b_by_grad(self, tensors: Tensor, grad_tensors: Tensor, inputs: Tensor, retain_graph: bool = True): def backward_b_w_by_grad(self, tensors: Tensor, grad_tensors: Tensor, inputs: Tensor, retain_graph: bool = True):
""" """
Performs a backward pass for dx, we only calculate dx = w*dy here Performs a backward pass for dx or dw,
for dx, we only calculate dx = w*dy here
for dw, we only calculate dw = x*dy here
Args: Args:
tensor (Tensor): y or loss of current chunk; tensor (Tensor): y or loss of current chunk;
grad_tensors (Tensor): dy of current chunk; grad_tensors (Tensor): dy of current chunk;
input_obj (Tensor): x of current chunk; input_obj (Tensor): for dx, input_obj is x of current chunk;
for dw, input_obj is w of current chunk;
retain_graph (bool): default to be True, we retain graph in backward_b retain_graph (bool): default to be True, we retain graph in backward_b
""" """
torch.autograd.backward( torch.autograd.backward(
@ -75,23 +78,6 @@ class OptimizerWrapper:
retain_graph=retain_graph, retain_graph=retain_graph,
) )
def backward_w_by_grad(self, tensors: Tensor, grad_tensors: Tensor, inputs: Tensor, retain_graph: bool = False):
"""
Performs a backward pass for dw, we only calculate dw = x*dy here
Args:
tensor (Tensor): y or loss of current chunk;
grad_tensors (Tensor): dy of current chunk;
input_obj (Tensor): w;
retain_graph (bool): default to be False, we release graph in backward_w
"""
torch.autograd.backward(
tensors=tensors,
grad_tensors=grad_tensors,
inputs=inputs,
retain_graph=retain_graph,
)
def state_dict(self): def state_dict(self):
""" """
Returns the optimizer state. Returns the optimizer state.

View File

@ -1,6 +1,32 @@
# Refer from Zero Bubble Pipeline Parallelism. # Refer from Zero Bubble Pipeline Parallelism.
# Github: https://github.com/sail-sg/zero-bubble-pipeline-parallelism # Github: https://github.com/sail-sg/zero-bubble-pipeline-parallelism
# Paper: https://arxiv.org/abs/2401.10241 # Paper: https://arxiv.org/abs/2401.10241
# The following applies to all files unless otherwise noted:
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from collections import deque from collections import deque
from dataclasses import dataclass from dataclasses import dataclass

View File

@ -46,13 +46,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
self.last_batch_size: Optional[int] = None self.last_batch_size: Optional[int] = None
self.microbatch_offset: List[int] self.microbatch_offset: List[int]
self.collect_non_loss_data = None
self.forward_only = None
self.schedules = schedule self.schedules = schedule
# TODO: optim post valid # TODO: optim post valid
self.do_post_validation = False self.do_post_validation = False
# self.is_first_run = True
# self.optimizer = None
# P2PMeta cache # P2PMeta cache
# self.enable_metadata_cache = enable_metadata_cache # self.enable_metadata_cache = enable_metadata_cache
@ -166,6 +162,14 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk_id = self.num_model_chunks - model_chunk_id - 1 model_chunk_id = self.num_model_chunks - model_chunk_id - 1
return model_chunk_id return model_chunk_id
def communication_func_map(self, node_type: str):
return {
"SEND_FORWARD": self.send_forward,
"RECV_FORWARD": self.recv_forward,
"SEND_BACKWARD": self.send_backward,
"RECV_BACKWARD": self.recv_backward,
}[node_type]
def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Tuple[Any, List]: def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Tuple[Any, List]:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage. """Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
For ZBV. For ZBV.
@ -439,10 +443,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
if model_chunk_id == 0: if model_chunk_id == 0:
# bwd step # bwd step
# torch.autograd.backward( optimizer.backward_b_w_by_grad(
# tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True
# )
optimizer.backward_b_by_grad(
tensors=output_obj, tensors=output_obj,
grad_tensors=output_obj_grad, grad_tensors=output_obj_grad,
inputs=input_obj, inputs=input_obj,
@ -451,8 +452,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
else: else:
if self.stage_manager.is_first_stage(ignore_chunk=True): if self.stage_manager.is_first_stage(ignore_chunk=True):
# loss backward; output_obj is loss # loss backward; output_obj is loss
# torch.autograd.backward(tensors=output_obj, grad_tensors=None, inputs=input_obj, retain_graph=True) optimizer.backward_b_w_by_grad(
optimizer.backward_b_by_grad(
tensors=output_obj, tensors=output_obj,
grad_tensors=None, grad_tensors=None,
inputs=input_obj, inputs=input_obj,
@ -461,10 +461,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
else: else:
# commom bwd step # commom bwd step
# torch.autograd.backward( optimizer.backward_b_w_by_grad(
# tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True
# )
optimizer.backward_b_by_grad(
tensors=output_obj, tensors=output_obj,
grad_tensors=output_obj_grad, grad_tensors=output_obj_grad,
inputs=input_obj, inputs=input_obj,
@ -495,30 +492,27 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
""" """
# calculate bwd w step ; only dw = x*dy; # calculate bwd w step ; only dw = x*dy;
if model_chunk_id == 0: if model_chunk_id == 0:
# torch.autograd.backward( optimizer.backward_b_w_by_grad(
# tensors=output_obj, grad_tensors=output_obj_grad, inputs=list(model_chunk[model_chunk_id].parameters()) tensors=output_obj,
# ) grad_tensors=output_obj_grad,
optimizer.backward_w_by_grad( inputs=list(model_chunk[model_chunk_id].parameters()),
tensors=output_obj, grad_tensors=output_obj_grad, inputs=list(model_chunk[model_chunk_id].parameters()) retain_graph=False,
) )
else: else:
if self.stage_manager.is_first_stage(ignore_chunk=True): if self.stage_manager.is_first_stage(ignore_chunk=True):
# torch.autograd.backward(tensors=output_obj_grad, grad_tensors=None, inputs=list(model_chunk[model_chunk_id].parameters())) optimizer.backward_b_w_by_grad(
optimizer.backward_w_by_grad( tensors=output_obj,
tensors=output_obj, grad_tensors=None, inputs=list(model_chunk[model_chunk_id].parameters()) grad_tensors=None,
inputs=list(model_chunk[model_chunk_id].parameters()),
retain_graph=False,
) )
else: else:
# torch.autograd.backward( optimizer.backward_b_w_by_grad(
# tensors=output_obj,
# grad_tensors=output_obj_grad,
# inputs=list(model_chunk[model_chunk_id].parameters()),
# )
optimizer.backward_w_by_grad(
tensors=output_obj, tensors=output_obj,
grad_tensors=output_obj_grad, grad_tensors=output_obj_grad,
inputs=list(model_chunk[model_chunk_id].parameters()), inputs=list(model_chunk[model_chunk_id].parameters()),
retain_graph=False,
) )
def schedule_f( def schedule_f(
@ -718,17 +712,14 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None
it = 0
# while we still have schedules_node in self.schedules # while we still have schedules_node in self.schedules
while it < len(self.schedules): for it in range(len(self.schedules)):
scheduled_node = self.schedules[it] scheduled_node = self.schedules[it]
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES: if scheduled_node.type in {"RECV_FORWARD", "SEND_FORWARD"}:
# communication # communication
if scheduled_node.type == "RECV_FORWARD": communication_func = self.communication_func_map(scheduled_node.type)
self.recv_forward(scheduled_node.chunk) communication_func(scheduled_node.chunk)
elif scheduled_node.type == "SEND_FORWARD":
self.send_forward(scheduled_node.chunk)
if scheduled_node.type == "F": if scheduled_node.type == "F":
self.schedule_f( self.schedule_f(
scheduled_node=scheduled_node, scheduled_node=scheduled_node,
@ -738,7 +729,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
accum_loss=accum_loss, accum_loss=accum_loss,
outputs=outputs, outputs=outputs,
) )
it += 1
# return loss & output # return loss & output
if outputs is not None: if outputs is not None:
outputs = merge_batch(outputs) outputs = merge_batch(outputs)
@ -771,9 +761,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None
it = 0
# while we still have schedules_node in self.schedules # while we still have schedules_node in self.schedules
while it < len(self.schedules): for it in range(len(self.schedules)):
scheduled_node = self.schedules[it] scheduled_node = self.schedules[it]
print( print(
@ -781,14 +770,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
) )
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES: if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
# communication # communication
if scheduled_node.type == "RECV_FORWARD": communication_func = self.communication_func_map(scheduled_node.type)
self.recv_forward(scheduled_node.chunk) communication_func(scheduled_node.chunk)
elif scheduled_node.type == "RECV_BACKWARD":
self.recv_backward(scheduled_node.chunk)
elif scheduled_node.type == "SEND_FORWARD":
self.send_forward(scheduled_node.chunk)
elif scheduled_node.type == "SEND_BACKWARD":
self.send_backward(scheduled_node.chunk)
if scheduled_node.type == "F": if scheduled_node.type == "F":
self.schedule_f( self.schedule_f(
scheduled_node=scheduled_node, scheduled_node=scheduled_node,
@ -812,7 +796,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk_id=scheduled_node.chunk, model_chunk_id=scheduled_node.chunk,
optimizer=optimizer, optimizer=optimizer,
) )
it += 1
# return loss & output # return loss & output
if outputs is not None: if outputs is not None: