mirror of https://github.com/hpcaitech/ColossalAI
[fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;
parent
6af81d8c0d
commit
8eb6eac225
|
@ -58,14 +58,17 @@ class OptimizerWrapper:
|
||||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
|
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
|
||||||
torch.autograd.backward(tensor, grad)
|
torch.autograd.backward(tensor, grad)
|
||||||
|
|
||||||
def backward_b_by_grad(self, tensors: Tensor, grad_tensors: Tensor, inputs: Tensor, retain_graph: bool = True):
|
def backward_b_w_by_grad(self, tensors: Tensor, grad_tensors: Tensor, inputs: Tensor, retain_graph: bool = True):
|
||||||
"""
|
"""
|
||||||
Performs a backward pass for dx, we only calculate dx = w*dy here
|
Performs a backward pass for dx or dw,
|
||||||
|
for dx, we only calculate dx = w*dy here
|
||||||
|
for dw, we only calculate dw = x*dy here
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tensor (Tensor): y or loss of current chunk;
|
tensor (Tensor): y or loss of current chunk;
|
||||||
grad_tensors (Tensor): dy of current chunk;
|
grad_tensors (Tensor): dy of current chunk;
|
||||||
input_obj (Tensor): x of current chunk;
|
input_obj (Tensor): for dx, input_obj is x of current chunk;
|
||||||
|
for dw, input_obj is w of current chunk;
|
||||||
retain_graph (bool): default to be True, we retain graph in backward_b
|
retain_graph (bool): default to be True, we retain graph in backward_b
|
||||||
"""
|
"""
|
||||||
torch.autograd.backward(
|
torch.autograd.backward(
|
||||||
|
@ -75,23 +78,6 @@ class OptimizerWrapper:
|
||||||
retain_graph=retain_graph,
|
retain_graph=retain_graph,
|
||||||
)
|
)
|
||||||
|
|
||||||
def backward_w_by_grad(self, tensors: Tensor, grad_tensors: Tensor, inputs: Tensor, retain_graph: bool = False):
|
|
||||||
"""
|
|
||||||
Performs a backward pass for dw, we only calculate dw = x*dy here
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tensor (Tensor): y or loss of current chunk;
|
|
||||||
grad_tensors (Tensor): dy of current chunk;
|
|
||||||
input_obj (Tensor): w;
|
|
||||||
retain_graph (bool): default to be False, we release graph in backward_w
|
|
||||||
"""
|
|
||||||
torch.autograd.backward(
|
|
||||||
tensors=tensors,
|
|
||||||
grad_tensors=grad_tensors,
|
|
||||||
inputs=inputs,
|
|
||||||
retain_graph=retain_graph,
|
|
||||||
)
|
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
"""
|
"""
|
||||||
Returns the optimizer state.
|
Returns the optimizer state.
|
||||||
|
|
|
@ -1,6 +1,32 @@
|
||||||
# Refer from Zero Bubble Pipeline Parallelism.
|
# Refer from Zero Bubble Pipeline Parallelism.
|
||||||
# Github: https://github.com/sail-sg/zero-bubble-pipeline-parallelism
|
# Github: https://github.com/sail-sg/zero-bubble-pipeline-parallelism
|
||||||
# Paper: https://arxiv.org/abs/2401.10241
|
# Paper: https://arxiv.org/abs/2401.10241
|
||||||
|
# The following applies to all files unless otherwise noted:
|
||||||
|
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Redistribution and use in source and binary forms, with or without
|
||||||
|
# modification, are permitted provided that the following conditions
|
||||||
|
# are met:
|
||||||
|
# * Redistributions of source code must retain the above copyright
|
||||||
|
# notice, this list of conditions and the following disclaimer.
|
||||||
|
# * Redistributions in binary form must reproduce the above copyright
|
||||||
|
# notice, this list of conditions and the following disclaimer in the
|
||||||
|
# documentation and/or other materials provided with the distribution.
|
||||||
|
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||||
|
# contributors may be used to endorse or promote products derived
|
||||||
|
# from this software without specific prior written permission.
|
||||||
|
#
|
||||||
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||||
|
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||||
|
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||||
|
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||||
|
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||||
|
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||||
|
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||||
|
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
|
@ -46,13 +46,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
self.last_batch_size: Optional[int] = None
|
self.last_batch_size: Optional[int] = None
|
||||||
self.microbatch_offset: List[int]
|
self.microbatch_offset: List[int]
|
||||||
|
|
||||||
self.collect_non_loss_data = None
|
|
||||||
self.forward_only = None
|
|
||||||
self.schedules = schedule
|
self.schedules = schedule
|
||||||
# TODO: optim post valid
|
# TODO: optim post valid
|
||||||
self.do_post_validation = False
|
self.do_post_validation = False
|
||||||
# self.is_first_run = True
|
|
||||||
# self.optimizer = None
|
|
||||||
|
|
||||||
# P2PMeta cache
|
# P2PMeta cache
|
||||||
# self.enable_metadata_cache = enable_metadata_cache
|
# self.enable_metadata_cache = enable_metadata_cache
|
||||||
|
@ -166,6 +162,14 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
model_chunk_id = self.num_model_chunks - model_chunk_id - 1
|
model_chunk_id = self.num_model_chunks - model_chunk_id - 1
|
||||||
return model_chunk_id
|
return model_chunk_id
|
||||||
|
|
||||||
|
def communication_func_map(self, node_type: str):
|
||||||
|
return {
|
||||||
|
"SEND_FORWARD": self.send_forward,
|
||||||
|
"RECV_FORWARD": self.recv_forward,
|
||||||
|
"SEND_BACKWARD": self.send_backward,
|
||||||
|
"RECV_BACKWARD": self.recv_backward,
|
||||||
|
}[node_type]
|
||||||
|
|
||||||
def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Tuple[Any, List]:
|
def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Tuple[Any, List]:
|
||||||
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
|
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
|
||||||
For ZBV.
|
For ZBV.
|
||||||
|
@ -439,10 +443,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
|
|
||||||
if model_chunk_id == 0:
|
if model_chunk_id == 0:
|
||||||
# bwd step
|
# bwd step
|
||||||
# torch.autograd.backward(
|
optimizer.backward_b_w_by_grad(
|
||||||
# tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True
|
|
||||||
# )
|
|
||||||
optimizer.backward_b_by_grad(
|
|
||||||
tensors=output_obj,
|
tensors=output_obj,
|
||||||
grad_tensors=output_obj_grad,
|
grad_tensors=output_obj_grad,
|
||||||
inputs=input_obj,
|
inputs=input_obj,
|
||||||
|
@ -451,8 +452,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
else:
|
else:
|
||||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
# loss backward; output_obj is loss
|
# loss backward; output_obj is loss
|
||||||
# torch.autograd.backward(tensors=output_obj, grad_tensors=None, inputs=input_obj, retain_graph=True)
|
optimizer.backward_b_w_by_grad(
|
||||||
optimizer.backward_b_by_grad(
|
|
||||||
tensors=output_obj,
|
tensors=output_obj,
|
||||||
grad_tensors=None,
|
grad_tensors=None,
|
||||||
inputs=input_obj,
|
inputs=input_obj,
|
||||||
|
@ -461,10 +461,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# commom bwd step
|
# commom bwd step
|
||||||
# torch.autograd.backward(
|
optimizer.backward_b_w_by_grad(
|
||||||
# tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True
|
|
||||||
# )
|
|
||||||
optimizer.backward_b_by_grad(
|
|
||||||
tensors=output_obj,
|
tensors=output_obj,
|
||||||
grad_tensors=output_obj_grad,
|
grad_tensors=output_obj_grad,
|
||||||
inputs=input_obj,
|
inputs=input_obj,
|
||||||
|
@ -495,30 +492,27 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
"""
|
"""
|
||||||
# calculate bwd w step ; only dw = x*dy;
|
# calculate bwd w step ; only dw = x*dy;
|
||||||
if model_chunk_id == 0:
|
if model_chunk_id == 0:
|
||||||
# torch.autograd.backward(
|
optimizer.backward_b_w_by_grad(
|
||||||
# tensors=output_obj, grad_tensors=output_obj_grad, inputs=list(model_chunk[model_chunk_id].parameters())
|
tensors=output_obj,
|
||||||
# )
|
grad_tensors=output_obj_grad,
|
||||||
optimizer.backward_w_by_grad(
|
inputs=list(model_chunk[model_chunk_id].parameters()),
|
||||||
tensors=output_obj, grad_tensors=output_obj_grad, inputs=list(model_chunk[model_chunk_id].parameters())
|
retain_graph=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
# torch.autograd.backward(tensors=output_obj_grad, grad_tensors=None, inputs=list(model_chunk[model_chunk_id].parameters()))
|
optimizer.backward_b_w_by_grad(
|
||||||
optimizer.backward_w_by_grad(
|
tensors=output_obj,
|
||||||
tensors=output_obj, grad_tensors=None, inputs=list(model_chunk[model_chunk_id].parameters())
|
grad_tensors=None,
|
||||||
|
inputs=list(model_chunk[model_chunk_id].parameters()),
|
||||||
|
retain_graph=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# torch.autograd.backward(
|
optimizer.backward_b_w_by_grad(
|
||||||
# tensors=output_obj,
|
|
||||||
# grad_tensors=output_obj_grad,
|
|
||||||
# inputs=list(model_chunk[model_chunk_id].parameters()),
|
|
||||||
# )
|
|
||||||
|
|
||||||
optimizer.backward_w_by_grad(
|
|
||||||
tensors=output_obj,
|
tensors=output_obj,
|
||||||
grad_tensors=output_obj_grad,
|
grad_tensors=output_obj_grad,
|
||||||
inputs=list(model_chunk[model_chunk_id].parameters()),
|
inputs=list(model_chunk[model_chunk_id].parameters()),
|
||||||
|
retain_graph=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def schedule_f(
|
def schedule_f(
|
||||||
|
@ -718,17 +712,14 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
|
|
||||||
outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None
|
outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None
|
||||||
|
|
||||||
it = 0
|
|
||||||
# while we still have schedules_node in self.schedules
|
# while we still have schedules_node in self.schedules
|
||||||
while it < len(self.schedules):
|
for it in range(len(self.schedules)):
|
||||||
scheduled_node = self.schedules[it]
|
scheduled_node = self.schedules[it]
|
||||||
|
|
||||||
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
|
if scheduled_node.type in {"RECV_FORWARD", "SEND_FORWARD"}:
|
||||||
# communication
|
# communication
|
||||||
if scheduled_node.type == "RECV_FORWARD":
|
communication_func = self.communication_func_map(scheduled_node.type)
|
||||||
self.recv_forward(scheduled_node.chunk)
|
communication_func(scheduled_node.chunk)
|
||||||
elif scheduled_node.type == "SEND_FORWARD":
|
|
||||||
self.send_forward(scheduled_node.chunk)
|
|
||||||
if scheduled_node.type == "F":
|
if scheduled_node.type == "F":
|
||||||
self.schedule_f(
|
self.schedule_f(
|
||||||
scheduled_node=scheduled_node,
|
scheduled_node=scheduled_node,
|
||||||
|
@ -738,7 +729,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
accum_loss=accum_loss,
|
accum_loss=accum_loss,
|
||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
)
|
)
|
||||||
it += 1
|
|
||||||
# return loss & output
|
# return loss & output
|
||||||
if outputs is not None:
|
if outputs is not None:
|
||||||
outputs = merge_batch(outputs)
|
outputs = merge_batch(outputs)
|
||||||
|
@ -771,9 +761,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
|
|
||||||
outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None
|
outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None
|
||||||
|
|
||||||
it = 0
|
|
||||||
# while we still have schedules_node in self.schedules
|
# while we still have schedules_node in self.schedules
|
||||||
while it < len(self.schedules):
|
for it in range(len(self.schedules)):
|
||||||
scheduled_node = self.schedules[it]
|
scheduled_node = self.schedules[it]
|
||||||
|
|
||||||
print(
|
print(
|
||||||
|
@ -781,14 +770,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
)
|
)
|
||||||
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
|
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
|
||||||
# communication
|
# communication
|
||||||
if scheduled_node.type == "RECV_FORWARD":
|
communication_func = self.communication_func_map(scheduled_node.type)
|
||||||
self.recv_forward(scheduled_node.chunk)
|
communication_func(scheduled_node.chunk)
|
||||||
elif scheduled_node.type == "RECV_BACKWARD":
|
|
||||||
self.recv_backward(scheduled_node.chunk)
|
|
||||||
elif scheduled_node.type == "SEND_FORWARD":
|
|
||||||
self.send_forward(scheduled_node.chunk)
|
|
||||||
elif scheduled_node.type == "SEND_BACKWARD":
|
|
||||||
self.send_backward(scheduled_node.chunk)
|
|
||||||
if scheduled_node.type == "F":
|
if scheduled_node.type == "F":
|
||||||
self.schedule_f(
|
self.schedule_f(
|
||||||
scheduled_node=scheduled_node,
|
scheduled_node=scheduled_node,
|
||||||
|
@ -812,7 +796,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
model_chunk_id=scheduled_node.chunk,
|
model_chunk_id=scheduled_node.chunk,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
)
|
)
|
||||||
it += 1
|
|
||||||
|
|
||||||
# return loss & output
|
# return loss & output
|
||||||
if outputs is not None:
|
if outputs is not None:
|
||||||
|
|
Loading…
Reference in New Issue