mirror of https://github.com/hpcaitech/ColossalAI
[feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;
parent
75c963686f
commit
ee9baedadf
|
@ -1,11 +1,12 @@
|
|||
from .p2p import PipelineP2PCommunication
|
||||
from .schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, PipelineSchedule
|
||||
from .schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, PipelineSchedule, ZeroBubbleVPipeScheduler
|
||||
from .stage_manager import PipelineStageManager
|
||||
|
||||
__all__ = [
|
||||
"PipelineSchedule",
|
||||
"OneForwardOneBackwardSchedule",
|
||||
"InterleavedSchedule",
|
||||
"ZeroBubbleVPipeScheduler",
|
||||
"PipelineP2PCommunication",
|
||||
"PipelineStageManager",
|
||||
]
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
from .base import PipelineSchedule
|
||||
from .interleaved_pp import InterleavedSchedule
|
||||
from .one_f_one_b import OneForwardOneBackwardSchedule
|
||||
from .zero_bubble_pp import ZeroBubbleVPipeScheduler
|
||||
|
||||
__all__ = [
|
||||
"PipelineSchedule",
|
||||
"OneForwardOneBackwardSchedule",
|
||||
"InterleavedSchedule",
|
||||
"ZeroBubbleVPipeScheduler",
|
||||
]
|
||||
|
|
|
@ -0,0 +1,468 @@
|
|||
# Refer from Zero Bubble Pipeline Parallelism.
|
||||
# Github: https://github.com/sail-sg/zero-bubble-pipeline-parallelism
|
||||
# Paper: https://arxiv.org/abs/2401.10241
|
||||
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(eq=True, frozen=True)
|
||||
class ScheduledNode:
|
||||
type: str
|
||||
chunk: int
|
||||
stage: int
|
||||
minibatch: int
|
||||
start_time: int
|
||||
completion_time: int
|
||||
rollback: bool = False
|
||||
|
||||
|
||||
class PipelineGraph(object):
|
||||
"""PipelineGraph"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_stage,
|
||||
n_micro,
|
||||
f_cost,
|
||||
b_cost,
|
||||
w_cost,
|
||||
c_cost,
|
||||
f_mem,
|
||||
b_mem,
|
||||
w_mem,
|
||||
max_mem=None,
|
||||
):
|
||||
self.n_node = 6 * n_stage * n_micro
|
||||
self.n_stage = n_stage
|
||||
self.n_micro = n_micro
|
||||
self.f_cost = f_cost
|
||||
self.b_cost = b_cost
|
||||
self.w_cost = w_cost
|
||||
self.c_cost = c_cost
|
||||
self.f_mem = f_mem
|
||||
self.b_mem = b_mem
|
||||
self.w_mem = w_mem
|
||||
self.fbw_cost = [f_cost, b_cost, w_cost]
|
||||
self.fbw_mem = [f_mem, b_mem, w_mem]
|
||||
self.max_mem = max_mem or f_mem * self.n_stage * 2
|
||||
|
||||
def get_id(self, cat, chunk, stage, micro):
|
||||
return (
|
||||
cat * 2 * self.n_stage * self.n_micro + chunk * self.n_stage * self.n_micro + stage * self.n_micro + micro
|
||||
)
|
||||
|
||||
def try_v_schedule(self, fill_f=True, fill_b=True, approved_bubble=None):
|
||||
count = []
|
||||
for i in range(self.n_stage):
|
||||
count.append([0] * 6)
|
||||
|
||||
end_time = [-1] * self.n_node
|
||||
cur_time = [0] * self.n_stage
|
||||
mem = [0] * self.n_stage
|
||||
stage_bubble = [0] * self.n_stage
|
||||
pending_w = [deque() for _ in range(self.n_stage)]
|
||||
schedule = [[] for _ in range(self.n_stage)]
|
||||
stage_str = [" " * i for i in range(self.n_stage)]
|
||||
|
||||
if approved_bubble is None:
|
||||
approved_bubble = [-1] * self.n_stage
|
||||
max_approved_bubble = max(approved_bubble)
|
||||
|
||||
def get_max_stage_bubble(stage=-1):
|
||||
max_stage_bubble = 0
|
||||
for bb in stage_bubble:
|
||||
max_stage_bubble = max(max_stage_bubble, bb)
|
||||
if stage >= 0:
|
||||
max_stage_bubble = max(max_stage_bubble, max_approved_bubble - approved_bubble[stage])
|
||||
return max_stage_bubble
|
||||
|
||||
def put_w(stage):
|
||||
assert len(pending_w[stage]) > 0
|
||||
_, chunk_, _ = pending_w[stage].popleft()
|
||||
put(2, chunk_, stage)
|
||||
|
||||
def put(cat, chunk, stage, assert_cnt=True):
|
||||
_tmp = _no_bubble = cur_time[stage] + self.fbw_cost[cat]
|
||||
_cnt = count[stage][cat * 2 + chunk]
|
||||
# assert _cnt < self.n_micro
|
||||
if _cnt >= self.n_micro:
|
||||
if not assert_cnt:
|
||||
stage_str[stage] += " "
|
||||
cur_time[stage] = _tmp # TODO
|
||||
return
|
||||
assert False
|
||||
assert mem[stage] + self.fbw_mem[cat] <= self.max_mem
|
||||
stage_str[stage] += "FfBbWw"[cat * 2 + chunk] + str(_cnt + 1) + " " * (3 - len(str(_cnt + 1)))
|
||||
if cat > 0 or chunk > 0:
|
||||
last_id = cat * 2 + chunk - 1
|
||||
if cat < 2:
|
||||
# if end_time[self.get_id(last_id // 2, last_id % 2, stage, _cnt)] < 0:
|
||||
# print(cat, chunk, stage, _cnt)
|
||||
# self.print_details(end_time)
|
||||
assert end_time[self.get_id(last_id // 2, last_id % 2, stage, _cnt)] >= 0
|
||||
else:
|
||||
assert end_time[self.get_id(1, chunk, stage, _cnt)] >= 0
|
||||
if chunk == 1 and cat < 2:
|
||||
if stage < self.n_stage - 1:
|
||||
_fa_id = self.get_id(cat, chunk, stage + 1, _cnt)
|
||||
assert end_time[_fa_id] >= 0
|
||||
_tmp = max(_tmp, end_time[_fa_id] + self.c_cost + self.fbw_cost[cat])
|
||||
if chunk == 0 and cat < 2:
|
||||
if stage > 0:
|
||||
_fa_id = self.get_id(cat, chunk, stage - 1, _cnt)
|
||||
# if end_time[_fa_id] < 0:
|
||||
# print(cat, chunk, stage, _cnt)
|
||||
# self.print_details(end_time)
|
||||
assert end_time[_fa_id] >= 0, f"{cat}, {chunk}, {stage}, {_cnt}"
|
||||
_tmp = max(_tmp, end_time[_fa_id] + self.c_cost + self.fbw_cost[cat])
|
||||
_id = self.get_id(cat, chunk, stage, _cnt)
|
||||
if count[stage][0] > 0:
|
||||
stage_bubble[stage] += _tmp - _no_bubble
|
||||
end_time[_id] = _tmp
|
||||
cur_time[stage] = _tmp
|
||||
mem[stage] += self.fbw_mem[cat]
|
||||
# noinspection PyTypeChecker
|
||||
schedule[stage].append((cat, chunk, _cnt))
|
||||
if cat == 1:
|
||||
pending_w[stage].append((2, chunk, _cnt))
|
||||
count[stage][cat * 2 + chunk] += 1
|
||||
|
||||
# for _ in range(2 * self.n_stage):
|
||||
# for i in range(self.n_stage):
|
||||
# if count[i][1] >= count[i][0]:
|
||||
# put(0, 0, i, assert_cnt=False)
|
||||
# continue
|
||||
# if i == self.n_stage - 1:
|
||||
# put(0, 1, i, assert_cnt=False)
|
||||
# continue
|
||||
# fa_id = self.get_id(0, 1, i + 1, count[i][1])
|
||||
# if 0 <= end_time[fa_id] < cur_time[i + 1]: # TODO
|
||||
# put(0, 1, i, assert_cnt=False)
|
||||
# else:
|
||||
# put(0, 0, i, assert_cnt=False)
|
||||
|
||||
for i in range(self.n_stage):
|
||||
put(0, 0, i)
|
||||
for i in range(self.n_stage - 1, -1, -1):
|
||||
if i == self.n_stage - 1:
|
||||
put(0, 1, i)
|
||||
continue
|
||||
tmp = end_time[self.get_id(0, 1, i + 1, 0)] + self.c_cost
|
||||
while (
|
||||
mem[i] + self.fbw_mem[0] * (2 + i * 2) <= self.max_mem
|
||||
and cur_time[i] + self.fbw_cost[0] <= tmp
|
||||
and count[i][0] < self.n_micro
|
||||
):
|
||||
for j in range(i + 1):
|
||||
put(0, 0, j)
|
||||
put(0, 1, i)
|
||||
iter_chunk_ = 0
|
||||
end_tmp = 0
|
||||
for i in range(self.n_stage):
|
||||
if i == 0:
|
||||
end_tmp = cur_time[0] + self.fbw_cost[1]
|
||||
continue
|
||||
tmp = end_tmp + self.c_cost
|
||||
while (
|
||||
count[i][0] + count[i][1] < count[i - 1][0] + count[i - 1][1]
|
||||
or count[i][1] <= count[i - 1][1] < self.n_micro
|
||||
):
|
||||
for j in range(self.n_stage - 1, i - 1, -1):
|
||||
if count[j][iter_chunk_] < self.n_micro:
|
||||
put(0, iter_chunk_, j)
|
||||
iter_chunk_ = 1 - iter_chunk_
|
||||
# while mem[i] + self.fbw_mem[0] <= self.max_mem and cur_time[i] + self.fbw_cost[0] <= tmp:
|
||||
# if iter_chunk_ == 0 and count[i][0] >= count[i - 1][0]:
|
||||
# break
|
||||
# for j in range(self.n_stage - 1, i - 1, -1):
|
||||
# if count[j][iter_chunk_] < self.n_micro:
|
||||
# put(0, iter_chunk_, j)
|
||||
# iter_chunk_ = 1 - iter_chunk_
|
||||
# end_tmp = max(tmp, cur_time[i]) + self.fbw_cost[1]
|
||||
|
||||
# init_bubble = get_max_stage_bubble()
|
||||
# print(stage_bubble)
|
||||
for _ in range(2 * self.n_micro):
|
||||
# check mem before putting b
|
||||
for i in range(self.n_stage):
|
||||
while mem[i] + self.fbw_mem[1] > self.max_mem:
|
||||
assert len(pending_w[i]) > 0
|
||||
put_w(i)
|
||||
b0_ranks, b1_ranks = [], []
|
||||
for i in range(self.n_stage):
|
||||
if count[i][3] >= count[i][2]:
|
||||
b0_ranks.append(i)
|
||||
elif i == self.n_stage - 1:
|
||||
b1_ranks.append(i)
|
||||
else:
|
||||
fa_id = self.get_id(1, 1, i + 1, count[i][3])
|
||||
if end_time[fa_id] >= 0 or count[i][2] >= self.n_micro:
|
||||
b1_ranks.append(i)
|
||||
else:
|
||||
b0_ranks.append(i)
|
||||
b_ranks = []
|
||||
# put b1
|
||||
for i in reversed(b1_ranks):
|
||||
b_ranks.append((i, 1))
|
||||
# put b0
|
||||
for i in b0_ranks:
|
||||
b_ranks.append((i, 0))
|
||||
for i, _chunk_ in b_ranks:
|
||||
fa_id = -1
|
||||
if _chunk_ == 1 and i < self.n_stage - 1:
|
||||
fa_id = self.get_id(1, 1, i + 1, count[i][3])
|
||||
if _chunk_ == 0 and i > 0:
|
||||
fa_id = self.get_id(1, 0, i - 1, count[i][2])
|
||||
while (
|
||||
len(pending_w[i]) > 0
|
||||
and fa_id >= 0
|
||||
and end_time[fa_id] + self.c_cost >= cur_time[i] + self.fbw_cost[2]
|
||||
):
|
||||
# fill the bubble
|
||||
put_w(i)
|
||||
if (
|
||||
len(pending_w[i]) > 0
|
||||
and end_time[fa_id] + self.c_cost - cur_time[i] > get_max_stage_bubble(i) - stage_bubble[i]
|
||||
):
|
||||
if _chunk_ == 1:
|
||||
put_w(i)
|
||||
elif fill_b:
|
||||
put_w(i)
|
||||
put(1, _chunk_, i)
|
||||
|
||||
# put f
|
||||
for i in range(self.n_stage):
|
||||
if count[i][1] >= self.n_micro:
|
||||
continue
|
||||
put_item = None
|
||||
if count[i][1] >= count[i][0]:
|
||||
put_item = 0
|
||||
elif i == self.n_stage - 1:
|
||||
put_item = 1
|
||||
else:
|
||||
if end_time[self.get_id(0, 1, i + 1, count[i][1])] >= 0:
|
||||
put_item = 1
|
||||
elif count[i][0] < self.n_micro:
|
||||
if i == 0:
|
||||
put_item = 0
|
||||
elif end_time[self.get_id(0, 0, i - 1, count[i][0])] >= 0:
|
||||
put_item = 0
|
||||
if put_item is None:
|
||||
continue
|
||||
# check mem before putting f
|
||||
while mem[i] + self.fbw_mem[0] > self.max_mem:
|
||||
assert len(pending_w[i]) > 0
|
||||
put_w(i)
|
||||
fa_id = -1
|
||||
if put_item == 0 and i > 0:
|
||||
fa_id = self.get_id(0, 0, i - 1, count[i][0])
|
||||
if put_item == 1 and i < self.n_stage - 1:
|
||||
fa_id = self.get_id(0, 1, i + 1, count[i][1])
|
||||
while (
|
||||
len(pending_w[i]) > 0
|
||||
and fa_id >= 0
|
||||
and end_time[fa_id] + self.c_cost >= cur_time[i] + self.fbw_cost[2]
|
||||
):
|
||||
# fill the bubble
|
||||
put_w(i)
|
||||
if (
|
||||
len(pending_w[i]) > 0
|
||||
and end_time[fa_id] + self.c_cost - cur_time[i] > get_max_stage_bubble(i) - stage_bubble[i]
|
||||
):
|
||||
if fill_f:
|
||||
put_w(i)
|
||||
put(0, put_item, i)
|
||||
|
||||
for i in range(self.n_stage):
|
||||
while len(pending_w[i]) > 0:
|
||||
put_w(i)
|
||||
|
||||
# for i in range(self.n_stage):
|
||||
# print(stage_str[i])
|
||||
|
||||
max_bubble = get_max_stage_bubble()
|
||||
expected_time = sum(self.fbw_cost) * self.n_micro * 2
|
||||
max_bubble / expected_time
|
||||
# print("%6.4f" % bubble_rate, "->", stage_bubble)
|
||||
if max_approved_bubble < 0 or max_bubble < max_approved_bubble:
|
||||
_schedule, _end_time, _max_bubble = self.try_v_schedule(
|
||||
fill_f=fill_f,
|
||||
fill_b=fill_b,
|
||||
approved_bubble=stage_bubble,
|
||||
)
|
||||
if _max_bubble < max_bubble:
|
||||
return _schedule, _end_time, _max_bubble
|
||||
# print("%2d %3d, [%5d %5d %5d], %6d -> %6.4f %6.4f" % \
|
||||
# (self.n_stage, self.n_micro, *self.fbw_cost, self.max_mem // self.f_mem, init_bubble / expected_time, bubble_rate), max_bubble)
|
||||
return schedule, end_time, max_bubble
|
||||
|
||||
def print_details(self, end_time, print_scaling=1):
|
||||
for stage in range(self.n_stage):
|
||||
stage_str = ["."] * int(max(end_time) / print_scaling)
|
||||
for _cat in range(3):
|
||||
for _chunk in range(2):
|
||||
for _micro in range(self.n_micro):
|
||||
_id = self.get_id(_cat, _chunk, stage, _micro)
|
||||
if end_time[_id] < 0:
|
||||
continue
|
||||
end = int(end_time[_id] / print_scaling)
|
||||
start = int((end_time[_id] - self.fbw_cost[_cat]) / print_scaling)
|
||||
for j in range(start, end):
|
||||
if j == start or j == end - 1:
|
||||
stage_str[j] = "FfBbWw"[_cat * 2 + _chunk]
|
||||
elif j == start + 1:
|
||||
if _micro >= 10:
|
||||
stage_str[j] = str(_micro // 10)
|
||||
else:
|
||||
stage_str[j] = str(_micro)
|
||||
elif j == start + 2 and _micro >= 10:
|
||||
stage_str[j] = str(_micro % 10)
|
||||
else:
|
||||
stage_str[j] = "-"
|
||||
_str = ""
|
||||
for _c in stage_str:
|
||||
_str += _c
|
||||
print(_str)
|
||||
|
||||
def get_v_schedule(self, only_run_time=False):
|
||||
schedule, end_time, max_bubble = None, None, None
|
||||
expected_time = sum(self.fbw_cost) * self.n_micro * 2
|
||||
for fill_b in [True, False]:
|
||||
for fill_f in [True, False]:
|
||||
_schedule, _end_time, _max_bubble = self.try_v_schedule(fill_b=fill_b, fill_f=fill_f)
|
||||
# print("")
|
||||
if max_bubble is None or _max_bubble < max_bubble:
|
||||
max_bubble = _max_bubble
|
||||
schedule = _schedule
|
||||
end_time = _end_time
|
||||
if only_run_time:
|
||||
return max_bubble + expected_time
|
||||
# self.print_details(end_time, print_scaling=1)
|
||||
max_bubble / (expected_time + max_bubble)
|
||||
# print("%2d %3d, [%5d %5d %5d %5d], %6d -> %6.4f" % \
|
||||
# (self.n_stage, self.n_micro, *self.fbw_cost, self.c_cost, self.max_mem // self.f_mem, bubble_rate))
|
||||
local_order = [[] for _ in range(self.n_stage)]
|
||||
comm_id = {}
|
||||
comm_id_counter = 0
|
||||
post_validation_time = 0
|
||||
for i in range(self.n_stage - 1, -1, -1):
|
||||
pv_id = min(2 * (self.n_stage - 1 - i), self.n_micro - 1)
|
||||
post_validation_time = max(
|
||||
post_validation_time, end_time[self.get_id(0, 0, i, pv_id)] - self.fbw_cost[0] - self.c_cost
|
||||
)
|
||||
# post_validation_time = 0
|
||||
# print(i, pv_id, post_validation_time)
|
||||
for it in ["RECV_", "SEND_", ""]:
|
||||
if i == 0 and it == "SEND_":
|
||||
continue
|
||||
if i == self.n_stage - 1 and it == "RECV_":
|
||||
continue
|
||||
# stage_ = i - 1 if it == "RECV_" else i
|
||||
stage_ = i
|
||||
local_order[stage_].append(
|
||||
ScheduledNode(
|
||||
type=it + "POST_VALIDATION",
|
||||
chunk=0,
|
||||
stage=stage_,
|
||||
minibatch=0,
|
||||
start_time=post_validation_time,
|
||||
completion_time=post_validation_time,
|
||||
)
|
||||
)
|
||||
comm_id[local_order[stage_][-1]] = comm_id_counter
|
||||
comm_id_counter += 1
|
||||
for i in range(self.n_stage):
|
||||
for _cat_, _chunk_, _micro_ in schedule[i]:
|
||||
complete_time = end_time[self.get_id(_cat_, _chunk_, i, _micro_)]
|
||||
local_order[i].append(
|
||||
ScheduledNode(
|
||||
type="FBW"[_cat_],
|
||||
chunk=_chunk_ if _cat_ == 0 else 1 - _chunk_,
|
||||
stage=i,
|
||||
minibatch=_micro_,
|
||||
start_time=complete_time - self.fbw_cost[_cat_],
|
||||
completion_time=complete_time,
|
||||
)
|
||||
)
|
||||
if _cat_ == 2: # no communication for W
|
||||
continue
|
||||
cat_str = "FORWARD" if _cat_ == 0 else "BACKWARD"
|
||||
|
||||
def communicate(send_recv, stage_):
|
||||
# noinspection PyTypeChecker
|
||||
local_order[stage_].append(
|
||||
ScheduledNode(
|
||||
type=send_recv + cat_str,
|
||||
chunk=_chunk_ if _cat_ == 0 else 1 - _chunk_,
|
||||
stage=stage_,
|
||||
minibatch=_micro_,
|
||||
start_time=complete_time,
|
||||
completion_time=complete_time,
|
||||
)
|
||||
)
|
||||
comm_id[local_order[stage_][-1]] = comm_id_counter
|
||||
|
||||
if _chunk_ == 1 and i > 0:
|
||||
communicate("SEND_", i)
|
||||
communicate("RECV_", i - 1)
|
||||
if _chunk_ == 0 and i < self.n_stage - 1:
|
||||
communicate("SEND_", i)
|
||||
communicate("RECV_", i + 1)
|
||||
comm_id_counter += 1
|
||||
for rank in range(self.n_stage):
|
||||
# For nodes with the same timestamp on the same stage, communication will be prioritized.
|
||||
def even_breaker(x: ScheduledNode):
|
||||
# Compute nodes are always delayed.
|
||||
if x.type in ["F", "B", "W"]:
|
||||
return comm_id_counter
|
||||
# For comm nodes, order by their unique comm id
|
||||
return comm_id[x]
|
||||
|
||||
local_order[rank] = list(sorted(local_order[rank], key=lambda x: (x.start_time, even_breaker(x))))
|
||||
# If a recv with intersects with previous computation, reorder them so that recv
|
||||
# is executed before computation and hence can be overlapped.
|
||||
for i in range(len(local_order[rank])):
|
||||
if (
|
||||
i > 0
|
||||
and local_order[rank][i - 1].type in {"F", "B", "W"}
|
||||
and local_order[rank][i].type.startswith("RECV")
|
||||
and "POST_VALIDATION" not in local_order[rank][i].type
|
||||
and local_order[rank][i].start_time <= local_order[rank][i - 1].completion_time
|
||||
):
|
||||
local_order[rank][i], local_order[rank][i - 1] = local_order[rank][i - 1], local_order[rank][i]
|
||||
|
||||
local_order_with_rollback = [[] for _ in range(self.n_stage)]
|
||||
for rank in range(self.n_stage):
|
||||
rollback_comm = set()
|
||||
if rank > 0:
|
||||
for node in local_order[rank - 1]:
|
||||
if node.type == "POST_VALIDATION":
|
||||
break
|
||||
if node.type == "SEND_FORWARD":
|
||||
assert node.chunk == 0
|
||||
rollback_comm.add(node.minibatch)
|
||||
for node in local_order[rank]:
|
||||
if node.type == "RECV_FORWARD" and node.chunk == 0 and node.minibatch in rollback_comm:
|
||||
rollback = True
|
||||
rollback_comm.remove(node.minibatch)
|
||||
else:
|
||||
rollback = False
|
||||
local_order_with_rollback[rank].append(
|
||||
ScheduledNode(
|
||||
type=node.type,
|
||||
chunk=node.chunk,
|
||||
stage=node.stage,
|
||||
minibatch=node.minibatch,
|
||||
start_time=node.start_time,
|
||||
completion_time=node.completion_time,
|
||||
rollback=rollback,
|
||||
)
|
||||
)
|
||||
assert len(rollback_comm) == 0
|
||||
for node in local_order_with_rollback[rank]:
|
||||
print(f"Rank {rank} Node info {node}")
|
||||
print(f"{node.type}-{node.minibatch}-{int(node.rollback)}", end=", ")
|
||||
print()
|
||||
|
||||
return local_order_with_rollback
|
|
@ -0,0 +1,615 @@
|
|||
from functools import partial
|
||||
from typing import Any, Callable, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
import torch.distributed
|
||||
from torch.nn import Module, ModuleList
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication
|
||||
from colossalai.pipeline.schedule.v_schedule import ScheduledNode
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
||||
from ._utils import detach, get_batch_size, get_micro_batch, retain_grad, to_device
|
||||
from .base import PipelineSchedule
|
||||
|
||||
AUTO_SCHEDULE_COMMUNICATION_TYPES = {"RECV_FORWARD", "RECV_BACKWARD", "SEND_FORWARD", "SEND_BACKWARD"}
|
||||
|
||||
|
||||
def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None:
|
||||
if wait_handles is not None:
|
||||
for req in wait_handles:
|
||||
req.wait()
|
||||
|
||||
|
||||
class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
def __init__(
|
||||
self,
|
||||
stage_manager: PipelineStageManager,
|
||||
schedule: List[ScheduledNode],
|
||||
num_model_chunks: int,
|
||||
num_microbatch: Optional[int] = None,
|
||||
microbatch_size: Optional[int] = None,
|
||||
enable_metadata_cache: bool = True,
|
||||
overlap_p2p: bool = True,
|
||||
):
|
||||
super().__init__(stage_manager)
|
||||
self.num_microbatch = num_microbatch
|
||||
self.collect_non_loss_data = None
|
||||
self.forward_only = None
|
||||
|
||||
self.schedules = schedule
|
||||
self.it = 0 # curr iteration
|
||||
self.do_post_validation = False
|
||||
self.is_first_run = True
|
||||
self.optimizer = None
|
||||
self.num_model_chunks = num_model_chunks
|
||||
|
||||
# P2PMeta cache
|
||||
# self.enable_metadata_cache = enable_metadata_cache
|
||||
# self.send_tensor_metadata = True
|
||||
# self.send_grad_metadata = True
|
||||
# self.tensor_metadata_recv = None
|
||||
# self.grad_metadata_recv = None
|
||||
|
||||
# P2P communication
|
||||
self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p)
|
||||
|
||||
# init buffer
|
||||
self._free_buffers()
|
||||
|
||||
def _free_buffers(self):
|
||||
# free local buffer
|
||||
# two dim array, first dim is the model chunk, second dim is the microbatch queue
|
||||
self.input_tensors = [[], []]
|
||||
self.output_tensors = [[], []]
|
||||
self.send_forward_buffer = [[], []]
|
||||
self.recv_forward_buffer = [[], []]
|
||||
self.send_backward_buffer = [[], []]
|
||||
self.recv_backward_buffer = [[], []]
|
||||
self.forward_data_store = []
|
||||
self.local_send_forward_buffer = []
|
||||
self.local_send_backward_buffer = []
|
||||
|
||||
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
||||
"""Load a batch from data iterator.
|
||||
|
||||
Args:
|
||||
data_iter (Iterable): Data iterator.
|
||||
device (Optional[torch.device], optional): Target device. Defaults to None.
|
||||
"""
|
||||
batch = next(data_iter)
|
||||
if device is not None:
|
||||
batch = tree_map(partial(to_device, device=device), batch)
|
||||
|
||||
self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
|
||||
self.batch = batch
|
||||
self.batch_size = get_batch_size(batch)
|
||||
|
||||
if self.microbatch_size is None:
|
||||
assert self.batch_size % self.num_microbatch == 0, "Batch size should divided by the number of microbatch"
|
||||
self.microbatch_size = self.batch_size // self.num_microbatch
|
||||
if self.num_microbatch is None:
|
||||
assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size"
|
||||
self.num_microbatch = self.batch_size // self.microbatch_size
|
||||
|
||||
if not self.forward_only:
|
||||
assert self.last_batch_size is None or self.last_batch_size == self.batch_size
|
||||
assert self.batch_size == self.microbatch_size * self.num_microbatch
|
||||
|
||||
assert (
|
||||
self.num_microbatch % self.stage_manager.num_stages == 0
|
||||
), "Number of microbatch should be an integer multiple of number of pipeline parallel devices"
|
||||
|
||||
if self.forward_only:
|
||||
self.num_microbatch = (self.batch_size - 1) // self.microbatch_size + 1
|
||||
# NOTE: disable metadata cache when batch size changes (not valid anymore)
|
||||
# if self.batch_size != self.last_batch_size:
|
||||
# self.enable_metadata_cache = False
|
||||
# self.send_tensor_metadata = True
|
||||
# self.send_grad_metadata = True
|
||||
# self.tensor_metadata_recv = None
|
||||
# self.grad_metadata_recv = None
|
||||
|
||||
self.last_batch_size = self.batch_size
|
||||
|
||||
def load_micro_batch(self, model_chunk_id: int) -> Any:
|
||||
"""Load a micro batch from the current batch.
|
||||
|
||||
Args:
|
||||
microbatch_id (int): the current model chunk idx.
|
||||
|
||||
Returns:
|
||||
Any: Micro batch.
|
||||
"""
|
||||
assert self.microbatch_offset[model_chunk_id] <= self.batch_size, "Microbatches exhausted"
|
||||
micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size)
|
||||
self.microbatch_offset[model_chunk_id] += self.microbatch_size
|
||||
return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch)
|
||||
|
||||
def get_model_chunk_id(self, microbatch_id: int, is_forward: bool) -> int:
|
||||
"""Helper method to get the model chunk ID given the iteration number.
|
||||
|
||||
Args:
|
||||
microbatch_id (int): the current microbatch idx
|
||||
forward (bool): if is the forward process
|
||||
|
||||
Returns:
|
||||
int: The model chunk idx of the input microbatch_id
|
||||
"""
|
||||
assert (
|
||||
microbatch_id < self.num_microbatch * self.num_model_chunks
|
||||
), f"microbatch_id {microbatch_id} is out of range ({self.num_microbatch * self.num_model_chunks})"
|
||||
microbatch_id_in_group = microbatch_id % (self.stage_manager.num_stages * self.num_model_chunks)
|
||||
model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages
|
||||
if not is_forward:
|
||||
# Reverse order
|
||||
model_chunk_id = self.num_model_chunks - model_chunk_id - 1
|
||||
return model_chunk_id
|
||||
|
||||
def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Tuple[Any, List]:
|
||||
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
|
||||
For ZBV.
|
||||
|
||||
Args:
|
||||
model_chunk_id (int): The current model chunk idx.
|
||||
prev_rank (int, optional): The rank of the source of the tensor.
|
||||
|
||||
Returns:
|
||||
Any: The input tensor or input tensor list.
|
||||
Any: The wait handles for the communication.
|
||||
"""
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if model_chunk_id == 0:
|
||||
################
|
||||
# chunk = 0 & is_first_stage
|
||||
# do nothing; cause u are chunk 0 in first rank, u have no prev rank;
|
||||
#################
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
return None, []
|
||||
|
||||
################
|
||||
# chunk = 0 & not is_first_stage
|
||||
# Recv y from PREV_rank as input
|
||||
#################
|
||||
else:
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
input_tensor, wait_handles = self.comm.recv_forward(prev_rank=prev_rank)
|
||||
# metadata_recv=self.tensor_metadata_recv
|
||||
# if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
# self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
return input_tensor, wait_handles
|
||||
|
||||
else:
|
||||
################
|
||||
# chunk = 1 & is_last_stage
|
||||
# get y from local_send_forward_buffer as input
|
||||
################
|
||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
input_tensor = self.local_send_forward_buffer.pop(0)
|
||||
|
||||
# if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
# self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
|
||||
return input_tensor, []
|
||||
|
||||
################
|
||||
# chunk = 1 & not is_last_stage
|
||||
# recv y from NEXT_rank as input
|
||||
################
|
||||
else:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
input_tensor, wait_handles = self.comm.recv_forward(next_rank)
|
||||
|
||||
# metadata_recv=self.tensor_metadata_recv
|
||||
# if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
# self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
|
||||
return input_tensor, wait_handles
|
||||
|
||||
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Tuple[Any, List]:
|
||||
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
||||
For ZBV.
|
||||
|
||||
Args:
|
||||
model_chunk_id (int): The current model chunk idx.
|
||||
next_rank (int, optional): The rank of the source of the tensor.
|
||||
|
||||
Returns:
|
||||
Any: The input gradient tensor or gradient tensor list.
|
||||
Any: The wait handles for the communication.
|
||||
"""
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if model_chunk_id == 0:
|
||||
# bwd chunk0 is right V;
|
||||
################
|
||||
# chunk = 0 & is_last_stage
|
||||
# get dy from local recv_bwd_buffer
|
||||
################
|
||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
output_tensor_grad = self.local_send_backward_buffer.pop(0)
|
||||
# if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
# self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
return output_tensor_grad, []
|
||||
|
||||
################
|
||||
# chunk = 0 & not is_last_stage
|
||||
# Recv bwd from next stage;
|
||||
################
|
||||
else:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
output_tensor_grad, wait_handles = self.comm.recv_backward(next_rank)
|
||||
# metadata_recv=self.grad_metadata_recv
|
||||
# if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
# self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
return output_tensor_grad, wait_handles
|
||||
|
||||
else:
|
||||
# bwd chunk1 is left V;
|
||||
################
|
||||
# chunk = 1 & is_first_stage
|
||||
# do nothing; get loss from local
|
||||
################
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
return None, []
|
||||
|
||||
################
|
||||
# chunk = 1 & not is_first_stage
|
||||
# self.comm.recv_backward recv bwd from prev stage;
|
||||
################
|
||||
else:
|
||||
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
output_tensor_grad, wait_handles = self.comm.recv_backward(next_rank=prev_rank)
|
||||
|
||||
# metadata_recv=self.grad_metadata_recv
|
||||
# if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
# self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
return output_tensor_grad, wait_handles
|
||||
|
||||
def send_forward(self, model_chunk_id: int, output_tensor: Any, next_rank: int = None) -> List:
|
||||
"""Sends the input tensor to the next stage in pipeline.
|
||||
For ZBV.
|
||||
|
||||
Args:
|
||||
model_chunk_id (int): The current model chunk idx.
|
||||
output_object (Any): Object to be sent.
|
||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
|
||||
Returns:
|
||||
Any: The wait handles for the communication.
|
||||
"""
|
||||
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if model_chunk_id == 0:
|
||||
################
|
||||
# chunk = 0 && is_last_stage
|
||||
# hold y on local_send_forward_buffer
|
||||
################
|
||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
self.local_send_forward_buffer.append(output_tensor)
|
||||
return []
|
||||
|
||||
################
|
||||
# chunk = 0 && not is_last_stage
|
||||
# self.comm.send_forward send y to NEXT stage
|
||||
################
|
||||
else:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
send_handles = self.comm.send_forward(output_object=output_tensor, next_rank=next_rank)
|
||||
# send_metadata=self.send_tensor_metadata
|
||||
# self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
return send_handles
|
||||
|
||||
else:
|
||||
################
|
||||
# chunk = 1 && is_first_stage
|
||||
# do nothing; cause you are the last chunk on last stage;
|
||||
################
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
return []
|
||||
|
||||
################
|
||||
# chunk = 1 && not is_first_stage
|
||||
# self.comm.send_forward send y to PREV stage
|
||||
################
|
||||
else:
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
send_handles = self.comm.send_forward(output_tensor, prev_rank)
|
||||
# send_metadata=self.send_tensor_metadata
|
||||
# self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
return send_handles
|
||||
|
||||
def send_backward(self, model_chunk_id: int, input_tensor_grad: Any, prev_rank: int = None) -> List:
|
||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||
For ZBV.
|
||||
|
||||
Args:
|
||||
model_chunk_id (int): The current model chunk idx.
|
||||
input_object (Any): Object to be sent.
|
||||
prev_rank (int, optional): The rank of the recipient of the tensor
|
||||
|
||||
Returns:
|
||||
Any: The wait handles for the communication.
|
||||
"""
|
||||
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if model_chunk_id == 0:
|
||||
# bwd chunk0 is right V;
|
||||
################
|
||||
# chunk = 0 && is_first_stage
|
||||
# do nothing; cause u are the first chunk in first stage; bwd end
|
||||
# send input_tensor_grad to local buffer;
|
||||
################
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
return []
|
||||
|
||||
################
|
||||
# chunk = 0 && not is_first_stage
|
||||
# Send dx to PREV stage;
|
||||
################
|
||||
else:
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
send_handles = self.comm.send_backward(input_tensor_grad, prev_rank)
|
||||
# send_metadata=self.send_grad_metadata
|
||||
return send_handles
|
||||
|
||||
# bwd chunk1 is left V;
|
||||
else:
|
||||
################
|
||||
# chunk = 1 && is_last_stage
|
||||
# hold dy to local_send_bwd_buffer;
|
||||
################
|
||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
self.local_send_backward_buffer.append(input_tensor_grad)
|
||||
return []
|
||||
|
||||
################
|
||||
# chunk = 1 && not is_last_stage
|
||||
# Send dx to NEXT stage;
|
||||
################
|
||||
else:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
# print(f"send bwd input_tensor_grad {input_tensor_grad}")
|
||||
send_handles = self.comm.send_backward(input_tensor_grad, next_rank)
|
||||
# send_metadata=self.send_grad_metadata
|
||||
return send_handles
|
||||
|
||||
def forward_step(
|
||||
self,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
model_chunk_id: int,
|
||||
input_obj: Optional[dict],
|
||||
criterion: Callable,
|
||||
accum_loss: Optional[torch.Tensor] = None,
|
||||
outputs: Optional[List[Any]] = None,
|
||||
) -> Union[torch.Tensor, dict]:
|
||||
"""Forward one step of the pipeline
|
||||
Args:
|
||||
model (ModuleList or Module): Model Chunk to be run
|
||||
input_obj (Optional[dict]): The output from the previous stage. If it is the first stage, the `input_obj` is None.
|
||||
criterion (Callable): Criterion to calculate loss.
|
||||
accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None.
|
||||
outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None.
|
||||
|
||||
Returns:
|
||||
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
|
||||
"""
|
||||
# Load input ids, attention mask and labels
|
||||
# micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id)
|
||||
|
||||
# for the first stage, input_obj is None
|
||||
# for other stages, input_obj is the output of the previous/next stage containing hidden_states etc.
|
||||
# Only attention_mask from micro_batch is used
|
||||
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
output_obj = model_chunk[model_chunk_id](input_obj)
|
||||
# last layer in model
|
||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
loss = criterion(output_obj) / self.num_microbatch
|
||||
if accum_loss is not None:
|
||||
accum_loss.add_(loss.detach())
|
||||
if outputs is not None:
|
||||
outputs.append(tree_map(detach, output_obj))
|
||||
return loss
|
||||
else:
|
||||
return output_obj
|
||||
|
||||
def backward_b_step(
|
||||
self,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
model_chunk_id: int,
|
||||
# optimizer: OptimizerWrapper,
|
||||
input_obj: Optional[dict],
|
||||
output_obj: Union[dict, torch.Tensor],
|
||||
output_obj_grad: Optional[dict],
|
||||
) -> Optional[dict]:
|
||||
"""Backward one step of the pipeline
|
||||
|
||||
Args:
|
||||
optimizer (OptimizerWrapper): Optimizer to update the model
|
||||
input_obj (Optional[dict]): Output of the previous stage. If it is the first stage, the `input_obj` is None.
|
||||
output_obj (Union[dict, torch.Tensor]): Output of the current stage. If it is the last stage, the output is the loss (Tensor).
|
||||
output_obj_grad (dict): Gradient of the `output_obj`. If it is the last stage, the `output_obj_grad` is None.
|
||||
|
||||
Returns:
|
||||
Optional[dict]: Gradient of the `input_obj`. If it is the first stage, the `input_obj_grad` is None.
|
||||
"""
|
||||
# calculate bwd b step ; only dx = w*dy;
|
||||
|
||||
# Retain the grad on the input_obj.
|
||||
tree_map(retain_grad, input_obj)
|
||||
|
||||
if model_chunk_id == 0:
|
||||
# bwd step
|
||||
torch.autograd.backward(
|
||||
tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True
|
||||
)
|
||||
else:
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# loss backward; output_obj is loss
|
||||
torch.autograd.backward(output_obj, inputs=input_obj, retain_graph=True)
|
||||
else:
|
||||
# commom bwd step
|
||||
# print(f"bwd output_obj {output_obj} output_obj_grad {output_obj_grad} input_obj {input_obj}")
|
||||
# BUG:output_obj_grad is None
|
||||
torch.autograd.backward(
|
||||
tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True
|
||||
)
|
||||
|
||||
return input_obj.grad
|
||||
|
||||
def backward_w_step(
|
||||
self,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
model_chunk_id: int,
|
||||
# optimizer: OptimizerWrapper,
|
||||
input_obj: Optional[dict],
|
||||
output_obj: Union[dict, torch.Tensor],
|
||||
output_obj_grad: Optional[dict],
|
||||
):
|
||||
# calculate bwd w step ; only dw = x*dy;
|
||||
if model_chunk_id == 0:
|
||||
torch.autograd.backward(
|
||||
tensors=output_obj, grad_tensors=output_obj_grad, inputs=list(model_chunk[model_chunk_id].parameters())
|
||||
)
|
||||
|
||||
else:
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
torch.autograd.backward(output_obj_grad, inputs=list(model=model_chunk[model_chunk_id].parameters()))
|
||||
|
||||
else:
|
||||
torch.autograd.backward(
|
||||
tensors=output_obj,
|
||||
grad_tensors=output_obj_grad,
|
||||
inputs=list(model_chunk[model_chunk_id].parameters()),
|
||||
)
|
||||
|
||||
def schedule_f(
|
||||
self,
|
||||
scheduled_node,
|
||||
model_chunk: torch.nn.ModuleList,
|
||||
model_chunk_id: int,
|
||||
input_obj: Optional[dict],
|
||||
criterion: Callable,
|
||||
accum_loss: Optional[torch.Tensor] = None,
|
||||
outputs: Optional[List[Any]] = None,
|
||||
):
|
||||
# Step1: recv fwd
|
||||
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# first layer
|
||||
input_obj = input_obj
|
||||
else:
|
||||
# other layer
|
||||
input_obj, wait_handles = self.recv_forward(model_chunk_id)
|
||||
# print(f"recv input_obj {input_obj}")
|
||||
_wait_p2p(wait_handles)
|
||||
# Step2: fwd step
|
||||
output_obj = self.forward_step(
|
||||
model_chunk=model_chunk,
|
||||
model_chunk_id=model_chunk_id,
|
||||
input_obj=input_obj,
|
||||
criterion=criterion,
|
||||
accum_loss=accum_loss,
|
||||
outputs=outputs,
|
||||
)
|
||||
# print(f"model_chunk_id {model_chunk_id} fwd output_obj {output_obj}")
|
||||
|
||||
# add input and output object for backward
|
||||
self.input_tensors[model_chunk_id].append(input_obj)
|
||||
self.output_tensors[model_chunk_id].append(output_obj)
|
||||
|
||||
# Step3: send fwd
|
||||
send_handles = self.send_forward(model_chunk_id=model_chunk_id, output_tensor=output_obj)
|
||||
|
||||
def schedule_b(
|
||||
self,
|
||||
scheduled_node,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
model_chunk_id: int,
|
||||
# optimizer: OptimizerWrapper,
|
||||
# input_obj: Optional[dict],
|
||||
# output_obj: Union[dict, torch.Tensor],
|
||||
# output_obj_grad: Optional[dict],
|
||||
):
|
||||
# Step1: recv bwd
|
||||
# not first stage and chunk 1
|
||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
output_tensor_grad, recv_bwd_handles = None, []
|
||||
# print(f"recv output_tensor_grad {output_tensor_grad}")
|
||||
else:
|
||||
output_tensor_grad, recv_bwd_handles = self.recv_backward(model_chunk_id=model_chunk_id)
|
||||
# print(f"recv output_tensor_grad {output_tensor_grad}")
|
||||
|
||||
# get input and output object from buffer
|
||||
input_obj = self.input_tensors[model_chunk_id].pop()
|
||||
output_obj = self.output_tensors[model_chunk_id].pop()
|
||||
|
||||
_wait_p2p(recv_bwd_handles)
|
||||
# print(f"input_obj {input_obj} output_obj {output_obj} output_tensor_grad {output_tensor_grad}")
|
||||
# Step2: bwd step
|
||||
input_object_grad = self.backward_b_step(
|
||||
model_chunk=model_chunk,
|
||||
model_chunk_id=model_chunk_id,
|
||||
# optimizer: OptimizerWrapper,
|
||||
input_obj=input_obj,
|
||||
output_obj=output_obj,
|
||||
output_obj_grad=output_tensor_grad,
|
||||
)
|
||||
print(f"input_object_grad {input_object_grad}")
|
||||
|
||||
# Step3: send bwd
|
||||
send_bwd_handles = self.send_backward(model_chunk_id=model_chunk_id, input_tensor_grad=input_object_grad)
|
||||
|
||||
def schedule_w(
|
||||
self,
|
||||
scheduled_node,
|
||||
non_w_pending,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
model_chunk_id: int,
|
||||
# optimizer: OptimizerWrapper,
|
||||
input_obj: Optional[dict],
|
||||
output_obj: Union[dict, torch.Tensor],
|
||||
output_obj_grad: Optional[dict],
|
||||
):
|
||||
self.backward_w_step(
|
||||
model_chunk=model_chunk,
|
||||
model_chunk_id=model_chunk_id,
|
||||
# optimizer: OptimizerWrapper,
|
||||
input_obj=input_obj,
|
||||
output_obj=output_obj,
|
||||
output_obj_grad=output_obj_grad,
|
||||
)
|
||||
|
||||
def run_forward_backward(
|
||||
self,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
data_iter: Iterable,
|
||||
criterion: Callable[..., Any],
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
return_loss: bool = False,
|
||||
return_outputs: bool = False,
|
||||
):
|
||||
it = self.it
|
||||
# while we still have schedules_node in self.schedules
|
||||
while it < len(self.schedules):
|
||||
scheduled_node = self.schedules[it]
|
||||
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
|
||||
# communication
|
||||
if scheduled_node.type == "RECV_FORWARD":
|
||||
self.recv_forward()
|
||||
elif scheduled_node.type == "RECV_BACKWARD":
|
||||
self.recv_backward()
|
||||
elif scheduled_node.type == "SEND_FORWARD":
|
||||
self.send_forward()
|
||||
elif scheduled_node.type == "SEND_BACKWARD":
|
||||
self.send_backward()
|
||||
elif scheduled_node.type == "F":
|
||||
self.schedule_f()
|
||||
elif scheduled_node.type == "B":
|
||||
self.schedule_b()
|
||||
elif scheduled_node.type == "W":
|
||||
self.schedule_w()
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,341 @@
|
|||
from copy import deepcopy
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
class MlpModel(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, num_layers):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)])
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
num_params = 0
|
||||
num_params_trainable = 0
|
||||
for p in model.parameters():
|
||||
num_params += p.numel()
|
||||
if p.requires_grad:
|
||||
num_params_trainable += p.numel()
|
||||
return num_params, num_params_trainable
|
||||
|
||||
|
||||
def test_zerobubble_pipeline_base(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
port: int,
|
||||
):
|
||||
# init dist
|
||||
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
pg_mesh = ProcessGroupMesh(world_size)
|
||||
|
||||
stage_manager = PipelineStageManager(pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=world_size)
|
||||
|
||||
scheduler = ZeroBubbleVPipeScheduler(
|
||||
schedule=[],
|
||||
stage_manager=stage_manager,
|
||||
num_model_chunks=world_size,
|
||||
num_microbatch=1,
|
||||
overlap_p2p=False,
|
||||
)
|
||||
|
||||
rank = dist.get_rank()
|
||||
|
||||
# init model and input
|
||||
num_layers = 8
|
||||
in_dim = out_dim = 2048
|
||||
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
|
||||
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
|
||||
input0 = torch.rand(in_dim, out_dim, requires_grad=True).to(rank)
|
||||
|
||||
input0.clone()
|
||||
deepcopy(model)
|
||||
|
||||
if rank == 0:
|
||||
# layer 0 & 7 to chunk 0 on rank0
|
||||
chunk_0 = torch.nn.ModuleList().to(rank)
|
||||
for idx, sub_model in enumerate(model.layers):
|
||||
if idx == 0 or idx == 7:
|
||||
chunk_0.append(sub_model)
|
||||
elif rank == 1:
|
||||
# layer 1 & 6 to chunk 1 on rank1
|
||||
chunk_1 = torch.nn.ModuleList().to(rank)
|
||||
for idx, sub_model in enumerate(model.layers):
|
||||
if idx == 1 or idx == 6:
|
||||
chunk_1.append(sub_model)
|
||||
elif rank == 2:
|
||||
# layer 2 & 5 to chunk 2 on rank2
|
||||
chunk_2 = torch.nn.ModuleList().to(rank)
|
||||
for idx, sub_model in enumerate(model.layers):
|
||||
if idx == 2 or idx == 5:
|
||||
chunk_2.append(sub_model)
|
||||
else:
|
||||
# layer 3 & 4 to chunk 3 on rank3
|
||||
chunk_3 = torch.nn.Sequential().to(rank)
|
||||
for idx, sub_model in enumerate(model.layers):
|
||||
if idx == 3 or idx == 4:
|
||||
chunk_3.append(sub_model)
|
||||
print(
|
||||
f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
|
||||
)
|
||||
|
||||
def criterion(x, *args, **kwargs):
|
||||
return (x * x).mean()
|
||||
|
||||
##########################
|
||||
# Step1: fwd
|
||||
##########################
|
||||
######
|
||||
# fwd 1->4
|
||||
######
|
||||
# chunk 0 id 0 (layer 0) fwd
|
||||
if rank == 0:
|
||||
chunk_id = 0
|
||||
scheduler.schedule_f(
|
||||
scheduled_node=None,
|
||||
model_chunk=chunk_0,
|
||||
model_chunk_id=chunk_id,
|
||||
input_obj=input0,
|
||||
criterion=criterion,
|
||||
accum_loss=None,
|
||||
outputs=None,
|
||||
)
|
||||
print(
|
||||
f"chunk 0 id 0 (layer 0)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
|
||||
)
|
||||
|
||||
# chunk 1 id 0 (layer 1) fwd
|
||||
if rank == 1:
|
||||
chunk_id = 0
|
||||
scheduler.schedule_f(
|
||||
scheduled_node=None,
|
||||
model_chunk=chunk_1,
|
||||
model_chunk_id=chunk_id,
|
||||
input_obj=None,
|
||||
criterion=criterion,
|
||||
accum_loss=None,
|
||||
outputs=None,
|
||||
)
|
||||
print(
|
||||
f"chunk 1 id 0 (layer 1)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
|
||||
)
|
||||
|
||||
# chunk 2 id 0 (layer 2) fwd
|
||||
if rank == 2:
|
||||
chunk_id = 0
|
||||
scheduler.schedule_f(
|
||||
scheduled_node=None,
|
||||
model_chunk=chunk_2,
|
||||
model_chunk_id=chunk_id,
|
||||
input_obj=None,
|
||||
criterion=criterion,
|
||||
accum_loss=None,
|
||||
outputs=None,
|
||||
)
|
||||
print(
|
||||
f"chunk 2 id 0 (layer 2)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
|
||||
)
|
||||
|
||||
# chunk 3 id 0 (layer 3) fwd
|
||||
if rank == 3:
|
||||
chunk_id = 0
|
||||
scheduler.schedule_f(
|
||||
scheduled_node=None,
|
||||
model_chunk=chunk_3,
|
||||
model_chunk_id=chunk_id,
|
||||
input_obj=None,
|
||||
criterion=criterion,
|
||||
accum_loss=None,
|
||||
outputs=None,
|
||||
)
|
||||
print(
|
||||
f"chunk 3 id 0 (layer 3)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
|
||||
)
|
||||
|
||||
######
|
||||
# fwd 4->1
|
||||
######
|
||||
|
||||
if rank == 3:
|
||||
chunk_id = 1
|
||||
scheduler.schedule_f(
|
||||
scheduled_node=None,
|
||||
model_chunk=chunk_3,
|
||||
model_chunk_id=chunk_id,
|
||||
input_obj=None,
|
||||
criterion=criterion,
|
||||
accum_loss=None,
|
||||
outputs=None,
|
||||
)
|
||||
print(
|
||||
f"chunk 3 id 1 (layer 4)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
|
||||
)
|
||||
|
||||
if rank == 2:
|
||||
chunk_id = 1
|
||||
scheduler.schedule_f(
|
||||
scheduled_node=None,
|
||||
model_chunk=chunk_2,
|
||||
model_chunk_id=chunk_id,
|
||||
input_obj=None,
|
||||
criterion=criterion,
|
||||
accum_loss=None,
|
||||
outputs=None,
|
||||
)
|
||||
print(
|
||||
f"chunk 2 id 1 (layer 5)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
|
||||
)
|
||||
|
||||
if rank == 1:
|
||||
chunk_id = 1
|
||||
scheduler.schedule_f(
|
||||
scheduled_node=None,
|
||||
model_chunk=chunk_1,
|
||||
model_chunk_id=chunk_id,
|
||||
input_obj=None,
|
||||
criterion=criterion,
|
||||
accum_loss=None,
|
||||
outputs=None,
|
||||
)
|
||||
print(
|
||||
f"chunk 1 id 1 (layer 6)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
|
||||
)
|
||||
|
||||
if rank == 0:
|
||||
chunk_id = 1
|
||||
scheduler.schedule_f(
|
||||
scheduled_node=None,
|
||||
model_chunk=chunk_0,
|
||||
model_chunk_id=chunk_id,
|
||||
input_obj=None,
|
||||
criterion=criterion,
|
||||
accum_loss=None,
|
||||
outputs=None,
|
||||
)
|
||||
# print(f"fwd output {output7}")
|
||||
print(
|
||||
f"chunk 0 id 1 (layer 7)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
|
||||
)
|
||||
|
||||
##########################
|
||||
# Step2: bwd
|
||||
##########################
|
||||
######
|
||||
# bwd rank 4->1
|
||||
######
|
||||
# chunk 0 id 1 (layer 7) bwd
|
||||
if rank == 0:
|
||||
chunk_id = 1
|
||||
scheduler.schedule_b(
|
||||
scheduled_node=None,
|
||||
model_chunk=chunk_0,
|
||||
model_chunk_id=chunk_id,
|
||||
# optimizer: OptimizerWrapper,
|
||||
)
|
||||
|
||||
# # chunk 1 id 1 (layer 6) bwd
|
||||
if rank == 1:
|
||||
chunk_id = 1
|
||||
scheduler.schedule_b(
|
||||
scheduled_node=None,
|
||||
model_chunk=chunk_1,
|
||||
model_chunk_id=chunk_id,
|
||||
# optimizer: OptimizerWrapper,
|
||||
)
|
||||
|
||||
# chunk 2 id 1 (layer 5) bwd
|
||||
if rank == 2:
|
||||
chunk_id = 1
|
||||
scheduler.schedule_b(
|
||||
scheduled_node=None,
|
||||
model_chunk=chunk_2,
|
||||
model_chunk_id=chunk_id,
|
||||
# optimizer: OptimizerWrapper,
|
||||
)
|
||||
|
||||
# chunk 3 id 1 (layer 4) bwd
|
||||
if rank == 3:
|
||||
chunk_id = 1
|
||||
scheduler.schedule_b(
|
||||
scheduled_node=None,
|
||||
model_chunk=chunk_3,
|
||||
model_chunk_id=chunk_id,
|
||||
# optimizer: OptimizerWrapper,
|
||||
)
|
||||
|
||||
# ######
|
||||
# # bwd rank 1->4
|
||||
# ######
|
||||
|
||||
# chunk 3 id 0 (layer 3) bwd
|
||||
if rank == 3:
|
||||
chunk_id = 0
|
||||
scheduler.schedule_b(
|
||||
scheduled_node=None,
|
||||
model_chunk=chunk_3,
|
||||
model_chunk_id=chunk_id,
|
||||
# optimizer: OptimizerWrapper,
|
||||
)
|
||||
# print(f"input_grad3 {input_grad3}")
|
||||
|
||||
# chunk 2 id 0 (layer 2) bwd
|
||||
if rank == 2:
|
||||
chunk_id = 0
|
||||
scheduler.schedule_b(
|
||||
scheduled_node=None,
|
||||
model_chunk=chunk_2,
|
||||
model_chunk_id=chunk_id,
|
||||
# optimizer: OptimizerWrapper,
|
||||
)
|
||||
# print(f"input_grad2 {input_grad2}")
|
||||
|
||||
# chunk 1 id 0 (layer 1) bwd
|
||||
if rank == 1:
|
||||
chunk_id = 0
|
||||
scheduler.schedule_b(
|
||||
scheduled_node=None,
|
||||
model_chunk=chunk_1,
|
||||
model_chunk_id=chunk_id,
|
||||
# optimizer: OptimizerWrapper,
|
||||
)
|
||||
|
||||
# chunk 0 id 0 (layer 0) bwd
|
||||
if rank == 0:
|
||||
chunk_id = 0
|
||||
scheduler.schedule_b(
|
||||
scheduled_node=None,
|
||||
model_chunk=chunk_0,
|
||||
model_chunk_id=chunk_id,
|
||||
# optimizer: OptimizerWrapper,
|
||||
)
|
||||
# print(f"input_grad0 {input_grad0}")
|
||||
|
||||
|
||||
# @pytest.mark.dist
|
||||
# @pytest.mark.parametrize("num_microbatch", [4])
|
||||
# @pytest.mark.parametrize("batch_size", [4])
|
||||
# @pytest.mark.parametrize("num_model_chunk", [2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_pp():
|
||||
spawn(
|
||||
test_zerobubble_pipeline_base,
|
||||
nprocs=4,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
test_pp()
|
Loading…
Reference in New Issue