[zerobubble]Support ZeroBubble Pipeline (#6034)

* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;

* [feat] add dw test;

* [fix] fix weight not close;

* [update] update text;

* [feat] add test run_fwd_bwd automatic scheduling;

* [feat] split communication and calculation; fix pop empty send_bwd_buffer error;

* [feat] add test for p & p grad;

* [feat] add comments for ZBV func;

* [fix] rm useless assign and comments;

* [fix] fix ci test; add pytest;

* [feat] add run_fwd_bwd_with_microbatch  (replace input) & test; add p&p.grad assert close test & all pass;

* [feat] add apply v_schedule graph; p & p.grad assert err exist;

* [fix] update

* [feat] fix ci; add assert;

* [feat] fix poc format

* [feat] fix func name & ci; add comments;

* [fix] fix poc test; add comments in poc;

* [feat] add optim backward_b_by_grad

* [feat] fix optimizer bwd b & w; support return accum loss & output

* [feat] add fwd_bwd_step, run_fwd_only;

* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;

* [fix] fix communication_map;

* [feat] update test; rm comments;

* [fix] rm zbv in hybridplugin

* [fix] fix optim bwd;

* [fix] fix optim bwd;

* [fix] rm output.data after send fwd;

* [fix] fix bwd step if condition; remove useless comments and format info;

* [fix] fix detach output & release output;

* [fix] rm requir_grad for output;

* [fix] fix requir grad position and detach position and input&output local buffer append position;

* [feat] add memory assertation;

* [fix] fix mem check;

* [fix] mem assertation'

* [fix] fix mem assertation

* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;

* [fix] fix model zoo import;

* [fix] fix redundant detach & clone; add buffer assertation in the end;

* [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap;

* [fix] update optim state dict assert (include param group & state); fix mem assert after add optim;

* [fix] add testcase with microbatch 4;
pull/6065/head
duanjunwen 2024-09-10 17:33:09 +08:00 committed by GitHub
parent 7cf9df07bc
commit 11ae6848c6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 2131 additions and 4 deletions

View File

@ -1103,7 +1103,7 @@ class HybridParallelPlugin(PipelinePluginBase):
self.stage_manager = PipelineStageManager(
self.pg_mesh,
pipeline_axis=self.pp_axis,
enable_interleave=pp_style == "interleaved",
enable_interleave=(pp_style == "interleaved"),
num_model_chunks=num_model_chunks,
num_layers_per_stage=num_layers_per_stage,
)

View File

@ -55,8 +55,25 @@ class OptimizerWrapper:
"""
loss.backward(*args, **kwargs)
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
torch.autograd.backward(tensor, grad)
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
"""
Performs a backward pass for dx or dw,
for dx, we only calculate dx = w*dy here
for dw, we only calculate dw = x*dy here
Args:
tensor (Tensor): y or loss of current chunk;
grad_tensors (Tensor): dy of current chunk;
input_obj (Tensor): for dx, input_obj is x of current chunk;
for dw, input_obj is w of current chunk;
retain_graph (bool): default to be True, we retain graph in backward_b
"""
torch.autograd.backward(
tensors=tensor,
grad_tensors=grad,
inputs=inputs,
retain_graph=retain_graph,
)
def state_dict(self):
"""

View File

@ -1,11 +1,12 @@
from .p2p import PipelineP2PCommunication
from .schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, PipelineSchedule
from .schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, PipelineSchedule, ZeroBubbleVPipeScheduler
from .stage_manager import PipelineStageManager
__all__ = [
"PipelineSchedule",
"OneForwardOneBackwardSchedule",
"InterleavedSchedule",
"ZeroBubbleVPipeScheduler",
"PipelineP2PCommunication",
"PipelineStageManager",
]

View File

@ -1,9 +1,11 @@
from .base import PipelineSchedule
from .interleaved_pp import InterleavedSchedule
from .one_f_one_b import OneForwardOneBackwardSchedule
from .zero_bubble_pp import ZeroBubbleVPipeScheduler
__all__ = [
"PipelineSchedule",
"OneForwardOneBackwardSchedule",
"InterleavedSchedule",
"ZeroBubbleVPipeScheduler",
]

View File

@ -0,0 +1,494 @@
# Refer from Zero Bubble Pipeline Parallelism.
# Github: https://github.com/sail-sg/zero-bubble-pipeline-parallelism
# Paper: https://arxiv.org/abs/2401.10241
# The following applies to all files unless otherwise noted:
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from collections import deque
from dataclasses import dataclass
@dataclass(eq=True, frozen=True)
class ScheduledNode:
type: str
chunk: int
stage: int
minibatch: int
start_time: int = 0
completion_time: int = 0
rollback: bool = False
class PipelineGraph(object):
"""PipelineGraph"""
def __init__(
self,
n_stage,
n_micro,
f_cost,
b_cost,
w_cost,
c_cost,
f_mem,
b_mem,
w_mem,
max_mem=None,
):
self.n_node = 6 * n_stage * n_micro
self.n_stage = n_stage
self.n_micro = n_micro
self.f_cost = f_cost
self.b_cost = b_cost
self.w_cost = w_cost
self.c_cost = c_cost
self.f_mem = f_mem
self.b_mem = b_mem
self.w_mem = w_mem
self.fbw_cost = [f_cost, b_cost, w_cost]
self.fbw_mem = [f_mem, b_mem, w_mem]
self.max_mem = max_mem or f_mem * self.n_stage * 2
def get_id(self, cat, chunk, stage, micro):
return (
cat * 2 * self.n_stage * self.n_micro + chunk * self.n_stage * self.n_micro + stage * self.n_micro + micro
)
def try_v_schedule(self, fill_f=True, fill_b=True, approved_bubble=None):
count = []
for i in range(self.n_stage):
count.append([0] * 6)
end_time = [-1] * self.n_node
cur_time = [0] * self.n_stage
mem = [0] * self.n_stage
stage_bubble = [0] * self.n_stage
pending_w = [deque() for _ in range(self.n_stage)]
schedule = [[] for _ in range(self.n_stage)]
stage_str = [" " * i for i in range(self.n_stage)]
if approved_bubble is None:
approved_bubble = [-1] * self.n_stage
max_approved_bubble = max(approved_bubble)
def get_max_stage_bubble(stage=-1):
max_stage_bubble = 0
for bb in stage_bubble:
max_stage_bubble = max(max_stage_bubble, bb)
if stage >= 0:
max_stage_bubble = max(max_stage_bubble, max_approved_bubble - approved_bubble[stage])
return max_stage_bubble
def put_w(stage):
assert len(pending_w[stage]) > 0
_, chunk_, _ = pending_w[stage].popleft()
put(2, chunk_, stage)
def put(cat, chunk, stage, assert_cnt=True):
_tmp = _no_bubble = cur_time[stage] + self.fbw_cost[cat]
_cnt = count[stage][cat * 2 + chunk]
# assert _cnt < self.n_micro
if _cnt >= self.n_micro:
if not assert_cnt:
stage_str[stage] += " "
cur_time[stage] = _tmp # TODO
return
assert False
assert mem[stage] + self.fbw_mem[cat] <= self.max_mem
stage_str[stage] += "FfBbWw"[cat * 2 + chunk] + str(_cnt + 1) + " " * (3 - len(str(_cnt + 1)))
if cat > 0 or chunk > 0:
last_id = cat * 2 + chunk - 1
if cat < 2:
# if end_time[self.get_id(last_id // 2, last_id % 2, stage, _cnt)] < 0:
# print(cat, chunk, stage, _cnt)
# self.print_details(end_time)
assert end_time[self.get_id(last_id // 2, last_id % 2, stage, _cnt)] >= 0
else:
assert end_time[self.get_id(1, chunk, stage, _cnt)] >= 0
if chunk == 1 and cat < 2:
if stage < self.n_stage - 1:
_fa_id = self.get_id(cat, chunk, stage + 1, _cnt)
assert end_time[_fa_id] >= 0
_tmp = max(_tmp, end_time[_fa_id] + self.c_cost + self.fbw_cost[cat])
if chunk == 0 and cat < 2:
if stage > 0:
_fa_id = self.get_id(cat, chunk, stage - 1, _cnt)
# if end_time[_fa_id] < 0:
# print(cat, chunk, stage, _cnt)
# self.print_details(end_time)
assert end_time[_fa_id] >= 0, f"{cat}, {chunk}, {stage}, {_cnt}"
_tmp = max(_tmp, end_time[_fa_id] + self.c_cost + self.fbw_cost[cat])
_id = self.get_id(cat, chunk, stage, _cnt)
if count[stage][0] > 0:
stage_bubble[stage] += _tmp - _no_bubble
end_time[_id] = _tmp
cur_time[stage] = _tmp
mem[stage] += self.fbw_mem[cat]
# noinspection PyTypeChecker
schedule[stage].append((cat, chunk, _cnt))
if cat == 1:
pending_w[stage].append((2, chunk, _cnt))
count[stage][cat * 2 + chunk] += 1
# for _ in range(2 * self.n_stage):
# for i in range(self.n_stage):
# if count[i][1] >= count[i][0]:
# put(0, 0, i, assert_cnt=False)
# continue
# if i == self.n_stage - 1:
# put(0, 1, i, assert_cnt=False)
# continue
# fa_id = self.get_id(0, 1, i + 1, count[i][1])
# if 0 <= end_time[fa_id] < cur_time[i + 1]: # TODO
# put(0, 1, i, assert_cnt=False)
# else:
# put(0, 0, i, assert_cnt=False)
for i in range(self.n_stage):
put(0, 0, i)
for i in range(self.n_stage - 1, -1, -1):
if i == self.n_stage - 1:
put(0, 1, i)
continue
tmp = end_time[self.get_id(0, 1, i + 1, 0)] + self.c_cost
while (
mem[i] + self.fbw_mem[0] * (2 + i * 2) <= self.max_mem
and cur_time[i] + self.fbw_cost[0] <= tmp
and count[i][0] < self.n_micro
):
for j in range(i + 1):
put(0, 0, j)
put(0, 1, i)
iter_chunk_ = 0
end_tmp = 0
for i in range(self.n_stage):
if i == 0:
end_tmp = cur_time[0] + self.fbw_cost[1]
continue
tmp = end_tmp + self.c_cost
while (
count[i][0] + count[i][1] < count[i - 1][0] + count[i - 1][1]
or count[i][1] <= count[i - 1][1] < self.n_micro
):
for j in range(self.n_stage - 1, i - 1, -1):
if count[j][iter_chunk_] < self.n_micro:
put(0, iter_chunk_, j)
iter_chunk_ = 1 - iter_chunk_
# while mem[i] + self.fbw_mem[0] <= self.max_mem and cur_time[i] + self.fbw_cost[0] <= tmp:
# if iter_chunk_ == 0 and count[i][0] >= count[i - 1][0]:
# break
# for j in range(self.n_stage - 1, i - 1, -1):
# if count[j][iter_chunk_] < self.n_micro:
# put(0, iter_chunk_, j)
# iter_chunk_ = 1 - iter_chunk_
# end_tmp = max(tmp, cur_time[i]) + self.fbw_cost[1]
# init_bubble = get_max_stage_bubble()
# print(stage_bubble)
for _ in range(2 * self.n_micro):
# check mem before putting b
for i in range(self.n_stage):
while mem[i] + self.fbw_mem[1] > self.max_mem:
assert len(pending_w[i]) > 0
put_w(i)
b0_ranks, b1_ranks = [], []
for i in range(self.n_stage):
if count[i][3] >= count[i][2]:
b0_ranks.append(i)
elif i == self.n_stage - 1:
b1_ranks.append(i)
else:
fa_id = self.get_id(1, 1, i + 1, count[i][3])
if end_time[fa_id] >= 0 or count[i][2] >= self.n_micro:
b1_ranks.append(i)
else:
b0_ranks.append(i)
b_ranks = []
# put b1
for i in reversed(b1_ranks):
b_ranks.append((i, 1))
# put b0
for i in b0_ranks:
b_ranks.append((i, 0))
for i, _chunk_ in b_ranks:
fa_id = -1
if _chunk_ == 1 and i < self.n_stage - 1:
fa_id = self.get_id(1, 1, i + 1, count[i][3])
if _chunk_ == 0 and i > 0:
fa_id = self.get_id(1, 0, i - 1, count[i][2])
while (
len(pending_w[i]) > 0
and fa_id >= 0
and end_time[fa_id] + self.c_cost >= cur_time[i] + self.fbw_cost[2]
):
# fill the bubble
put_w(i)
if (
len(pending_w[i]) > 0
and end_time[fa_id] + self.c_cost - cur_time[i] > get_max_stage_bubble(i) - stage_bubble[i]
):
if _chunk_ == 1:
put_w(i)
elif fill_b:
put_w(i)
put(1, _chunk_, i)
# put f
for i in range(self.n_stage):
if count[i][1] >= self.n_micro:
continue
put_item = None
if count[i][1] >= count[i][0]:
put_item = 0
elif i == self.n_stage - 1:
put_item = 1
else:
if end_time[self.get_id(0, 1, i + 1, count[i][1])] >= 0:
put_item = 1
elif count[i][0] < self.n_micro:
if i == 0:
put_item = 0
elif end_time[self.get_id(0, 0, i - 1, count[i][0])] >= 0:
put_item = 0
if put_item is None:
continue
# check mem before putting f
while mem[i] + self.fbw_mem[0] > self.max_mem:
assert len(pending_w[i]) > 0
put_w(i)
fa_id = -1
if put_item == 0 and i > 0:
fa_id = self.get_id(0, 0, i - 1, count[i][0])
if put_item == 1 and i < self.n_stage - 1:
fa_id = self.get_id(0, 1, i + 1, count[i][1])
while (
len(pending_w[i]) > 0
and fa_id >= 0
and end_time[fa_id] + self.c_cost >= cur_time[i] + self.fbw_cost[2]
):
# fill the bubble
put_w(i)
if (
len(pending_w[i]) > 0
and end_time[fa_id] + self.c_cost - cur_time[i] > get_max_stage_bubble(i) - stage_bubble[i]
):
if fill_f:
put_w(i)
put(0, put_item, i)
for i in range(self.n_stage):
while len(pending_w[i]) > 0:
put_w(i)
# for i in range(self.n_stage):
# print(stage_str[i])
max_bubble = get_max_stage_bubble()
expected_time = sum(self.fbw_cost) * self.n_micro * 2
max_bubble / expected_time
# print("%6.4f" % bubble_rate, "->", stage_bubble)
if max_approved_bubble < 0 or max_bubble < max_approved_bubble:
_schedule, _end_time, _max_bubble = self.try_v_schedule(
fill_f=fill_f,
fill_b=fill_b,
approved_bubble=stage_bubble,
)
if _max_bubble < max_bubble:
return _schedule, _end_time, _max_bubble
# print("%2d %3d, [%5d %5d %5d], %6d -> %6.4f %6.4f" % \
# (self.n_stage, self.n_micro, *self.fbw_cost, self.max_mem // self.f_mem, init_bubble / expected_time, bubble_rate), max_bubble)
return schedule, end_time, max_bubble
def print_details(self, end_time, print_scaling=1):
for stage in range(self.n_stage):
stage_str = ["."] * int(max(end_time) / print_scaling)
for _cat in range(3):
for _chunk in range(2):
for _micro in range(self.n_micro):
_id = self.get_id(_cat, _chunk, stage, _micro)
if end_time[_id] < 0:
continue
end = int(end_time[_id] / print_scaling)
start = int((end_time[_id] - self.fbw_cost[_cat]) / print_scaling)
for j in range(start, end):
if j == start or j == end - 1:
stage_str[j] = "FfBbWw"[_cat * 2 + _chunk]
elif j == start + 1:
if _micro >= 10:
stage_str[j] = str(_micro // 10)
else:
stage_str[j] = str(_micro)
elif j == start + 2 and _micro >= 10:
stage_str[j] = str(_micro % 10)
else:
stage_str[j] = "-"
_str = ""
for _c in stage_str:
_str += _c
print(_str)
def get_v_schedule(self, only_run_time=False):
schedule, end_time, max_bubble = None, None, None
expected_time = sum(self.fbw_cost) * self.n_micro * 2
for fill_b in [True, False]:
for fill_f in [True, False]:
_schedule, _end_time, _max_bubble = self.try_v_schedule(fill_b=fill_b, fill_f=fill_f)
# print("")
if max_bubble is None or _max_bubble < max_bubble:
max_bubble = _max_bubble
schedule = _schedule
end_time = _end_time
if only_run_time:
return max_bubble + expected_time
# self.print_details(end_time, print_scaling=1)
max_bubble / (expected_time + max_bubble)
# print("%2d %3d, [%5d %5d %5d %5d], %6d -> %6.4f" % \
# (self.n_stage, self.n_micro, *self.fbw_cost, self.c_cost, self.max_mem // self.f_mem, bubble_rate))
local_order = [[] for _ in range(self.n_stage)]
comm_id = {}
comm_id_counter = 0
post_validation_time = 0
for i in range(self.n_stage - 1, -1, -1):
pv_id = min(2 * (self.n_stage - 1 - i), self.n_micro - 1)
post_validation_time = max(
post_validation_time, end_time[self.get_id(0, 0, i, pv_id)] - self.fbw_cost[0] - self.c_cost
)
# post_validation_time = 0
# print(i, pv_id, post_validation_time)
for it in ["RECV_", "SEND_", ""]:
if i == 0 and it == "SEND_":
continue
if i == self.n_stage - 1 and it == "RECV_":
continue
# stage_ = i - 1 if it == "RECV_" else i
stage_ = i
local_order[stage_].append(
ScheduledNode(
type=it + "POST_VALIDATION",
chunk=0,
stage=stage_,
minibatch=0,
start_time=post_validation_time,
completion_time=post_validation_time,
)
)
comm_id[local_order[stage_][-1]] = comm_id_counter
comm_id_counter += 1
for i in range(self.n_stage):
for _cat_, _chunk_, _micro_ in schedule[i]:
complete_time = end_time[self.get_id(_cat_, _chunk_, i, _micro_)]
local_order[i].append(
ScheduledNode(
type="FBW"[_cat_],
chunk=_chunk_ if _cat_ == 0 else 1 - _chunk_,
stage=i,
minibatch=_micro_,
start_time=complete_time - self.fbw_cost[_cat_],
completion_time=complete_time,
)
)
if _cat_ == 2: # no communication for W
continue
cat_str = "FORWARD" if _cat_ == 0 else "BACKWARD"
def communicate(send_recv, stage_):
# noinspection PyTypeChecker
local_order[stage_].append(
ScheduledNode(
type=send_recv + cat_str,
chunk=_chunk_ if _cat_ == 0 else 1 - _chunk_,
stage=stage_,
minibatch=_micro_,
start_time=complete_time,
completion_time=complete_time,
)
)
comm_id[local_order[stage_][-1]] = comm_id_counter
if _chunk_ == 1 and i > 0:
communicate("SEND_", i)
communicate("RECV_", i - 1)
if _chunk_ == 0 and i < self.n_stage - 1:
communicate("SEND_", i)
communicate("RECV_", i + 1)
comm_id_counter += 1
for rank in range(self.n_stage):
# For nodes with the same timestamp on the same stage, communication will be prioritized.
def even_breaker(x: ScheduledNode):
# Compute nodes are always delayed.
if x.type in ["F", "B", "W"]:
return comm_id_counter
# For comm nodes, order by their unique comm id
return comm_id[x]
local_order[rank] = list(sorted(local_order[rank], key=lambda x: (x.start_time, even_breaker(x))))
# If a recv with intersects with previous computation, reorder them so that recv
# is executed before computation and hence can be overlapped.
for i in range(len(local_order[rank])):
if (
i > 0
and local_order[rank][i - 1].type in {"F", "B", "W"}
and local_order[rank][i].type.startswith("RECV")
and "POST_VALIDATION" not in local_order[rank][i].type
and local_order[rank][i].start_time <= local_order[rank][i - 1].completion_time
):
local_order[rank][i], local_order[rank][i - 1] = local_order[rank][i - 1], local_order[rank][i]
local_order_with_rollback = [[] for _ in range(self.n_stage)]
for rank in range(self.n_stage):
rollback_comm = set()
if rank > 0:
for node in local_order[rank - 1]:
if node.type == "POST_VALIDATION":
break
if node.type == "SEND_FORWARD":
assert node.chunk == 0
rollback_comm.add(node.minibatch)
for node in local_order[rank]:
if node.type == "RECV_FORWARD" and node.chunk == 0 and node.minibatch in rollback_comm:
rollback = True
rollback_comm.remove(node.minibatch)
else:
rollback = False
local_order_with_rollback[rank].append(
ScheduledNode(
type=node.type,
chunk=node.chunk,
stage=node.stage,
minibatch=node.minibatch,
start_time=node.start_time,
completion_time=node.completion_time,
rollback=rollback,
)
)
assert len(rollback_comm) == 0
# for node in local_order_with_rollback[rank]:
# print(f"Rank {rank} Node info {node}")
# print(f"{node.type}-{node.minibatch}-{int(node.rollback)}", end=", ")
# print()
return local_order_with_rollback

View File

@ -0,0 +1,844 @@
from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import torch
import torch.cuda
from torch.nn import Module, ModuleList
from torch.utils._pytree import tree_map
from colossalai.accelerator import get_accelerator
from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.schedule.v_schedule import ScheduledNode
from colossalai.pipeline.stage_manager import PipelineStageManager
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, retain_grad, to_device
from .base import PipelineSchedule
AUTO_SCHEDULE_COMMUNICATION_TYPES = {"RECV_FORWARD", "RECV_BACKWARD", "SEND_FORWARD", "SEND_BACKWARD"}
def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None:
if wait_handles is not None:
for req in wait_handles:
req.wait()
def deallocate_output_tensor(out, deallocate_pipeline_outputs=False):
"""Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
This method should be called right after the output tensor has been
sent to the next pipeline stage. At this point, the output tensor is
only useful for its '.grad_fn' field, and not its '.data'.
"""
if (out is None) or (not deallocate_pipeline_outputs):
return
assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__
assert out._base is None, "counter-productive to free a view of another tensor."
# out.data = torch.empty((1,), device=out.device, dtype=out.dtype,)
out.data.untyped_storage().resize_(0)
class ZeroBubbleVPipeScheduler(PipelineSchedule):
def __init__(
self,
stage_manager: PipelineStageManager,
schedule: List[ScheduledNode],
num_model_chunks: int,
num_microbatch: Optional[int] = None,
microbatch_size: Optional[int] = None,
enable_metadata_cache: bool = True,
overlap_p2p: bool = True,
):
super().__init__(stage_manager)
# batch info
self.num_microbatch = num_microbatch
self.microbatch_size = microbatch_size
self.num_model_chunks = num_model_chunks
self.batch: Any
self.batch_size: int
self.last_batch_size: Optional[int] = None
self.microbatch_offset: List[int]
self.schedules = schedule
# TODO: optim post valid
self.do_post_validation = False
# P2PMeta cache
# self.enable_metadata_cache = enable_metadata_cache
# self.send_tensor_metadata = True
# self.send_grad_metadata = True
# self.tensor_metadata_recv = None
# self.grad_metadata_recv = None
# P2P communication
self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p)
# init communication map
self.communication_map = {
"SEND_FORWARD": self.send_forward,
"RECV_FORWARD": self.recv_forward,
"SEND_BACKWARD": self.send_backward,
"RECV_BACKWARD": self.recv_backward,
}
# init buffer
self._free_buffers()
def _free_buffers(self):
# free local buffer
# two dim array, first dim is the model chunk, second dim is the microbatch queue
# x & y buffer for schedule b
self.input_tensors = [[], []]
self.output_tensors = [[], []]
# y & dy buffer for schedule w
self.output_tensors_dw = [[], []]
self.output_tensors_grad_dw = [[], []]
# buffer for communication
self.send_forward_buffer = [[], []]
self.recv_forward_buffer = [[], []]
self.send_backward_buffer = [[], []]
self.recv_backward_buffer = [[], []]
# y buffer for local send fwd
self.local_send_forward_buffer = []
# dy buffer for local send bwd
self.local_send_backward_buffer = []
def assert_buffer_empty(self):
# assert buuffer is empty at end
assert len(self.input_tensors[0]) == 0
assert len(self.input_tensors[1]) == 0
assert len(self.output_tensors[0]) == 0
assert len(self.output_tensors[1]) == 0
assert len(self.output_tensors_dw[0]) == 0
assert len(self.output_tensors_dw[1]) == 0
assert len(self.output_tensors_grad_dw[0]) == 0
assert len(self.output_tensors_grad_dw[1]) == 0
assert len(self.send_forward_buffer[0]) == 0
assert len(self.send_forward_buffer[1]) == 0
assert len(self.recv_forward_buffer[0]) == 0
assert len(self.recv_forward_buffer[1]) == 0
assert len(self.send_backward_buffer[0]) == 0
assert len(self.send_backward_buffer[1]) == 0
assert len(self.recv_backward_buffer[0]) == 0
assert len(self.recv_backward_buffer[1]) == 0
assert len(self.local_send_forward_buffer) == 0
assert len(self.local_send_backward_buffer) == 0
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator.
Args:
data_iter (Iterable): Data iterator.
device (Optional[torch.device], optional): Target device. Defaults to None.
"""
batch = next(data_iter)
if device is not None:
batch = tree_map(partial(to_device, device=device), batch)
self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
self.batch = batch
self.batch_size = get_batch_size(batch)
if self.microbatch_size is None:
assert self.batch_size % self.num_microbatch == 0, "Batch size should divided by the number of microbatch"
self.microbatch_size = self.batch_size // self.num_microbatch
if self.num_microbatch is None:
assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size"
self.num_microbatch = self.batch_size // self.microbatch_size
if not self.forward_only:
assert self.last_batch_size is None or self.last_batch_size == self.batch_size
assert self.batch_size == self.microbatch_size * self.num_microbatch
assert (
self.num_microbatch % self.stage_manager.num_stages == 0
), "Number of microbatch should be an integer multiple of number of pipeline parallel devices"
if self.forward_only:
self.num_microbatch = (self.batch_size - 1) // self.microbatch_size + 1
# NOTE: disable metadata cache when batch size changes (not valid anymore)
# if self.batch_size != self.last_batch_size:
# self.enable_metadata_cache = False
# self.send_tensor_metadata = True
# self.send_grad_metadata = True
# self.tensor_metadata_recv = None
# self.grad_metadata_recv = None
self.last_batch_size = self.batch_size
def load_micro_batch(self, model_chunk_id: int) -> Any:
"""Load a micro batch from the current batch.
Args:
microbatch_id (int): the current model chunk idx.
Returns:
Any: Micro batch.
"""
assert self.microbatch_offset[model_chunk_id] <= self.batch_size, "Microbatches exhausted"
micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size)
self.microbatch_offset[model_chunk_id] += self.microbatch_size
return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch)
def get_model_chunk_id(self, microbatch_id: int, is_forward: bool) -> int:
"""Helper method to get the model chunk ID given the iteration number.
Args:
microbatch_id (int): the current microbatch idx
forward (bool): if is the forward process
Returns:
int: The model chunk idx of the input microbatch_id
"""
assert (
microbatch_id < self.num_microbatch * self.num_model_chunks
), f"microbatch_id {microbatch_id} is out of range ({self.num_microbatch * self.num_model_chunks})"
microbatch_id_in_group = microbatch_id % (self.stage_manager.num_stages * self.num_model_chunks)
model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages
if not is_forward:
# Reverse order
model_chunk_id = self.num_model_chunks - model_chunk_id - 1
return model_chunk_id
def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Tuple[Any, List]:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
For ZBV.
Args:
model_chunk_id (int): The current model chunk idx.
prev_rank (int, optional): The rank of the source of the tensor.
Returns:
Any: The input tensor or input tensor list.
Any: The wait handles for the communication.
"""
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if model_chunk_id == 0:
################
# chunk = 0 & is_first_stage
# do nothing; cause u are chunk 0 in first rank, u have no prev rank;
#################
if self.stage_manager.is_first_stage(ignore_chunk=True):
return None, []
################
# chunk = 0 & not is_first_stage
# Recv y from PREV_rank as input
#################
else:
prev_rank = self.stage_manager.get_prev_rank()
input_tensor, wait_handles = self.comm.recv_forward(prev_rank=prev_rank)
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
return input_tensor, wait_handles
else:
################
# chunk = 1 & is_last_stage
# do nothing; cause u get y from local_send_forward_buffer in schedule f
################
if self.stage_manager.is_last_stage(ignore_chunk=True):
return None, []
################
# chunk = 1 & not is_last_stage
# recv y from NEXT_rank as input
################
else:
next_rank = self.stage_manager.get_next_rank()
input_tensor, wait_handles = self.comm.recv_forward(next_rank)
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
return input_tensor, wait_handles
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Tuple[Any, List]:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
For ZBV.
Args:
model_chunk_id (int): The current model chunk idx.
next_rank (int, optional): The rank of the source of the tensor.
Returns:
Any: The input gradient tensor or gradient tensor list.
Any: The wait handles for the communication.
"""
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if model_chunk_id == 0:
# bwd chunk0 is right V;
################
# chunk = 0 & is_last_stage
# do nothing; Already get dy from local_send_backward_buffer in schedule b
################
if self.stage_manager.is_last_stage(ignore_chunk=True):
return None, []
################
# chunk = 0 & not is_last_stage
# Recv bwd from next stage;
################
else:
next_rank = self.stage_manager.get_next_rank()
output_tensor_grad, wait_handles = self.comm.recv_backward(next_rank)
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
return output_tensor_grad, wait_handles
else:
# bwd chunk1 is left V;
################
# chunk = 1 & is_first_stage
# do nothing; get loss from local
################
if self.stage_manager.is_first_stage(ignore_chunk=True):
return None, []
################
# chunk = 1 & not first stage
# recv_backward recv bwd from prev stage;
################
else:
prev_rank = self.stage_manager.get_prev_rank()
output_tensor_grad, wait_handles = self.comm.recv_backward(next_rank=prev_rank)
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
return output_tensor_grad, wait_handles
def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List:
"""Sends the input tensor to the next stage in pipeline.
For ZBV.
Args:
model_chunk_id (int): The current model chunk idx.
next_rank (int, optional): The rank of the recipient of the tensor.
Returns:
Any: The wait handles for the communication.
"""
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if model_chunk_id == 0:
################
# chunk = 0 && is_last_stage
# do nothing; hold y on local_send_forward_buffer
################
if self.stage_manager.is_last_stage(ignore_chunk=True):
return []
################
# chunk = 0 && not is_last_stage
# self.comm.send_forward send y to NEXT stage
################
else:
next_rank = self.stage_manager.get_next_rank()
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
send_handles = self.comm.send_forward(output_object=output_tensor, next_rank=next_rank)
return send_handles
else:
################
# chunk = 1 && is_first_stage
# do nothing; Already send LOSS to local_send_backward_buffer in schedule f send part
################
if self.stage_manager.is_first_stage(ignore_chunk=True):
return []
################
# chunk = 1 && not is_first_stage
# self.comm.send_forward send y to PREV stage
################
else:
prev_rank = self.stage_manager.get_prev_rank()
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
send_handles = self.comm.send_forward(output_tensor, prev_rank)
return send_handles
def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List:
"""Sends the gradient tensor to the previous stage in pipeline.
For ZBV.
Args:
model_chunk_id (int): The current model chunk idx.
prev_rank (int, optional): The rank of the recipient of the tensor
Returns:
Any: The wait handles for the communication.
"""
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if model_chunk_id == 0:
# bwd chunk0 is right V;
################
# chunk = 0 && is_first_stage
# do nothing; cause u are the first chunk in first stage; bwd end
################
if self.stage_manager.is_first_stage(ignore_chunk=True):
return []
################
# chunk = 0 && not is_first_stage
# Send dx to PREV stage;
################
else:
prev_rank = self.stage_manager.get_prev_rank()
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
send_handles = self.comm.send_backward(input_tensor_grad, prev_rank)
return send_handles
# bwd chunk1 is left V;
else:
################
# chunk = 1 && is_last_stage
# do nothing; Already send input_tensor_grad to local_send_bwd_buffer in schedule b;
################
if self.stage_manager.is_last_stage(ignore_chunk=True):
return []
################
# chunk = 1 && not is_last_stage
# Send dx to NEXT stage;
################
else:
next_rank = self.stage_manager.get_next_rank()
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
send_handles = self.comm.send_backward(input_tensor_grad, next_rank)
return send_handles
def forward_step(
self,
model_chunk: Union[ModuleList, Module],
model_chunk_id: int,
input_obj: Optional[dict],
criterion: Callable,
accum_loss: Optional[torch.Tensor] = None,
outputs: Optional[List[Any]] = None,
) -> Union[torch.Tensor, dict]:
"""Forward one step of the pipeline
Args:
model_chunk (ModuleList or Module): Model Chunk to be run;
model_chunk_id (int): The current model chunk idx;
input_obj (Optional[dict]): x;
criterion (Callable): loss function;
accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None.
outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None.
Returns:
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
"""
# Load input ids, attention mask and labels
# micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id)
# for the first stage, input_obj is None
# for other stages, input_obj is the output of the previous/next stage containing hidden_states etc.
# Only attention_mask from micro_batch is used
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
# fwd calculate
output_obj = model_chunk[model_chunk_id](input_obj)
# last layer in model
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
loss = criterion(output_obj) / self.num_microbatch
if accum_loss is not None:
accum_loss.add_(loss.detach())
if outputs is not None:
outputs.append(tree_map(detach, output_obj))
return loss
else:
return output_obj
def backward_b_step(
self,
model_chunk: Union[ModuleList, Module],
model_chunk_id: int,
optimizer: OptimizerWrapper,
input_obj: Optional[dict],
output_obj: Union[dict, torch.Tensor],
output_obj_grad: Optional[dict],
) -> Optional[dict]:
"""Backward dx step of the pipeline; we calculate "dx = w*dy" here;
Args:
model_chunk (ModuleList or Module): Model Chunk to be run;
model_chunk_id (int): The current model chunk idx;
optimizer (OptimizerWrapper): Optimizer to update the model
input_obj (Optional[dict]): x.
output_obj (Union[dict, torch.Tensor]): y.
output_obj_grad (dict): dy.
Returns:
Optional[dict]: dx.
"""
# calculate bwd b step ; only dx = w*dy;
# Retain the grad on the input_obj.
tree_map(retain_grad, input_obj)
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# loss backward; output_obj is loss; so output_obj_grad should be None
assert output_obj_grad is None
optimizer.backward_by_grad(
tensor=output_obj,
grad=output_obj_grad,
inputs=input_obj,
retain_graph=True,
)
return input_obj.grad
def backward_w_step(
self,
model_chunk: Union[ModuleList, Module],
model_chunk_id: int,
optimizer: OptimizerWrapper,
output_obj: Union[dict, torch.Tensor],
output_obj_grad: Optional[dict],
):
"""Backward dw step of the pipeline; we calculate "dw = x*dy" here;
Args:
model_chunk (ModuleList or Module): Model Chunk to be run;
model_chunk_id (int): The current model chunk idx;
optimizer (OptimizerWrapper): Optimizer to update the model
output_obj (Union[dict, torch.Tensor]): y.
output_obj_grad (dict): dy.
Returns:
Nothing need to return; we only calculate dw then update w;
"""
# calculate bwd w step ; only dw = x*dy;
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# loss backward; output_obj is loss
output_obj_grad = None
optimizer.backward_by_grad(
tensor=output_obj,
grad=output_obj_grad,
inputs=list(model_chunk[model_chunk_id].parameters()),
retain_graph=False,
)
def schedule_f(
self,
scheduled_node,
model_chunk: torch.nn.ModuleList,
model_chunk_id: int,
criterion: Callable,
accum_loss: Optional[torch.Tensor] = None,
outputs: Optional[List[Any]] = None,
):
"""A complete forward schedule; Include recv fwd --> cal fwd --> send fwd;
Args:
scheduled_node:
model_chunk (ModuleList or Module): Model Chunk to be run;
model_chunk_id (int): The current model chunk idx;
criterion (Callable): loss function;
accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None.
outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None.
Returns:
Nothing.
"""
micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id)
# Step1: recv fwd
if model_chunk_id == 0:
# is first stage; get input from func param
if self.stage_manager.is_first_stage(ignore_chunk=True):
input_obj = micro_batch
else:
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
else:
# is last stage; recv from local
if self.stage_manager.is_last_stage(ignore_chunk=True):
input_obj = self.local_send_forward_buffer.pop(0)
# not last stage; recv from next
else:
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
# Here, let input_obj.requires_grad_()
tree_map(torch.Tensor.requires_grad_, input_obj)
# Step2: fwd step
output_obj = self.forward_step(
model_chunk=model_chunk,
model_chunk_id=model_chunk_id,
input_obj=input_obj,
criterion=criterion,
accum_loss=accum_loss,
outputs=outputs,
)
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# We should not detach bwd LOSS
pass
else:
detached_output_obj = output_obj.clone().detach()
# Step3: send fwd
# add output to send_fwd_buffer
if model_chunk_id == 0:
# is last stage; send to local_send_forward_buffer
if self.stage_manager.is_last_stage(ignore_chunk=True):
self.local_send_forward_buffer.append(detached_output_obj)
else:
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
else:
# is first stage; end of fwd; append LOSS to local_send_backward_buffer
if self.stage_manager.is_first_stage(ignore_chunk=True):
pass
else:
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
# add input and output object for backward b
self.input_tensors[model_chunk_id].append(input_obj)
# detached output; for bwd b&w, we only need the graph(grad_fn) of output_obj
deallocate_output_tensor(output_obj, deallocate_pipeline_outputs=True)
self.output_tensors[model_chunk_id].append(output_obj)
# add output object for backward w
self.output_tensors_dw[model_chunk_id].append(output_obj)
def schedule_b(
self,
scheduled_node,
model_chunk: Union[ModuleList, Module],
model_chunk_id: int,
optimizer: OptimizerWrapper,
# input_obj: Optional[dict],
# output_obj: Union[dict, torch.Tensor],
# output_obj_grad: Optional[dict],
):
"""A complete backward b schedule; Include recv bwd --> cal bwd step --> send bwd;
Args:
scheduled_node:
model_chunk (ModuleList or Module): Model Chunk to be run;
model_chunk_id (int): The current model chunk idx;
Returns:
Nothing.
"""
# Step1: recv bwd
if model_chunk_id == 0:
# chunk0 is last stage; recv output_grad from local_send_backward_buffer
if self.stage_manager.is_last_stage(ignore_chunk=True):
output_tensor_grad = self.local_send_backward_buffer.pop(0)
# chunk 0 not last stage; recv output_grad from recv_backward_buffer
else:
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
else:
# chunk1, is first stage; recv LOSS from local send bwd buffer
if self.stage_manager.is_first_stage(ignore_chunk=True):
output_tensor_grad = None
# chunk1, not first stage; recv output_grad from recv_backward_buffer
else:
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
# get input and output object from buffer;
input_obj = self.input_tensors[model_chunk_id].pop(0)
output_obj = self.output_tensors[model_chunk_id].pop(0)
# save output_tensor_grad for dw
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# we save loss here
self.output_tensors_grad_dw[model_chunk_id].append(output_obj)
else:
# we save output_tensor_grad here
self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
# _wait_p2p(recv_bwd_handles)
# Step2: bwd step
input_object_grad = self.backward_b_step(
model_chunk=model_chunk,
model_chunk_id=model_chunk_id,
optimizer=optimizer,
input_obj=input_obj,
output_obj=output_obj,
output_obj_grad=output_tensor_grad,
)
# Step3: send bwd
if model_chunk_id == 0:
# do nothing; end of bwd;
if self.stage_manager.is_first_stage(ignore_chunk=True):
pass
# save input_object_grad to send_backward_buffer
else:
self.send_backward_buffer[model_chunk_id].append(input_object_grad)
else:
# send to local_send_backward_buffer
if self.stage_manager.is_last_stage(ignore_chunk=True):
self.local_send_backward_buffer.append(input_object_grad)
# send to next
else:
self.send_backward_buffer[model_chunk_id].append(input_object_grad)
def schedule_w(
self,
scheduled_node,
model_chunk: Union[ModuleList, Module],
model_chunk_id: int,
optimizer: OptimizerWrapper,
):
"""A complete backward w schedule; Include get y & dy from buffer --> cal bwd w step(cal dw & update w);
Args:
scheduled_node:
model_chunk (ModuleList or Module): Model Chunk to be run;
model_chunk_id (int): The current model chunk idx;
Returns:
Nothing.
"""
# get y & dy from buffer
output_obj = self.output_tensors_dw[model_chunk_id].pop(0)
output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0)
self.backward_w_step(
model_chunk=model_chunk,
model_chunk_id=model_chunk_id,
optimizer=optimizer,
output_obj=output_obj,
output_obj_grad=output_obj_grad,
)
def run_forward_only(
self,
model_chunk: Union[ModuleList, Module],
data_iter: Iterable,
criterion: Callable[..., Any],
return_loss: bool = False,
return_outputs: bool = False,
) -> Dict:
assert self.forward_only
# prepare batch
self.load_batch(data_iter)
# prepare accum loss & output
accum_loss = None
# reset accum loss at fwd end;
if return_loss and self.stage_manager.is_first_stage(ignore_chunk=True):
accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device())
outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None
# while we still have schedules_node in self.schedules
for it in range(len(self.schedules)):
scheduled_node = self.schedules[it]
if scheduled_node.type in {"RECV_FORWARD", "SEND_FORWARD"}:
# communication
communication_func = self.communication_map[scheduled_node.type]
communication_func(scheduled_node.chunk)
if scheduled_node.type == "F":
self.schedule_f(
scheduled_node=scheduled_node,
model_chunk=model_chunk,
model_chunk_id=scheduled_node.chunk,
criterion=criterion,
accum_loss=accum_loss,
outputs=outputs,
)
# return loss & output
if outputs is not None:
outputs = merge_batch(outputs)
return {"loss": accum_loss, "outputs": outputs}
def run_forward_backward(
self,
model_chunk: Union[ModuleList, Module],
data_iter: Iterable,
criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False,
) -> Dict:
"""
Runs Zerobubble schedule, with communication between pipeline stages.
"""
# prepare batch
self.load_batch(data_iter)
# prepare accum loss & output
accum_loss = None
# reset accum loss at fwd end;
if return_loss and self.stage_manager.is_first_stage(ignore_chunk=True):
accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device())
outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None
# while we still have schedules_node in self.schedules
schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank)
for it in range(len(schedule)):
scheduled_node = schedule[it]
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
# communication
communication_func = self.communication_map[scheduled_node.type]
communication_func(scheduled_node.chunk)
if scheduled_node.type == "F":
self.schedule_f(
scheduled_node=scheduled_node,
model_chunk=model_chunk,
model_chunk_id=scheduled_node.chunk,
criterion=criterion,
accum_loss=accum_loss,
outputs=outputs,
)
elif scheduled_node.type == "B":
self.schedule_b(
scheduled_node=scheduled_node,
model_chunk=model_chunk,
model_chunk_id=scheduled_node.chunk,
optimizer=optimizer,
)
elif scheduled_node.type == "W":
self.schedule_w(
scheduled_node=scheduled_node,
model_chunk=model_chunk,
model_chunk_id=scheduled_node.chunk,
optimizer=optimizer,
)
# return loss & output
if outputs is not None:
outputs = merge_batch(outputs)
return {"loss": accum_loss, "outputs": outputs}
def forward_backward_step(
self,
model_chunk: Union[ModuleList, Module],
data_iter: Iterable,
criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False,
) -> dict:
"""
Args:
model_chunk (ModuleList or Module): Model Chunk to be trained. Original interleaved uses a module list whereas shardformer uses entire model + layer specification
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
Returns:
dict: A dict with keys: 'loss' and 'outputs'.
"""
self.forward_only = not torch.is_grad_enabled()
if optimizer is None:
assert self.forward_only, "Optimizer should be passed when doing backward."
if self.forward_only:
result = self.run_forward_only(model_chunk, data_iter, criterion, return_loss, return_outputs)
else:
result = self.run_forward_backward(
model_chunk, data_iter, criterion, optimizer, return_loss, return_outputs
)
self.assert_buffer_empty()
return result

View File

@ -0,0 +1,769 @@
from copy import deepcopy
from typing import Tuple
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import OptimizerWrapper
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
class MlpModel(nn.Module):
def __init__(self, in_dim, out_dim, num_layers):
super().__init__()
self.layers = nn.ModuleList([nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
num_params = 0
num_params_trainable = 0
for p in model.parameters():
num_params += p.numel()
if p.requires_grad:
num_params_trainable += p.numel()
return num_params, num_params_trainable
# 1) Test manual v_schedule with multiple microbatch
@parameterize(
"test_config",
[
{
"batch_size": 8,
"tp_size": 1,
"pp_size": 4,
"num_microbatches": 4,
"zero_stage": 1,
"precision": "bf16",
"num_model_chunk": 2,
},
],
)
def run_fwd_bwd_iter_input(test_config):
# init dist
rank = dist.get_rank()
pp_size = test_config["pp_size"]
pg_mesh = ProcessGroupMesh(pp_size)
num_microbatch = test_config["num_microbatches"]
num_model_chunk = test_config["num_model_chunk"]
# stage_manager
stage_manager = PipelineStageManager(
pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk
)
# schedule list
zbv_schedule = [
# stage 0
[
# microbatch 0
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=0),
ScheduledNode(type="F", chunk=0, stage=0, minibatch=0),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=0),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=0),
ScheduledNode(type="F", chunk=1, stage=0, minibatch=0),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=0),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=0),
ScheduledNode(type="B", chunk=1, stage=0, minibatch=0),
ScheduledNode(type="W", chunk=1, stage=0, minibatch=0),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=0),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=0),
ScheduledNode(type="B", chunk=0, stage=0, minibatch=0),
ScheduledNode(type="W", chunk=0, stage=0, minibatch=0),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=0),
# microbatch 1
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=1),
ScheduledNode(type="F", chunk=0, stage=0, minibatch=1),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=1),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=1),
ScheduledNode(type="F", chunk=1, stage=0, minibatch=1),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=1),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=1),
ScheduledNode(type="B", chunk=1, stage=0, minibatch=1),
ScheduledNode(type="W", chunk=1, stage=0, minibatch=1),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=1),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=1),
ScheduledNode(type="B", chunk=0, stage=0, minibatch=1),
ScheduledNode(type="W", chunk=0, stage=0, minibatch=1),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=1),
# microbatch 2
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=2),
ScheduledNode(type="F", chunk=0, stage=0, minibatch=2),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=2),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=2),
ScheduledNode(type="F", chunk=1, stage=0, minibatch=2),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=2),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=2),
ScheduledNode(type="B", chunk=1, stage=0, minibatch=2),
ScheduledNode(type="W", chunk=1, stage=0, minibatch=2),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=2),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=2),
ScheduledNode(type="B", chunk=0, stage=0, minibatch=2),
ScheduledNode(type="W", chunk=0, stage=0, minibatch=2),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=2),
# microbatch 3
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=3),
ScheduledNode(type="F", chunk=0, stage=0, minibatch=3),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=3),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=3),
ScheduledNode(type="F", chunk=1, stage=0, minibatch=3),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=3),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=3),
ScheduledNode(type="B", chunk=1, stage=0, minibatch=3),
ScheduledNode(type="W", chunk=1, stage=0, minibatch=3),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=3),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=3),
ScheduledNode(type="B", chunk=0, stage=0, minibatch=3),
ScheduledNode(type="W", chunk=0, stage=0, minibatch=3),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=3),
],
# stage 1
[
# microbatch 0
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=0),
ScheduledNode(type="F", chunk=0, stage=1, minibatch=0),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=0),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=0),
ScheduledNode(type="F", chunk=1, stage=1, minibatch=0),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=0),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=0),
ScheduledNode(type="B", chunk=1, stage=1, minibatch=0),
ScheduledNode(type="W", chunk=1, stage=1, minibatch=0),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=0),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=0),
ScheduledNode(type="B", chunk=0, stage=1, minibatch=0),
ScheduledNode(type="W", chunk=0, stage=1, minibatch=0),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=0),
# microbatch 1
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=1),
ScheduledNode(type="F", chunk=0, stage=1, minibatch=1),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=1),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=1),
ScheduledNode(type="F", chunk=1, stage=1, minibatch=1),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=1),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=1),
ScheduledNode(type="B", chunk=1, stage=1, minibatch=1),
ScheduledNode(type="W", chunk=1, stage=1, minibatch=1),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=1),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=1),
ScheduledNode(type="B", chunk=0, stage=1, minibatch=1),
ScheduledNode(type="W", chunk=0, stage=1, minibatch=1),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=1),
# microbatch 2
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=2),
ScheduledNode(type="F", chunk=0, stage=1, minibatch=2),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=2),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=2),
ScheduledNode(type="F", chunk=1, stage=1, minibatch=2),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=2),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=2),
ScheduledNode(type="B", chunk=1, stage=1, minibatch=2),
ScheduledNode(type="W", chunk=1, stage=1, minibatch=2),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=2),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=2),
ScheduledNode(type="B", chunk=0, stage=1, minibatch=2),
ScheduledNode(type="W", chunk=0, stage=1, minibatch=2),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=2),
# microbatch 3
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=3),
ScheduledNode(type="F", chunk=0, stage=1, minibatch=3),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=3),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=3),
ScheduledNode(type="F", chunk=1, stage=1, minibatch=3),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=3),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=3),
ScheduledNode(type="B", chunk=1, stage=1, minibatch=3),
ScheduledNode(type="W", chunk=1, stage=1, minibatch=3),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=3),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=3),
ScheduledNode(type="B", chunk=0, stage=1, minibatch=3),
ScheduledNode(type="W", chunk=0, stage=1, minibatch=3),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=3),
],
# stage 2
[
# microbatch 0
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=0),
ScheduledNode(type="F", chunk=0, stage=2, minibatch=0),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=0),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=0),
ScheduledNode(type="F", chunk=1, stage=2, minibatch=0),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=0),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=0),
ScheduledNode(type="B", chunk=1, stage=2, minibatch=0),
ScheduledNode(type="W", chunk=1, stage=2, minibatch=0),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=0),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=0),
ScheduledNode(type="B", chunk=0, stage=2, minibatch=0),
ScheduledNode(type="W", chunk=0, stage=2, minibatch=0),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=0),
# microbatch 1
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=1),
ScheduledNode(type="F", chunk=0, stage=2, minibatch=1),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=1),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=1),
ScheduledNode(type="F", chunk=1, stage=2, minibatch=1),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=1),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=1),
ScheduledNode(type="B", chunk=1, stage=2, minibatch=1),
ScheduledNode(type="W", chunk=1, stage=2, minibatch=1),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=1),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=1),
ScheduledNode(type="B", chunk=0, stage=2, minibatch=1),
ScheduledNode(type="W", chunk=0, stage=2, minibatch=1),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=1),
# microbatch 2
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=2),
ScheduledNode(type="F", chunk=0, stage=2, minibatch=2),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=2),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=2),
ScheduledNode(type="F", chunk=1, stage=2, minibatch=2),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=2),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=2),
ScheduledNode(type="B", chunk=1, stage=2, minibatch=2),
ScheduledNode(type="W", chunk=1, stage=2, minibatch=2),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=2),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=2),
ScheduledNode(type="B", chunk=0, stage=2, minibatch=2),
ScheduledNode(type="W", chunk=0, stage=2, minibatch=2),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=2),
# microbatch 3
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=3),
ScheduledNode(type="F", chunk=0, stage=2, minibatch=3),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=3),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=3),
ScheduledNode(type="F", chunk=1, stage=2, minibatch=3),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=3),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=3),
ScheduledNode(type="B", chunk=1, stage=2, minibatch=3),
ScheduledNode(type="W", chunk=1, stage=2, minibatch=3),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=3),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=3),
ScheduledNode(type="B", chunk=0, stage=2, minibatch=3),
ScheduledNode(type="W", chunk=0, stage=2, minibatch=3),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=3),
],
# stage 3
[
# microbatch 0
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=0),
ScheduledNode(type="F", chunk=0, stage=3, minibatch=0),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=0),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=0),
ScheduledNode(type="F", chunk=1, stage=3, minibatch=0),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=0),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=0),
ScheduledNode(type="B", chunk=1, stage=3, minibatch=0),
ScheduledNode(type="W", chunk=1, stage=3, minibatch=0),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=0),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=0),
ScheduledNode(type="B", chunk=0, stage=3, minibatch=0),
ScheduledNode(type="W", chunk=0, stage=3, minibatch=0),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=0),
# microbatch 1
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=1),
ScheduledNode(type="F", chunk=0, stage=3, minibatch=1),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=1),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=1),
ScheduledNode(type="F", chunk=1, stage=3, minibatch=1),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=1),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=1),
ScheduledNode(type="B", chunk=1, stage=3, minibatch=1),
ScheduledNode(type="W", chunk=1, stage=3, minibatch=1),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=1),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=1),
ScheduledNode(type="B", chunk=0, stage=3, minibatch=1),
ScheduledNode(type="W", chunk=0, stage=3, minibatch=1),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=1),
# microbatch 2
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=2),
ScheduledNode(type="F", chunk=0, stage=3, minibatch=2),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=2),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=2),
ScheduledNode(type="F", chunk=1, stage=3, minibatch=2),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=2),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=2),
ScheduledNode(type="B", chunk=1, stage=3, minibatch=2),
ScheduledNode(type="W", chunk=1, stage=3, minibatch=2),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=2),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=2),
ScheduledNode(type="B", chunk=0, stage=3, minibatch=2),
ScheduledNode(type="W", chunk=0, stage=3, minibatch=2),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=2),
# microbatch 3
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=3),
ScheduledNode(type="F", chunk=0, stage=3, minibatch=3),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=3),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=3),
ScheduledNode(type="F", chunk=1, stage=3, minibatch=3),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=3),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=3),
ScheduledNode(type="B", chunk=1, stage=3, minibatch=3),
ScheduledNode(type="W", chunk=1, stage=3, minibatch=3),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=3),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=3),
ScheduledNode(type="B", chunk=0, stage=3, minibatch=3),
ScheduledNode(type="W", chunk=0, stage=3, minibatch=3),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=3),
],
]
scheduler = ZeroBubbleVPipeScheduler(
schedule=zbv_schedule, # hint: send whole schedule or local schedule only ?
stage_manager=stage_manager,
num_model_chunks=pp_size,
num_microbatch=num_microbatch,
overlap_p2p=False,
)
# loss func
def criterion(x, *args, **kwargs):
return (x * x).mean()
# init model and input
batch_size = 4
num_layers = 8
in_dim = out_dim = 8
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)]
input_base = [t.clone() for t in data_iter]
model_base = deepcopy(model)
if rank == 0:
# layer 0 & 7 to chunk 0 on rank0
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 0 or idx == 7:
local_chunk.append(sub_model)
elif rank == 1:
# layer 1 & 6 to chunk 1 on rank1
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 1 or idx == 6:
local_chunk.append(sub_model)
elif rank == 2:
# layer 2 & 5 to chunk 2 on rank2
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 2 or idx == 5:
local_chunk.append(sub_model)
else:
# layer 3 & 4 to chunk 3 on rank3
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 3 or idx == 4:
local_chunk.append(sub_model)
# init optimizer
optimizer_base = torch.optim.SGD(model_base.parameters(), lr=1e-5)
optimizer_pp = OptimizerWrapper(torch.optim.SGD(local_chunk.parameters(), lr=1e-5))
print(
f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
)
torch.cuda.synchronize()
result = scheduler.forward_backward_step(
model_chunk=local_chunk,
data_iter=iter(data_iter),
criterion=criterion,
optimizer=optimizer_pp,
return_loss=True,
return_outputs=True,
)
optimizer_pp.step()
##########################
# Fwd bwd for base
##########################
# fwd & bwd
output_base = model_base(input_base[0])
loss_base = criterion(output_base)
loss_base.backward()
optimizer_base.step()
print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
##########################
# assert weight
##########################
if rank == 0:
# layer 0
assert_close(local_chunk[0].weight, model_base.layers[0].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[0].weight.grad)
# layer 7
assert_close(local_chunk[1].weight, model_base.layers[7].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[7].weight.grad)
if rank == 1:
# layer 1
assert_close(local_chunk[0].weight, model_base.layers[1].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[1].weight.grad)
# layer 6
assert_close(local_chunk[1].weight, model_base.layers[6].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[6].weight.grad)
if rank == 2:
# layer 2
assert_close(local_chunk[0].weight, model_base.layers[2].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[2].weight.grad)
# layer 5
assert_close(local_chunk[1].weight, model_base.layers[5].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[5].weight.grad)
if rank == 3:
# layer 3
assert_close(local_chunk[0].weight, model_base.layers[3].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[3].weight.grad)
# layer 4
assert_close(local_chunk[1].weight, model_base.layers[4].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad)
# 2) add optimizer base 1)
@parameterize(
"test_config",
[
{
"batch_size": 8,
"tp_size": 1,
"pp_size": 4,
"num_microbatches": 4,
"zero_stage": 1,
"precision": "bf16",
"num_model_chunk": 2,
},
{
"batch_size": 8,
"tp_size": 1,
"pp_size": 4,
"num_microbatches": 8,
"zero_stage": 1,
"precision": "bf16",
"num_model_chunk": 2,
},
],
)
def run_fwd_bwd_vschedule_with_optim(test_config):
# init dist
rank = dist.get_rank()
pp_size = test_config["pp_size"]
pg_mesh = ProcessGroupMesh(pp_size)
num_microbatch = test_config["num_microbatches"]
num_model_chunk = test_config["num_model_chunk"]
# stage_manager
stage_manager = PipelineStageManager(
pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk
)
h, a, s = 4096, 32, 1024
mem_f = 34 * h + 5 * a * s
mem_w = -32 * h
mem_b = -mem_w - mem_f
graph = PipelineGraph(
n_stage=pp_size,
n_micro=num_microbatch,
f_cost=1,
b_cost=1,
w_cost=1,
c_cost=1,
f_mem=mem_f,
b_mem=mem_b,
w_mem=mem_w,
# max_mem=mem_f * (p * 2 + m_offset),
)
zbv_schedule = graph.get_v_schedule()
scheduler = ZeroBubbleVPipeScheduler(
schedule=zbv_schedule, # hint: send whole schedule or local schedule only ?
stage_manager=stage_manager,
num_model_chunks=num_model_chunk,
num_microbatch=num_microbatch,
overlap_p2p=False,
)
# init loss func
def criterion(x, *args, **kwargs):
return (x * x).mean()
# init model and input
batch_size = test_config["batch_size"]
num_layers = 8
assert num_layers % num_model_chunk == 0, f"Model with {num_layers} layer can not dist on {num_model_chunk} chunk"
in_dim = out_dim = 4096
before_init_memory = torch.cuda.memory_allocated() / 1024**3
print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};")
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)]
input_base = [t.clone() for t in data_iter]
model_base = deepcopy(model)
if rank == 0:
# layer 0 & 7 to chunk 0 on rank0
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 0 or idx == 7:
local_chunk.append(sub_model)
elif rank == 1:
# layer 1 & 6 to chunk 1 on rank1
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 1 or idx == 6:
local_chunk.append(sub_model)
elif rank == 2:
# layer 2 & 5 to chunk 2 on rank2
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 2 or idx == 5:
local_chunk.append(sub_model)
else:
# layer 3 & 4 to chunk 3 on rank3
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 3 or idx == 4:
local_chunk.append(sub_model)
# init optimizer
optimizer_base = torch.optim.SGD(model_base.parameters(), momentum=0.1, lr=1e-5)
optimizer_pp = OptimizerWrapper(torch.optim.SGD(local_chunk.parameters(), momentum=0.1, lr=1e-5))
after_init_memory = torch.cuda.memory_allocated() / 1024**3
print(f"After init Model & input: {after_init_memory :.5f} GB on device {stage_manager.get_rank()};")
torch.cuda.synchronize()
result = scheduler.forward_backward_step(
model_chunk=local_chunk,
data_iter=iter(data_iter),
criterion=criterion,
optimizer=optimizer_pp,
return_loss=True,
return_outputs=True,
)
optimizer_pp.step()
after_pp_step_memory = torch.cuda.memory_allocated() / 1024**3
# assert memory
if rank != 0:
# w.grad hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3
# output hid_dim * hid_dim * 4(fp32) / 1024**3
# optim state hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3
print(f"rank {rank}: {(after_pp_step_memory - after_init_memory)} <= {(in_dim * in_dim * 4 * 5 / 1024**3)}")
assert (after_pp_step_memory - after_init_memory) <= (in_dim * in_dim * 4 * 5 / 1024**3)
else:
# rank0 will also hold output;
print(
f"rank {rank}: {round((after_pp_step_memory - after_init_memory), 5)} <= {round((in_dim * in_dim * 4 * 5 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5)}"
)
assert round((after_pp_step_memory - after_init_memory), 5) <= round(
(in_dim * in_dim * 4 * 5 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5
)
##########################
# Fwd bwd for base
##########################
# fwd & bwd
output_base = model_base(input_base[0])
loss_base = criterion(output_base)
loss_base.backward()
optimizer_base.step()
##########################
# assert loss & output
##########################
# only chunk 1 stage 0 hold loss and output
if rank == 0:
assert_close(result["loss"], loss_base)
assert_close(result["outputs"], output_base)
# print(f"pp result {result}; base result loss:{loss_base} output_base:{output_base} ")
##########################
# assert weight
##########################
if rank == 0:
# layer 0
assert_close(local_chunk[0].weight, model_base.layers[0].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[0].weight.grad)
# layer 7
assert_close(local_chunk[1].weight, model_base.layers[7].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[7].weight.grad)
if rank == 1:
# layer 1
assert_close(local_chunk[0].weight, model_base.layers[1].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[1].weight.grad)
# layer 6
assert_close(local_chunk[1].weight, model_base.layers[6].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[6].weight.grad)
if rank == 2:
# layer 2
assert_close(local_chunk[0].weight, model_base.layers[2].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[2].weight.grad)
# layer 5
assert_close(local_chunk[1].weight, model_base.layers[5].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[5].weight.grad)
if rank == 3:
# layer 3
assert_close(local_chunk[0].weight, model_base.layers[3].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[3].weight.grad)
# layer 4
assert_close(local_chunk[1].weight, model_base.layers[4].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad)
##########################
# assert optim state
##########################
optim_base_state = optimizer_base.state_dict()["state"]
optim_pp_state = optimizer_pp.state_dict()["state"]
optim_base_param_groups = optimizer_base.state_dict()["param_groups"][0]
optim_pp_param_groups = optimizer_pp.state_dict()["param_groups"][0]
# if rank == 0:
# print(f"optim_base_state {optim_base_state}")
# assert param group
for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_param_groups.items(), optim_pp_param_groups.items()):
if key_base == key_pp:
if key_base != "params":
assert val_base == val_pp
else:
# BUG:
# param_base: [0, 1, 2, 3, 4, 5, 6, 7];
# params pp: [0, 1];
assert val_base[:2] == val_pp
# assert state
assert_close(optim_pp_state[0]["momentum_buffer"], optim_base_state[2 * rank]["momentum_buffer"])
assert_close(optim_pp_state[1]["momentum_buffer"], optim_base_state[2 * rank + 1]["momentum_buffer"])
# TODO:4) support Hybrid base 3)
def run_with_hybridplugin(test_config):
pass
# TODO:5) support MoEHybrid base 3)
@parameterize(
"test_config",
[
{
"batch_size": 8,
"tp_size": 1,
"pp_size": 4,
"num_microbatches": 4,
"zero_stage": 1,
"precision": "bf16",
"num_model_chunk": 2,
},
],
)
def run_with_moehybridplugin(test_config):
model_zoo.get_sub_registry("transformers_bert")
test_config["use_lazy_init"] = False
test_config["initial_scale"] = 2**16
model_list = [
"transformers_bert",
]
# TODO:6) support booster & Hybrid base 4)
# TODO:7) support booster & MoEHybrid base 4)
def run_dist(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
# run_fwd_bwd_iter_input()
run_fwd_bwd_vschedule_with_optim()
# run_with_moehybridplugin()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_pp():
spawn(
run_dist,
nprocs=4,
)
if __name__ == "__main__":
test_pp()