mirror of https://github.com/hpcaitech/ColossalAI
959 lines
40 KiB
Python
959 lines
40 KiB
Python
from functools import partial
|
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
|
|
|
|
import torch
|
|
import torch.cuda
|
|
import torch.distributed
|
|
from torch.nn import Module, ModuleList
|
|
from torch.utils._pytree import tree_flatten, tree_map
|
|
|
|
from colossalai.accelerator import get_accelerator
|
|
from colossalai.interface import OptimizerWrapper
|
|
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
|
|
from colossalai.pipeline.schedule.v_schedule import ScheduledNode
|
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
|
from colossalai.pipeline.weight_grad_store import WeightGradStore
|
|
|
|
from ._utils import (
|
|
clone,
|
|
detach,
|
|
get_batch_size,
|
|
get_micro_batch,
|
|
merge_batch,
|
|
model_forward,
|
|
release_tensor_data,
|
|
require_grad,
|
|
retain_grad,
|
|
to_device,
|
|
)
|
|
from .base import PipelineSchedule
|
|
|
|
AUTO_SCHEDULE_COMMUNICATION_TYPES = {"RECV_FORWARD", "RECV_BACKWARD", "SEND_FORWARD", "SEND_BACKWARD"}
|
|
|
|
|
|
def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None:
|
|
if wait_handles is not None:
|
|
for req in wait_handles:
|
|
req.wait()
|
|
|
|
|
|
class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
def __init__(
|
|
self,
|
|
stage_manager: PipelineStageManager,
|
|
schedule: List[ScheduledNode],
|
|
num_model_chunks: int,
|
|
num_microbatch: Optional[int] = None,
|
|
microbatch_size: Optional[int] = None,
|
|
enable_metadata_cache: bool = True,
|
|
overlap_p2p: bool = True,
|
|
):
|
|
super().__init__(stage_manager)
|
|
# batch info
|
|
self.num_microbatch = num_microbatch
|
|
self.microbatch_size = microbatch_size
|
|
self.num_model_chunks = num_model_chunks
|
|
self.batch: Any
|
|
self.batch_size: int
|
|
self.last_batch_size: Optional[int] = None
|
|
self.microbatch_offset: List[int]
|
|
|
|
self.schedules = schedule
|
|
# TODO: optim post valid
|
|
self.do_post_validation = False
|
|
|
|
# P2PMeta cache
|
|
self.enable_metadata_cache = enable_metadata_cache
|
|
|
|
# check send_tensor_metadata, send_grad_metadata
|
|
# pp4 as sample, we should follow this meta strategy
|
|
# send_tensor_meta(fwd) send_grad_meta(bwd)
|
|
# chunk0 | chunk1 chunk0 | chunk 1
|
|
# stage 0 T | F F | T
|
|
# stage 1 T | T T | T
|
|
# stage 2 T | T T | T
|
|
# stage 3 F | T F | T
|
|
if stage_manager.is_first_stage(ignore_chunk=True):
|
|
self.send_tensor_metadata = [True, False]
|
|
self.send_grad_metadata = [False, True]
|
|
elif stage_manager.is_last_stage(ignore_chunk=True):
|
|
self.send_tensor_metadata = [False, True]
|
|
self.send_grad_metadata = [True, False]
|
|
else:
|
|
self.send_tensor_metadata = [True, True]
|
|
self.send_grad_metadata = [True, True]
|
|
|
|
# meta cache buffer
|
|
self.tensor_metadata_recv = [None, None] # [chunk 0 meta, chunk 1 meta]
|
|
self.grad_metadata_recv = [None, None]
|
|
|
|
# P2P communication
|
|
self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p)
|
|
|
|
# init communication map
|
|
self.communication_map = {
|
|
"SEND_FORWARD": self.send_forward,
|
|
"RECV_FORWARD": self.recv_forward,
|
|
"SEND_BACKWARD": self.send_backward,
|
|
"RECV_BACKWARD": self.recv_backward,
|
|
}
|
|
|
|
# init buffer
|
|
self._free_buffers()
|
|
|
|
def _free_buffers(self):
|
|
# free local buffer
|
|
# two dim array, first dim is the model chunk, second dim is the microbatch queue
|
|
|
|
# x & y buffer for schedule b
|
|
self.input_tensors = [[], []]
|
|
self.output_tensors = [[], []]
|
|
|
|
# y & dy buffer for schedule w
|
|
self.output_tensors_dw = [[], []]
|
|
self.output_tensors_grad_dw = [[], []]
|
|
|
|
# buffer for communication
|
|
self.send_forward_buffer = [[], []] # [chunk0:[torch.Tensor], chunk1:[torch.Tensor]]
|
|
self.recv_forward_buffer = [
|
|
[],
|
|
[],
|
|
] # [chunk0:[(torch.Tensor, wait_handle)], chunk1:[(torch.Tensor, wait_handle)]]
|
|
self.send_backward_buffer = [[], []] # [chunk0:[torch.Tensor], chunk1:[torch.Tensor]]
|
|
self.recv_backward_buffer = [
|
|
[],
|
|
[],
|
|
] # [chunk0:[(torch.Tensor, wait_handle)], chunk1:[(torch.Tensor, wait_handle)]]
|
|
|
|
# y buffer for local send fwd
|
|
self.local_send_forward_buffer = []
|
|
# dy buffer for local send bwd
|
|
self.local_send_backward_buffer = []
|
|
|
|
# wait pp buffer
|
|
self.wait_handles = []
|
|
|
|
def assert_buffer_empty(self):
|
|
# assert buffer is empty at end
|
|
assert len(self.input_tensors[0]) == 0
|
|
assert len(self.input_tensors[1]) == 0
|
|
assert len(self.output_tensors[0]) == 0
|
|
assert len(self.output_tensors[1]) == 0
|
|
assert len(self.output_tensors_dw[0]) == 0
|
|
assert len(self.output_tensors_dw[1]) == 0
|
|
assert len(self.output_tensors_grad_dw[0]) == 0
|
|
assert len(self.output_tensors_grad_dw[1]) == 0
|
|
assert len(self.send_forward_buffer[0]) == 0
|
|
assert len(self.send_forward_buffer[1]) == 0
|
|
assert len(self.recv_forward_buffer[0]) == 0
|
|
assert len(self.recv_forward_buffer[1]) == 0
|
|
assert len(self.send_backward_buffer[0]) == 0
|
|
assert len(self.send_backward_buffer[1]) == 0
|
|
assert len(self.recv_backward_buffer[0]) == 0
|
|
assert len(self.recv_backward_buffer[1]) == 0
|
|
assert len(self.local_send_forward_buffer) == 0
|
|
assert len(self.local_send_backward_buffer) == 0
|
|
|
|
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
|
"""Load a batch from data iterator.
|
|
|
|
Args:
|
|
data_iter (Iterable): Data iterator.
|
|
device (Optional[torch.device], optional): Target device. Defaults to None.
|
|
"""
|
|
batch = next(data_iter)
|
|
if device is not None:
|
|
batch = tree_map(partial(to_device, device=device), batch)
|
|
|
|
self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
|
|
self.batch = batch
|
|
self.batch_size = get_batch_size(batch)
|
|
|
|
if self.microbatch_size is None:
|
|
assert self.batch_size % self.num_microbatch == 0, "Batch size should divided by the number of microbatch"
|
|
self.microbatch_size = self.batch_size // self.num_microbatch
|
|
if self.num_microbatch is None:
|
|
assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size"
|
|
self.num_microbatch = self.batch_size // self.microbatch_size
|
|
|
|
if not self.forward_only:
|
|
assert self.last_batch_size is None or self.last_batch_size == self.batch_size
|
|
assert self.batch_size == self.microbatch_size * self.num_microbatch
|
|
|
|
assert (
|
|
self.num_microbatch % self.stage_manager.num_stages == 0
|
|
), "Number of microbatch should be an integer multiple of number of pipeline parallel devices"
|
|
|
|
if self.forward_only:
|
|
self.num_microbatch = (self.batch_size - 1) // self.microbatch_size + 1
|
|
|
|
self.last_batch_size = self.batch_size
|
|
|
|
def load_micro_batch(self, model_chunk_id: int) -> Any:
|
|
"""Load a micro batch from the current batch.
|
|
|
|
Args:
|
|
microbatch_id (int): the current model chunk idx.
|
|
|
|
Returns:
|
|
Any: Micro batch.
|
|
"""
|
|
assert self.microbatch_offset[model_chunk_id] <= self.batch_size, "Microbatches exhausted"
|
|
micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size)
|
|
self.microbatch_offset[model_chunk_id] += self.microbatch_size
|
|
return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch)
|
|
|
|
def get_model_chunk_id(self, microbatch_id: int, is_forward: bool) -> int:
|
|
"""Helper method to get the model chunk ID given the iteration number.
|
|
|
|
Args:
|
|
microbatch_id (int): the current microbatch idx
|
|
forward (bool): if is the forward process
|
|
|
|
Returns:
|
|
int: The model chunk idx of the input microbatch_id
|
|
"""
|
|
assert (
|
|
microbatch_id < self.num_microbatch * self.num_model_chunks
|
|
), f"microbatch_id {microbatch_id} is out of range ({self.num_microbatch * self.num_model_chunks})"
|
|
microbatch_id_in_group = microbatch_id % (self.stage_manager.num_stages * self.num_model_chunks)
|
|
model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages
|
|
if not is_forward:
|
|
# Reverse order
|
|
model_chunk_id = self.num_model_chunks - model_chunk_id - 1
|
|
return model_chunk_id
|
|
|
|
def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> List:
|
|
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
|
|
For ZBV.
|
|
|
|
Args:
|
|
model_chunk_id (int): The current model chunk idx.
|
|
prev_rank (int, optional): The rank of the source of the tensor.
|
|
|
|
Returns:
|
|
Any: The input tensor or input tensor list.
|
|
Any: The wait handles for the communication.
|
|
"""
|
|
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
|
if model_chunk_id == 0:
|
|
################
|
|
# chunk = 0 & is_first_stage
|
|
# do nothing; cause u are chunk 0 in first rank, u have no prev rank;
|
|
#################
|
|
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
|
return []
|
|
|
|
################
|
|
# chunk = 0 & not is_first_stage
|
|
# Recv y from PREV_rank as input
|
|
#################
|
|
else:
|
|
prev_rank = self.stage_manager.get_prev_rank()
|
|
input_tensor, wait_handles = self.comm.recv_forward(
|
|
prev_rank=prev_rank, metadata_recv=self.tensor_metadata_recv[model_chunk_id]
|
|
)
|
|
if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None:
|
|
self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor)
|
|
self.recv_forward_buffer[model_chunk_id].append((input_tensor, wait_handles))
|
|
return wait_handles
|
|
|
|
else:
|
|
################
|
|
# chunk = 1 & is_last_stage
|
|
# do nothing; cause u get y from local_send_forward_buffer in schedule f
|
|
################
|
|
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
|
# return None, []
|
|
return []
|
|
|
|
################
|
|
# chunk = 1 & not is_last_stage
|
|
# recv y from NEXT_rank as input
|
|
################
|
|
else:
|
|
next_rank = self.stage_manager.get_next_rank()
|
|
input_tensor, wait_handles = self.comm.recv_forward(
|
|
next_rank, metadata_recv=self.tensor_metadata_recv[model_chunk_id]
|
|
)
|
|
if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None:
|
|
self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor)
|
|
self.recv_forward_buffer[model_chunk_id].append((input_tensor, wait_handles))
|
|
return wait_handles
|
|
|
|
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> List:
|
|
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
|
For ZBV.
|
|
|
|
Args:
|
|
model_chunk_id (int): The current model chunk idx.
|
|
next_rank (int, optional): The rank of the source of the tensor.
|
|
|
|
Returns:
|
|
Any: The input gradient tensor or gradient tensor list.
|
|
Any: The wait handles for the communication.
|
|
"""
|
|
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
|
if model_chunk_id == 0:
|
|
# bwd chunk0 is right V;
|
|
################
|
|
# chunk = 0 & is_last_stage
|
|
# do nothing; Already get dy from local_send_backward_buffer in schedule b
|
|
################
|
|
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
|
return []
|
|
|
|
################
|
|
# chunk = 0 & not is_last_stage
|
|
# Recv bwd from next stage;
|
|
################
|
|
else:
|
|
next_rank = self.stage_manager.get_next_rank()
|
|
output_tensor_grad, wait_handles = self.comm.recv_backward(
|
|
next_rank, metadata_recv=self.grad_metadata_recv[model_chunk_id]
|
|
)
|
|
if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None:
|
|
self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad)
|
|
self.recv_backward_buffer[model_chunk_id].append((output_tensor_grad, wait_handles))
|
|
return wait_handles
|
|
|
|
else:
|
|
# bwd chunk1 is left V;
|
|
################
|
|
# chunk = 1 & is_first_stage
|
|
# do nothing; get loss from local
|
|
################
|
|
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
|
return []
|
|
|
|
################
|
|
# chunk = 1 & not first stage
|
|
# recv_backward recv bwd from prev stage;
|
|
################
|
|
else:
|
|
prev_rank = self.stage_manager.get_prev_rank()
|
|
output_tensor_grad, wait_handles = self.comm.recv_backward(
|
|
next_rank=prev_rank, metadata_recv=self.grad_metadata_recv[model_chunk_id]
|
|
)
|
|
if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None:
|
|
self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad)
|
|
self.recv_backward_buffer[model_chunk_id].append((output_tensor_grad, wait_handles))
|
|
return wait_handles
|
|
|
|
def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List:
|
|
"""Sends the input tensor to the next stage in pipeline.
|
|
For ZBV.
|
|
|
|
Args:
|
|
model_chunk_id (int): The current model chunk idx.
|
|
next_rank (int, optional): The rank of the recipient of the tensor.
|
|
|
|
Returns:
|
|
Any: The wait handles for the communication.
|
|
"""
|
|
|
|
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
|
if model_chunk_id == 0:
|
|
################
|
|
# chunk = 0 && is_last_stage
|
|
# do nothing; hold y on local_send_forward_buffer
|
|
################
|
|
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
|
self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache
|
|
return []
|
|
|
|
################
|
|
# chunk = 0 && not is_last_stage
|
|
# self.comm.send_forward send y to NEXT stage
|
|
################
|
|
else:
|
|
next_rank = self.stage_manager.get_next_rank()
|
|
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
|
|
send_handles = self.comm.send_forward(
|
|
output_object=output_tensor,
|
|
next_rank=next_rank,
|
|
send_metadata=self.send_tensor_metadata[model_chunk_id],
|
|
)
|
|
self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache
|
|
return send_handles
|
|
|
|
else:
|
|
################
|
|
# chunk = 1 && is_first_stage
|
|
# do nothing; Already send LOSS to local_send_backward_buffer in schedule f send part
|
|
################
|
|
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
|
self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache
|
|
return []
|
|
|
|
################
|
|
# chunk = 1 && not is_first_stage
|
|
# self.comm.send_forward send y to PREV stage
|
|
################
|
|
else:
|
|
prev_rank = self.stage_manager.get_prev_rank()
|
|
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
|
|
send_handles = self.comm.send_forward(
|
|
output_tensor, prev_rank, send_metadata=self.send_tensor_metadata[model_chunk_id]
|
|
)
|
|
self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache
|
|
return send_handles
|
|
|
|
def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List:
|
|
"""Sends the gradient tensor to the previous stage in pipeline.
|
|
For ZBV.
|
|
|
|
Args:
|
|
model_chunk_id (int): The current model chunk idx.
|
|
prev_rank (int, optional): The rank of the recipient of the tensor
|
|
|
|
Returns:
|
|
Any: The wait handles for the communication.
|
|
"""
|
|
|
|
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
|
if model_chunk_id == 0:
|
|
# bwd chunk0 is right V;
|
|
################
|
|
# chunk = 0 && is_first_stage
|
|
# do nothing; cause u are the first chunk in first stage; bwd end
|
|
################
|
|
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
|
self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache
|
|
return []
|
|
|
|
################
|
|
# chunk = 0 && not is_first_stage
|
|
# Send dx to PREV stage;
|
|
################
|
|
else:
|
|
prev_rank = self.stage_manager.get_prev_rank()
|
|
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
|
|
send_handles = self.comm.send_backward(
|
|
input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata[model_chunk_id]
|
|
)
|
|
self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache
|
|
return send_handles
|
|
|
|
# bwd chunk1 is left V;
|
|
else:
|
|
################
|
|
# chunk = 1 && is_last_stage
|
|
# do nothing; Already send input_tensor_grad to local_send_bwd_buffer in schedule b;
|
|
################
|
|
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
|
self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache
|
|
return []
|
|
|
|
################
|
|
# chunk = 1 && not is_last_stage
|
|
# Send dx to NEXT stage;
|
|
################
|
|
else:
|
|
next_rank = self.stage_manager.get_next_rank()
|
|
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
|
|
send_handles = self.comm.send_backward(
|
|
input_tensor_grad, next_rank, send_metadata=self.send_grad_metadata[model_chunk_id]
|
|
)
|
|
self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache
|
|
return send_handles
|
|
|
|
def forward_step(
|
|
self,
|
|
model_chunk: Union[ModuleList, Module],
|
|
model_chunk_id: int,
|
|
micro_batch: Optional[dict],
|
|
input_obj: Optional[dict],
|
|
criterion: Callable,
|
|
accum_loss: Optional[torch.Tensor] = None,
|
|
outputs: Optional[List[Any]] = None,
|
|
) -> Union[torch.Tensor, dict]:
|
|
"""Forward one step of the pipeline
|
|
Args:
|
|
model_chunk (ModuleList or Module): Model Chunk to be run;
|
|
model_chunk_id (int): The current model chunk idx;
|
|
input_obj (Optional[dict]): x;
|
|
criterion (Callable): loss function;
|
|
accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None.
|
|
outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None.
|
|
|
|
Returns:
|
|
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
|
|
"""
|
|
# Load input ids, attention mask and labels
|
|
# for the first stage, input_obj is None; So,we use micro_batch as input_obj
|
|
# for other stages, input_obj is the output of the previous/next stage containing hidden_states etc.
|
|
# Only attention_mask from micro_batch is used
|
|
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
|
# fwd calculate
|
|
internal_inputs = {} if input_obj is None else input_obj
|
|
internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
|
|
output_obj = model_forward(model_chunk, micro_batch, internal_inputs)
|
|
# last layer in model
|
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
|
loss = criterion(output_obj, micro_batch) / self.num_microbatch
|
|
if accum_loss is not None:
|
|
accum_loss.add_(loss.detach())
|
|
if outputs is not None:
|
|
outputs.append(tree_map(detach, output_obj))
|
|
return loss
|
|
else:
|
|
return output_obj
|
|
|
|
def backward_b_step(
|
|
self,
|
|
model_chunk: Union[ModuleList, Module],
|
|
model_chunk_id: int,
|
|
optimizer: OptimizerWrapper,
|
|
# micro_batch: Optional[dict],
|
|
input_obj: Optional[dict],
|
|
output_obj: Union[dict, torch.Tensor],
|
|
output_obj_grad: Optional[dict],
|
|
) -> Optional[dict]:
|
|
"""Backward dx step of the pipeline; we calculate "dx = w*dy" here;
|
|
|
|
Args:
|
|
model_chunk (ModuleList or Module): Model Chunk to be run;
|
|
model_chunk_id (int): The current model chunk idx;
|
|
optimizer (OptimizerWrapper): Optimizer to update the model
|
|
input_obj (Optional[Tuple(dict)]): x. (microbatch, input_obj)
|
|
output_obj (Union[dict, torch.Tensor]): y.
|
|
output_obj_grad (dict): dy.
|
|
|
|
Returns:
|
|
Optional[dict]: dx.
|
|
"""
|
|
# calculate bwd b step ; only dx = w*dy;
|
|
|
|
# Retain the grad on the input_obj. No need retain_grad microbatch
|
|
if input_obj is not None:
|
|
tree_map(retain_grad, input_obj)
|
|
|
|
# x, y, dy list for backward_by_grad; Type: list[tensor];
|
|
input_obj_ = []
|
|
output_obj_ = []
|
|
output_obj_grad_ = []
|
|
|
|
# For chunk 0 stage 0, use micro_batch as input_obj_; and we don't have to cal microbatch dx.
|
|
|
|
# For loss backward; output_obj is loss; output_obj_grad should be None
|
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
|
assert output_obj_grad is None
|
|
input_obj_, _ = tree_flatten(input_obj)
|
|
output_obj_.append(output_obj) # LOSS
|
|
output_obj_grad_.append(output_obj_grad) # None
|
|
|
|
# For other chunk stage, use input_obj as input_obj_;
|
|
else:
|
|
input_obj_, _ = tree_flatten(input_obj)
|
|
output_obj_, _ = tree_flatten(output_obj) # y
|
|
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
|
|
|
|
# filter item which is not torch.Tensor
|
|
input_obj_ = [v for v in input_obj_ if isinstance(v, torch.Tensor) or v is None]
|
|
output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None]
|
|
output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None]
|
|
|
|
try:
|
|
ctx = optimizer.no_sync()
|
|
except AttributeError:
|
|
ctx = model_chunk.no_sync()
|
|
with ctx:
|
|
optimizer.backward_by_grad(
|
|
tensor=output_obj_,
|
|
grad=output_obj_grad_,
|
|
# inputs=input_obj_,
|
|
retain_graph=False,
|
|
)
|
|
# Format output_obj_grad
|
|
input_obj_grad = dict()
|
|
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
|
pass
|
|
else:
|
|
for k, v in input_obj.items():
|
|
if isinstance(v, torch.Tensor) and v.grad is not None:
|
|
input_obj_grad[k] = v.grad
|
|
return input_obj_grad
|
|
|
|
def backward_w_step(
|
|
self,
|
|
model_chunk: Union[ModuleList, Module],
|
|
model_chunk_id: int,
|
|
optimizer: OptimizerWrapper,
|
|
output_obj: Union[dict, torch.Tensor],
|
|
output_obj_grad: Optional[dict],
|
|
):
|
|
"""Backward dw step of the pipeline; we calculate "dw = x*dy" here;
|
|
|
|
Args:
|
|
model_chunk (ModuleList or Module): Model Chunk to be run;
|
|
model_chunk_id (int): The current model chunk idx;
|
|
optimizer (OptimizerWrapper): Optimizer to update the model
|
|
output_obj (Union[dict, torch.Tensor]): y.
|
|
output_obj_grad (dict): dy.
|
|
|
|
Returns:
|
|
Nothing need to return; we only calculate dw then update w;
|
|
"""
|
|
# calculate bwd w step ; only dw = x*dy;
|
|
|
|
# y, dy list for w backward_by_grad; Type: list[tensor];
|
|
output_obj_ = []
|
|
output_obj_grad_ = []
|
|
|
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
|
# loss backward; output_obj is loss;
|
|
output_obj_.append(output_obj) # LOSS
|
|
output_obj_grad_.append(None) # None
|
|
else:
|
|
output_obj_, _ = tree_flatten(output_obj) # y
|
|
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
|
|
|
|
# filter item which is not torch.Tensor
|
|
output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None]
|
|
output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None]
|
|
|
|
optimizer.backward_by_grad(
|
|
tensor=output_obj_,
|
|
grad=output_obj_grad_,
|
|
inputs=list(model_chunk.parameters()),
|
|
retain_graph=False,
|
|
)
|
|
|
|
def schedule_f(
|
|
self,
|
|
scheduled_node,
|
|
model_chunk: torch.nn.ModuleList,
|
|
model_chunk_id: int,
|
|
criterion: Callable,
|
|
accum_loss: Optional[torch.Tensor] = None,
|
|
outputs: Optional[List[Any]] = None,
|
|
):
|
|
"""A complete forward schedule; Include recv fwd --> cal fwd --> send fwd;
|
|
|
|
Args:
|
|
scheduled_node:
|
|
model_chunk (ModuleList or Module): Model Chunk to be run;
|
|
model_chunk_id (int): The current model chunk idx;
|
|
criterion (Callable): loss function;
|
|
accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None.
|
|
outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None.
|
|
|
|
Returns:
|
|
Nothing.
|
|
"""
|
|
micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id)
|
|
# Step1: recv fwd
|
|
if model_chunk_id == 0:
|
|
# is first stage; get input from microbatch
|
|
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
|
input_obj = None # (tensor, wait_handle)
|
|
else:
|
|
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
|
for h in input_obj[1]:
|
|
h.wait()
|
|
input_obj = input_obj[0]
|
|
else:
|
|
# is last stage; recv from local
|
|
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
|
input_obj = self.local_send_forward_buffer.pop(0)
|
|
# not last stage; recv from next
|
|
else:
|
|
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
|
for h in input_obj[1]:
|
|
h.wait()
|
|
input_obj = input_obj[0]
|
|
# Here, let input_obj.requires_grad_()
|
|
# if input_obj is not None:
|
|
if not isinstance(input_obj, torch.Tensor):
|
|
tree_map(require_grad, input_obj)
|
|
|
|
# Also requires_grad_ for micro_batch in stage 0 chunk 0 fwd,
|
|
# tree_map(torch.Tensor.requires_grad_, micro_batch)
|
|
|
|
# Step2: fwd step
|
|
output_obj = self.forward_step(
|
|
model_chunk=model_chunk,
|
|
model_chunk_id=model_chunk_id,
|
|
micro_batch=micro_batch,
|
|
input_obj=input_obj,
|
|
criterion=criterion,
|
|
accum_loss=accum_loss,
|
|
outputs=outputs,
|
|
)
|
|
|
|
# Step3:
|
|
# 3-1:detach output; detach output for send fwd;
|
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
|
# We should not detach bwd LOSS
|
|
pass
|
|
else:
|
|
# detach output
|
|
detached_output_obj = tree_map(detach, output_obj)
|
|
# 3-2 clone detached_output_obj
|
|
detached_output_obj = tree_map(clone, detached_output_obj)
|
|
|
|
# 3-3 release cloned output.data; release_tensor_data output for bwd b & w; (do not detach output)
|
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
|
# We should not release_tensor_data bwd LOSS
|
|
pass
|
|
else:
|
|
# release_tensor_data output
|
|
tree_map(release_tensor_data, output_obj)
|
|
|
|
# add input and output object for backward b
|
|
self.input_tensors[model_chunk_id].append(input_obj)
|
|
|
|
# for bwd b&w, we only need the graph(grad_fn) of output_obj
|
|
# Do not release_tensor_data loss, release_tensor_data other output_obj;
|
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
|
self.output_tensors[model_chunk_id].append(output_obj)
|
|
else:
|
|
self.output_tensors[model_chunk_id].append(output_obj)
|
|
|
|
# add output to send_fwd_buffer
|
|
if model_chunk_id == 0: # chunk 0
|
|
# is last stage; send to local_send_forward_buffer
|
|
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
|
self.local_send_forward_buffer.append(detached_output_obj)
|
|
else:
|
|
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
|
|
else: # chunk 1
|
|
# is first stage; end of fwd; do nothing
|
|
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
|
pass
|
|
else:
|
|
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
|
|
|
|
def schedule_b(
|
|
self,
|
|
scheduled_node,
|
|
model_chunk: Union[ModuleList, Module],
|
|
model_chunk_id: int,
|
|
optimizer: OptimizerWrapper,
|
|
):
|
|
"""A complete backward b schedule; Include recv bwd --> cal bwd step --> send bwd;
|
|
|
|
Args:
|
|
scheduled_node:
|
|
model_chunk (ModuleList or Module): Model Chunk to be run;
|
|
model_chunk_id (int): The current model chunk idx;
|
|
Returns:
|
|
Nothing.
|
|
"""
|
|
# Step1: recv bwd
|
|
if model_chunk_id == 0:
|
|
# chunk0 is last stage; recv output_grad from local_send_backward_buffer
|
|
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
|
output_tensor_grad = self.local_send_backward_buffer.pop(0)
|
|
# chunk0 not last stage; recv output_grad from recv_backward_buffer
|
|
else:
|
|
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
|
|
for h in output_tensor_grad[1]:
|
|
h.wait()
|
|
output_tensor_grad = output_tensor_grad[0]
|
|
else:
|
|
# chunk1, is first stage; recv LOSS from local send bwd buffer
|
|
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
|
output_tensor_grad = None
|
|
# chunk1, not first stage; recv output_grad from recv_backward_buffer
|
|
else:
|
|
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
|
|
for h in output_tensor_grad[1]:
|
|
h.wait()
|
|
output_tensor_grad = output_tensor_grad[0]
|
|
|
|
# get input and output object from buffer;
|
|
input_obj = self.input_tensors[model_chunk_id].pop(0)
|
|
output_obj = self.output_tensors[model_chunk_id].pop(0)
|
|
|
|
input_object_grad = self.backward_b_step(
|
|
model_chunk=model_chunk,
|
|
model_chunk_id=model_chunk_id,
|
|
optimizer=optimizer,
|
|
input_obj=input_obj,
|
|
output_obj=output_obj,
|
|
output_obj_grad=output_tensor_grad,
|
|
)
|
|
|
|
# Step3: send bwd
|
|
if model_chunk_id == 0:
|
|
# do nothing; end of bwd;
|
|
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
|
pass
|
|
# save input_object_grad to send_backward_buffer
|
|
else:
|
|
self.send_backward_buffer[model_chunk_id].append(input_object_grad)
|
|
else:
|
|
# send to local_send_backward_buffer
|
|
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
|
self.local_send_backward_buffer.append(input_object_grad)
|
|
# send to next
|
|
else:
|
|
self.send_backward_buffer[model_chunk_id].append(input_object_grad)
|
|
WeightGradStore.flush(chunk=model_chunk_id)
|
|
|
|
def schedule_w(
|
|
self,
|
|
scheduled_node,
|
|
model_chunk: Union[ModuleList, Module],
|
|
model_chunk_id: int,
|
|
optimizer: OptimizerWrapper,
|
|
):
|
|
"""A complete backward w schedule; Include get y & dy from buffer --> cal bwd w step(cal dw & update w);
|
|
|
|
Args:
|
|
scheduled_node:
|
|
model_chunk (ModuleList or Module): Model Chunk to be run;
|
|
model_chunk_id (int): The current model chunk idx;
|
|
Returns:
|
|
Nothing.
|
|
"""
|
|
WeightGradStore.pop(chunk=model_chunk_id)
|
|
|
|
def run_forward_only(
|
|
self,
|
|
model_chunk: Union[ModuleList, Module],
|
|
data_iter: Iterable,
|
|
criterion: Callable[..., Any],
|
|
return_loss: bool = False,
|
|
return_outputs: bool = False,
|
|
) -> Dict:
|
|
assert self.forward_only
|
|
|
|
# prepare batch
|
|
self.load_batch(data_iter)
|
|
|
|
# prepare accum loss & output
|
|
accum_loss = None
|
|
|
|
# reset accum loss at fwd end;
|
|
if return_loss and self.stage_manager.is_first_stage(ignore_chunk=True):
|
|
accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device())
|
|
|
|
outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None
|
|
|
|
# while we still have schedules_node in self.schedules
|
|
for it in range(len(self.schedules)):
|
|
scheduled_node = self.schedules[it]
|
|
|
|
if scheduled_node.type in {"RECV_FORWARD", "SEND_FORWARD"}:
|
|
# communication
|
|
communication_func = self.communication_map[scheduled_node.type]
|
|
communication_func(scheduled_node.chunk)
|
|
if scheduled_node.type == "F":
|
|
self.schedule_f(
|
|
scheduled_node=scheduled_node,
|
|
model_chunk=model_chunk,
|
|
model_chunk_id=scheduled_node.chunk,
|
|
criterion=criterion,
|
|
accum_loss=accum_loss,
|
|
outputs=outputs,
|
|
)
|
|
# return loss & output
|
|
if outputs is not None:
|
|
outputs = merge_batch(outputs)
|
|
return {"loss": accum_loss, "outputs": outputs}
|
|
|
|
def run_forward_backward(
|
|
self,
|
|
model_chunk: Union[ModuleList, Module],
|
|
data_iter: Iterable,
|
|
criterion: Callable[..., Any],
|
|
optimizer: Optional[OptimizerWrapper] = None,
|
|
return_loss: bool = False,
|
|
return_outputs: bool = False,
|
|
) -> Dict:
|
|
"""
|
|
Runs Zerobubble schedule, with communication between pipeline stages.
|
|
"""
|
|
# prepare batch
|
|
self.load_batch(data_iter)
|
|
|
|
# prepare accum loss & output
|
|
accum_loss = None
|
|
|
|
# reset accum loss at fwd end;
|
|
if return_loss and self.stage_manager.is_first_stage(ignore_chunk=True):
|
|
accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device())
|
|
|
|
outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None
|
|
|
|
# while we still have schedules_node in self.schedules
|
|
schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank)
|
|
for it in range(len(schedule)):
|
|
scheduled_node = schedule[it]
|
|
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
|
|
# communication
|
|
communication_func = self.communication_map[scheduled_node.type]
|
|
wait_handle = communication_func(scheduled_node.chunk)
|
|
# We wait recv handle in fwd step and bwd step. Here only need to wait for send handle
|
|
if scheduled_node.type in {"SEND_FORWARD", "SEND_BACKWARD"}:
|
|
self.wait_handles.append(wait_handle)
|
|
elif scheduled_node.type == "F":
|
|
self.schedule_f(
|
|
scheduled_node=scheduled_node,
|
|
model_chunk=model_chunk,
|
|
model_chunk_id=scheduled_node.chunk,
|
|
criterion=criterion,
|
|
accum_loss=accum_loss,
|
|
outputs=outputs,
|
|
)
|
|
elif scheduled_node.type == "B":
|
|
self.schedule_b(
|
|
scheduled_node=scheduled_node,
|
|
model_chunk=model_chunk,
|
|
model_chunk_id=scheduled_node.chunk,
|
|
optimizer=optimizer,
|
|
)
|
|
elif scheduled_node.type == "W":
|
|
self.schedule_w(
|
|
scheduled_node=scheduled_node,
|
|
model_chunk=model_chunk,
|
|
model_chunk_id=scheduled_node.chunk,
|
|
optimizer=optimizer,
|
|
)
|
|
# wait here to ensure all communication is done
|
|
for h in self.wait_handles:
|
|
for hh in h:
|
|
hh.wait()
|
|
# return loss & output
|
|
if outputs is not None:
|
|
outputs = merge_batch(outputs)
|
|
return {"loss": accum_loss, "outputs": outputs}
|
|
|
|
def forward_backward_step(
|
|
self,
|
|
model_chunk: Union[ModuleList, Module],
|
|
data_iter: Iterable,
|
|
criterion: Callable[..., Any],
|
|
optimizer: Optional[OptimizerWrapper] = None,
|
|
return_loss: bool = False,
|
|
return_outputs: bool = False,
|
|
) -> dict:
|
|
"""
|
|
Args:
|
|
model_chunk (ModuleList or Module): Model Chunk to be trained. Original interleaved uses a module list whereas shardformer uses entire model + layer specification
|
|
data_iter (Iterable): Data iterator.
|
|
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
|
|
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
|
|
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
|
|
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
|
|
|
|
Returns:
|
|
dict: A dict with keys: 'loss' and 'outputs'.
|
|
"""
|
|
self.forward_only = not torch.is_grad_enabled()
|
|
if optimizer is None:
|
|
assert self.forward_only, "Optimizer should be passed when doing backward."
|
|
|
|
if self.forward_only:
|
|
result = self.run_forward_only(model_chunk, data_iter, criterion, return_loss, return_outputs)
|
|
else:
|
|
result = self.run_forward_backward(
|
|
model_chunk, data_iter, criterion, optimizer, return_loss, return_outputs
|
|
)
|
|
|
|
self.assert_buffer_empty()
|
|
return result
|