From 11ae6848c69e04c2a48487586a9ca1160749c8cd Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 10 Sep 2024 17:33:09 +0800 Subject: [PATCH] [zerobubble]Support ZeroBubble Pipeline (#6034) * [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble; * [feat] add dw test; * [fix] fix weight not close; * [update] update text; * [feat] add test run_fwd_bwd automatic scheduling; * [feat] split communication and calculation; fix pop empty send_bwd_buffer error; * [feat] add test for p & p grad; * [feat] add comments for ZBV func; * [fix] rm useless assign and comments; * [fix] fix ci test; add pytest; * [feat] add run_fwd_bwd_with_microbatch (replace input) & test; add p&p.grad assert close test & all pass; * [feat] add apply v_schedule graph; p & p.grad assert err exist; * [fix] update * [feat] fix ci; add assert; * [feat] fix poc format * [feat] fix func name & ci; add comments; * [fix] fix poc test; add comments in poc; * [feat] add optim backward_b_by_grad * [feat] fix optimizer bwd b & w; support return accum loss & output * [feat] add fwd_bwd_step, run_fwd_only; * [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict; * [fix] fix communication_map; * [feat] update test; rm comments; * [fix] rm zbv in hybridplugin * [fix] fix optim bwd; * [fix] fix optim bwd; * [fix] rm output.data after send fwd; * [fix] fix bwd step if condition; remove useless comments and format info; * [fix] fix detach output & release output; * [fix] rm requir_grad for output; * [fix] fix requir grad position and detach position and input&output local buffer append position; * [feat] add memory assertation; * [fix] fix mem check; * [fix] mem assertation' * [fix] fix mem assertation * [fix] fix mem; use a new model shape; only assert mem less and equal than theo; * [fix] fix model zoo import; * [fix] fix redundant detach & clone; add buffer assertation in the end; * [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap; * [fix] update optim state dict assert (include param group & state); fix mem assert after add optim; * [fix] add testcase with microbatch 4; --- .../booster/plugin/hybrid_parallel_plugin.py | 2 +- colossalai/interface/optimizer.py | 21 +- colossalai/pipeline/__init__.py | 3 +- colossalai/pipeline/schedule/__init__.py | 2 + colossalai/pipeline/schedule/v_schedule.py | 494 ++++++++++ .../pipeline/schedule/zero_bubble_pp.py | 844 ++++++++++++++++++ .../test_schedule/test_zerobubble_pp.py | 769 ++++++++++++++++ 7 files changed, 2131 insertions(+), 4 deletions(-) create mode 100644 colossalai/pipeline/schedule/v_schedule.py create mode 100644 colossalai/pipeline/schedule/zero_bubble_pp.py create mode 100644 tests/test_pipeline/test_schedule/test_zerobubble_pp.py diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index b4b40020f..1b3b765c2 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1103,7 +1103,7 @@ class HybridParallelPlugin(PipelinePluginBase): self.stage_manager = PipelineStageManager( self.pg_mesh, pipeline_axis=self.pp_axis, - enable_interleave=pp_style == "interleaved", + enable_interleave=(pp_style == "interleaved"), num_model_chunks=num_model_chunks, num_layers_per_stage=num_layers_per_stage, ) diff --git a/colossalai/interface/optimizer.py b/colossalai/interface/optimizer.py index 6cd74b3b4..a236434a5 100644 --- a/colossalai/interface/optimizer.py +++ b/colossalai/interface/optimizer.py @@ -55,8 +55,25 @@ class OptimizerWrapper: """ loss.backward(*args, **kwargs) - def backward_by_grad(self, tensor: Tensor, grad: Tensor): - torch.autograd.backward(tensor, grad) + def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False): + """ + Performs a backward pass for dx or dw, + for dx, we only calculate dx = w*dy here + for dw, we only calculate dw = x*dy here + + Args: + tensor (Tensor): y or loss of current chunk; + grad_tensors (Tensor): dy of current chunk; + input_obj (Tensor): for dx, input_obj is x of current chunk; + for dw, input_obj is w of current chunk; + retain_graph (bool): default to be True, we retain graph in backward_b + """ + torch.autograd.backward( + tensors=tensor, + grad_tensors=grad, + inputs=inputs, + retain_graph=retain_graph, + ) def state_dict(self): """ diff --git a/colossalai/pipeline/__init__.py b/colossalai/pipeline/__init__.py index 4754212c1..5d44530e7 100644 --- a/colossalai/pipeline/__init__.py +++ b/colossalai/pipeline/__init__.py @@ -1,11 +1,12 @@ from .p2p import PipelineP2PCommunication -from .schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, PipelineSchedule +from .schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, PipelineSchedule, ZeroBubbleVPipeScheduler from .stage_manager import PipelineStageManager __all__ = [ "PipelineSchedule", "OneForwardOneBackwardSchedule", "InterleavedSchedule", + "ZeroBubbleVPipeScheduler", "PipelineP2PCommunication", "PipelineStageManager", ] diff --git a/colossalai/pipeline/schedule/__init__.py b/colossalai/pipeline/schedule/__init__.py index 6845dc237..05dd24e81 100644 --- a/colossalai/pipeline/schedule/__init__.py +++ b/colossalai/pipeline/schedule/__init__.py @@ -1,9 +1,11 @@ from .base import PipelineSchedule from .interleaved_pp import InterleavedSchedule from .one_f_one_b import OneForwardOneBackwardSchedule +from .zero_bubble_pp import ZeroBubbleVPipeScheduler __all__ = [ "PipelineSchedule", "OneForwardOneBackwardSchedule", "InterleavedSchedule", + "ZeroBubbleVPipeScheduler", ] diff --git a/colossalai/pipeline/schedule/v_schedule.py b/colossalai/pipeline/schedule/v_schedule.py new file mode 100644 index 000000000..9eebebdea --- /dev/null +++ b/colossalai/pipeline/schedule/v_schedule.py @@ -0,0 +1,494 @@ +# Refer from Zero Bubble Pipeline Parallelism. +# Github: https://github.com/sail-sg/zero-bubble-pipeline-parallelism +# Paper: https://arxiv.org/abs/2401.10241 +# The following applies to all files unless otherwise noted: +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from collections import deque +from dataclasses import dataclass + + +@dataclass(eq=True, frozen=True) +class ScheduledNode: + type: str + chunk: int + stage: int + minibatch: int + start_time: int = 0 + completion_time: int = 0 + rollback: bool = False + + +class PipelineGraph(object): + """PipelineGraph""" + + def __init__( + self, + n_stage, + n_micro, + f_cost, + b_cost, + w_cost, + c_cost, + f_mem, + b_mem, + w_mem, + max_mem=None, + ): + self.n_node = 6 * n_stage * n_micro + self.n_stage = n_stage + self.n_micro = n_micro + self.f_cost = f_cost + self.b_cost = b_cost + self.w_cost = w_cost + self.c_cost = c_cost + self.f_mem = f_mem + self.b_mem = b_mem + self.w_mem = w_mem + self.fbw_cost = [f_cost, b_cost, w_cost] + self.fbw_mem = [f_mem, b_mem, w_mem] + self.max_mem = max_mem or f_mem * self.n_stage * 2 + + def get_id(self, cat, chunk, stage, micro): + return ( + cat * 2 * self.n_stage * self.n_micro + chunk * self.n_stage * self.n_micro + stage * self.n_micro + micro + ) + + def try_v_schedule(self, fill_f=True, fill_b=True, approved_bubble=None): + count = [] + for i in range(self.n_stage): + count.append([0] * 6) + + end_time = [-1] * self.n_node + cur_time = [0] * self.n_stage + mem = [0] * self.n_stage + stage_bubble = [0] * self.n_stage + pending_w = [deque() for _ in range(self.n_stage)] + schedule = [[] for _ in range(self.n_stage)] + stage_str = [" " * i for i in range(self.n_stage)] + + if approved_bubble is None: + approved_bubble = [-1] * self.n_stage + max_approved_bubble = max(approved_bubble) + + def get_max_stage_bubble(stage=-1): + max_stage_bubble = 0 + for bb in stage_bubble: + max_stage_bubble = max(max_stage_bubble, bb) + if stage >= 0: + max_stage_bubble = max(max_stage_bubble, max_approved_bubble - approved_bubble[stage]) + return max_stage_bubble + + def put_w(stage): + assert len(pending_w[stage]) > 0 + _, chunk_, _ = pending_w[stage].popleft() + put(2, chunk_, stage) + + def put(cat, chunk, stage, assert_cnt=True): + _tmp = _no_bubble = cur_time[stage] + self.fbw_cost[cat] + _cnt = count[stage][cat * 2 + chunk] + # assert _cnt < self.n_micro + if _cnt >= self.n_micro: + if not assert_cnt: + stage_str[stage] += " " + cur_time[stage] = _tmp # TODO + return + assert False + assert mem[stage] + self.fbw_mem[cat] <= self.max_mem + stage_str[stage] += "FfBbWw"[cat * 2 + chunk] + str(_cnt + 1) + " " * (3 - len(str(_cnt + 1))) + if cat > 0 or chunk > 0: + last_id = cat * 2 + chunk - 1 + if cat < 2: + # if end_time[self.get_id(last_id // 2, last_id % 2, stage, _cnt)] < 0: + # print(cat, chunk, stage, _cnt) + # self.print_details(end_time) + assert end_time[self.get_id(last_id // 2, last_id % 2, stage, _cnt)] >= 0 + else: + assert end_time[self.get_id(1, chunk, stage, _cnt)] >= 0 + if chunk == 1 and cat < 2: + if stage < self.n_stage - 1: + _fa_id = self.get_id(cat, chunk, stage + 1, _cnt) + assert end_time[_fa_id] >= 0 + _tmp = max(_tmp, end_time[_fa_id] + self.c_cost + self.fbw_cost[cat]) + if chunk == 0 and cat < 2: + if stage > 0: + _fa_id = self.get_id(cat, chunk, stage - 1, _cnt) + # if end_time[_fa_id] < 0: + # print(cat, chunk, stage, _cnt) + # self.print_details(end_time) + assert end_time[_fa_id] >= 0, f"{cat}, {chunk}, {stage}, {_cnt}" + _tmp = max(_tmp, end_time[_fa_id] + self.c_cost + self.fbw_cost[cat]) + _id = self.get_id(cat, chunk, stage, _cnt) + if count[stage][0] > 0: + stage_bubble[stage] += _tmp - _no_bubble + end_time[_id] = _tmp + cur_time[stage] = _tmp + mem[stage] += self.fbw_mem[cat] + # noinspection PyTypeChecker + schedule[stage].append((cat, chunk, _cnt)) + if cat == 1: + pending_w[stage].append((2, chunk, _cnt)) + count[stage][cat * 2 + chunk] += 1 + + # for _ in range(2 * self.n_stage): + # for i in range(self.n_stage): + # if count[i][1] >= count[i][0]: + # put(0, 0, i, assert_cnt=False) + # continue + # if i == self.n_stage - 1: + # put(0, 1, i, assert_cnt=False) + # continue + # fa_id = self.get_id(0, 1, i + 1, count[i][1]) + # if 0 <= end_time[fa_id] < cur_time[i + 1]: # TODO + # put(0, 1, i, assert_cnt=False) + # else: + # put(0, 0, i, assert_cnt=False) + + for i in range(self.n_stage): + put(0, 0, i) + for i in range(self.n_stage - 1, -1, -1): + if i == self.n_stage - 1: + put(0, 1, i) + continue + tmp = end_time[self.get_id(0, 1, i + 1, 0)] + self.c_cost + while ( + mem[i] + self.fbw_mem[0] * (2 + i * 2) <= self.max_mem + and cur_time[i] + self.fbw_cost[0] <= tmp + and count[i][0] < self.n_micro + ): + for j in range(i + 1): + put(0, 0, j) + put(0, 1, i) + iter_chunk_ = 0 + end_tmp = 0 + for i in range(self.n_stage): + if i == 0: + end_tmp = cur_time[0] + self.fbw_cost[1] + continue + tmp = end_tmp + self.c_cost + while ( + count[i][0] + count[i][1] < count[i - 1][0] + count[i - 1][1] + or count[i][1] <= count[i - 1][1] < self.n_micro + ): + for j in range(self.n_stage - 1, i - 1, -1): + if count[j][iter_chunk_] < self.n_micro: + put(0, iter_chunk_, j) + iter_chunk_ = 1 - iter_chunk_ + # while mem[i] + self.fbw_mem[0] <= self.max_mem and cur_time[i] + self.fbw_cost[0] <= tmp: + # if iter_chunk_ == 0 and count[i][0] >= count[i - 1][0]: + # break + # for j in range(self.n_stage - 1, i - 1, -1): + # if count[j][iter_chunk_] < self.n_micro: + # put(0, iter_chunk_, j) + # iter_chunk_ = 1 - iter_chunk_ + # end_tmp = max(tmp, cur_time[i]) + self.fbw_cost[1] + + # init_bubble = get_max_stage_bubble() + # print(stage_bubble) + for _ in range(2 * self.n_micro): + # check mem before putting b + for i in range(self.n_stage): + while mem[i] + self.fbw_mem[1] > self.max_mem: + assert len(pending_w[i]) > 0 + put_w(i) + b0_ranks, b1_ranks = [], [] + for i in range(self.n_stage): + if count[i][3] >= count[i][2]: + b0_ranks.append(i) + elif i == self.n_stage - 1: + b1_ranks.append(i) + else: + fa_id = self.get_id(1, 1, i + 1, count[i][3]) + if end_time[fa_id] >= 0 or count[i][2] >= self.n_micro: + b1_ranks.append(i) + else: + b0_ranks.append(i) + b_ranks = [] + # put b1 + for i in reversed(b1_ranks): + b_ranks.append((i, 1)) + # put b0 + for i in b0_ranks: + b_ranks.append((i, 0)) + for i, _chunk_ in b_ranks: + fa_id = -1 + if _chunk_ == 1 and i < self.n_stage - 1: + fa_id = self.get_id(1, 1, i + 1, count[i][3]) + if _chunk_ == 0 and i > 0: + fa_id = self.get_id(1, 0, i - 1, count[i][2]) + while ( + len(pending_w[i]) > 0 + and fa_id >= 0 + and end_time[fa_id] + self.c_cost >= cur_time[i] + self.fbw_cost[2] + ): + # fill the bubble + put_w(i) + if ( + len(pending_w[i]) > 0 + and end_time[fa_id] + self.c_cost - cur_time[i] > get_max_stage_bubble(i) - stage_bubble[i] + ): + if _chunk_ == 1: + put_w(i) + elif fill_b: + put_w(i) + put(1, _chunk_, i) + + # put f + for i in range(self.n_stage): + if count[i][1] >= self.n_micro: + continue + put_item = None + if count[i][1] >= count[i][0]: + put_item = 0 + elif i == self.n_stage - 1: + put_item = 1 + else: + if end_time[self.get_id(0, 1, i + 1, count[i][1])] >= 0: + put_item = 1 + elif count[i][0] < self.n_micro: + if i == 0: + put_item = 0 + elif end_time[self.get_id(0, 0, i - 1, count[i][0])] >= 0: + put_item = 0 + if put_item is None: + continue + # check mem before putting f + while mem[i] + self.fbw_mem[0] > self.max_mem: + assert len(pending_w[i]) > 0 + put_w(i) + fa_id = -1 + if put_item == 0 and i > 0: + fa_id = self.get_id(0, 0, i - 1, count[i][0]) + if put_item == 1 and i < self.n_stage - 1: + fa_id = self.get_id(0, 1, i + 1, count[i][1]) + while ( + len(pending_w[i]) > 0 + and fa_id >= 0 + and end_time[fa_id] + self.c_cost >= cur_time[i] + self.fbw_cost[2] + ): + # fill the bubble + put_w(i) + if ( + len(pending_w[i]) > 0 + and end_time[fa_id] + self.c_cost - cur_time[i] > get_max_stage_bubble(i) - stage_bubble[i] + ): + if fill_f: + put_w(i) + put(0, put_item, i) + + for i in range(self.n_stage): + while len(pending_w[i]) > 0: + put_w(i) + + # for i in range(self.n_stage): + # print(stage_str[i]) + + max_bubble = get_max_stage_bubble() + expected_time = sum(self.fbw_cost) * self.n_micro * 2 + max_bubble / expected_time + # print("%6.4f" % bubble_rate, "->", stage_bubble) + if max_approved_bubble < 0 or max_bubble < max_approved_bubble: + _schedule, _end_time, _max_bubble = self.try_v_schedule( + fill_f=fill_f, + fill_b=fill_b, + approved_bubble=stage_bubble, + ) + if _max_bubble < max_bubble: + return _schedule, _end_time, _max_bubble + # print("%2d %3d, [%5d %5d %5d], %6d -> %6.4f %6.4f" % \ + # (self.n_stage, self.n_micro, *self.fbw_cost, self.max_mem // self.f_mem, init_bubble / expected_time, bubble_rate), max_bubble) + return schedule, end_time, max_bubble + + def print_details(self, end_time, print_scaling=1): + for stage in range(self.n_stage): + stage_str = ["."] * int(max(end_time) / print_scaling) + for _cat in range(3): + for _chunk in range(2): + for _micro in range(self.n_micro): + _id = self.get_id(_cat, _chunk, stage, _micro) + if end_time[_id] < 0: + continue + end = int(end_time[_id] / print_scaling) + start = int((end_time[_id] - self.fbw_cost[_cat]) / print_scaling) + for j in range(start, end): + if j == start or j == end - 1: + stage_str[j] = "FfBbWw"[_cat * 2 + _chunk] + elif j == start + 1: + if _micro >= 10: + stage_str[j] = str(_micro // 10) + else: + stage_str[j] = str(_micro) + elif j == start + 2 and _micro >= 10: + stage_str[j] = str(_micro % 10) + else: + stage_str[j] = "-" + _str = "" + for _c in stage_str: + _str += _c + print(_str) + + def get_v_schedule(self, only_run_time=False): + schedule, end_time, max_bubble = None, None, None + expected_time = sum(self.fbw_cost) * self.n_micro * 2 + for fill_b in [True, False]: + for fill_f in [True, False]: + _schedule, _end_time, _max_bubble = self.try_v_schedule(fill_b=fill_b, fill_f=fill_f) + # print("") + if max_bubble is None or _max_bubble < max_bubble: + max_bubble = _max_bubble + schedule = _schedule + end_time = _end_time + if only_run_time: + return max_bubble + expected_time + # self.print_details(end_time, print_scaling=1) + max_bubble / (expected_time + max_bubble) + # print("%2d %3d, [%5d %5d %5d %5d], %6d -> %6.4f" % \ + # (self.n_stage, self.n_micro, *self.fbw_cost, self.c_cost, self.max_mem // self.f_mem, bubble_rate)) + local_order = [[] for _ in range(self.n_stage)] + comm_id = {} + comm_id_counter = 0 + post_validation_time = 0 + for i in range(self.n_stage - 1, -1, -1): + pv_id = min(2 * (self.n_stage - 1 - i), self.n_micro - 1) + post_validation_time = max( + post_validation_time, end_time[self.get_id(0, 0, i, pv_id)] - self.fbw_cost[0] - self.c_cost + ) + # post_validation_time = 0 + # print(i, pv_id, post_validation_time) + for it in ["RECV_", "SEND_", ""]: + if i == 0 and it == "SEND_": + continue + if i == self.n_stage - 1 and it == "RECV_": + continue + # stage_ = i - 1 if it == "RECV_" else i + stage_ = i + local_order[stage_].append( + ScheduledNode( + type=it + "POST_VALIDATION", + chunk=0, + stage=stage_, + minibatch=0, + start_time=post_validation_time, + completion_time=post_validation_time, + ) + ) + comm_id[local_order[stage_][-1]] = comm_id_counter + comm_id_counter += 1 + for i in range(self.n_stage): + for _cat_, _chunk_, _micro_ in schedule[i]: + complete_time = end_time[self.get_id(_cat_, _chunk_, i, _micro_)] + local_order[i].append( + ScheduledNode( + type="FBW"[_cat_], + chunk=_chunk_ if _cat_ == 0 else 1 - _chunk_, + stage=i, + minibatch=_micro_, + start_time=complete_time - self.fbw_cost[_cat_], + completion_time=complete_time, + ) + ) + if _cat_ == 2: # no communication for W + continue + cat_str = "FORWARD" if _cat_ == 0 else "BACKWARD" + + def communicate(send_recv, stage_): + # noinspection PyTypeChecker + local_order[stage_].append( + ScheduledNode( + type=send_recv + cat_str, + chunk=_chunk_ if _cat_ == 0 else 1 - _chunk_, + stage=stage_, + minibatch=_micro_, + start_time=complete_time, + completion_time=complete_time, + ) + ) + comm_id[local_order[stage_][-1]] = comm_id_counter + + if _chunk_ == 1 and i > 0: + communicate("SEND_", i) + communicate("RECV_", i - 1) + if _chunk_ == 0 and i < self.n_stage - 1: + communicate("SEND_", i) + communicate("RECV_", i + 1) + comm_id_counter += 1 + for rank in range(self.n_stage): + # For nodes with the same timestamp on the same stage, communication will be prioritized. + def even_breaker(x: ScheduledNode): + # Compute nodes are always delayed. + if x.type in ["F", "B", "W"]: + return comm_id_counter + # For comm nodes, order by their unique comm id + return comm_id[x] + + local_order[rank] = list(sorted(local_order[rank], key=lambda x: (x.start_time, even_breaker(x)))) + # If a recv with intersects with previous computation, reorder them so that recv + # is executed before computation and hence can be overlapped. + for i in range(len(local_order[rank])): + if ( + i > 0 + and local_order[rank][i - 1].type in {"F", "B", "W"} + and local_order[rank][i].type.startswith("RECV") + and "POST_VALIDATION" not in local_order[rank][i].type + and local_order[rank][i].start_time <= local_order[rank][i - 1].completion_time + ): + local_order[rank][i], local_order[rank][i - 1] = local_order[rank][i - 1], local_order[rank][i] + + local_order_with_rollback = [[] for _ in range(self.n_stage)] + for rank in range(self.n_stage): + rollback_comm = set() + if rank > 0: + for node in local_order[rank - 1]: + if node.type == "POST_VALIDATION": + break + if node.type == "SEND_FORWARD": + assert node.chunk == 0 + rollback_comm.add(node.minibatch) + for node in local_order[rank]: + if node.type == "RECV_FORWARD" and node.chunk == 0 and node.minibatch in rollback_comm: + rollback = True + rollback_comm.remove(node.minibatch) + else: + rollback = False + local_order_with_rollback[rank].append( + ScheduledNode( + type=node.type, + chunk=node.chunk, + stage=node.stage, + minibatch=node.minibatch, + start_time=node.start_time, + completion_time=node.completion_time, + rollback=rollback, + ) + ) + assert len(rollback_comm) == 0 + # for node in local_order_with_rollback[rank]: + # print(f"Rank {rank} Node info {node}") + # print(f"{node.type}-{node.minibatch}-{int(node.rollback)}", end=", ") + # print() + + return local_order_with_rollback diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py new file mode 100644 index 000000000..c1c4f13c6 --- /dev/null +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -0,0 +1,844 @@ +from functools import partial +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union + +import torch +import torch.cuda +from torch.nn import Module, ModuleList +from torch.utils._pytree import tree_map + +from colossalai.accelerator import get_accelerator +from colossalai.interface import OptimizerWrapper +from colossalai.pipeline.p2p import PipelineP2PCommunication +from colossalai.pipeline.schedule.v_schedule import ScheduledNode +from colossalai.pipeline.stage_manager import PipelineStageManager + +from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, retain_grad, to_device +from .base import PipelineSchedule + +AUTO_SCHEDULE_COMMUNICATION_TYPES = {"RECV_FORWARD", "RECV_BACKWARD", "SEND_FORWARD", "SEND_BACKWARD"} + + +def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None: + if wait_handles is not None: + for req in wait_handles: + req.wait() + + +def deallocate_output_tensor(out, deallocate_pipeline_outputs=False): + """Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field. + + This method should be called right after the output tensor has been + sent to the next pipeline stage. At this point, the output tensor is + only useful for its '.grad_fn' field, and not its '.data'. + """ + if (out is None) or (not deallocate_pipeline_outputs): + return + assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__ + assert out._base is None, "counter-productive to free a view of another tensor." + # out.data = torch.empty((1,), device=out.device, dtype=out.dtype,) + out.data.untyped_storage().resize_(0) + + +class ZeroBubbleVPipeScheduler(PipelineSchedule): + def __init__( + self, + stage_manager: PipelineStageManager, + schedule: List[ScheduledNode], + num_model_chunks: int, + num_microbatch: Optional[int] = None, + microbatch_size: Optional[int] = None, + enable_metadata_cache: bool = True, + overlap_p2p: bool = True, + ): + super().__init__(stage_manager) + # batch info + self.num_microbatch = num_microbatch + self.microbatch_size = microbatch_size + self.num_model_chunks = num_model_chunks + self.batch: Any + self.batch_size: int + self.last_batch_size: Optional[int] = None + self.microbatch_offset: List[int] + + self.schedules = schedule + # TODO: optim post valid + self.do_post_validation = False + + # P2PMeta cache + # self.enable_metadata_cache = enable_metadata_cache + # self.send_tensor_metadata = True + # self.send_grad_metadata = True + # self.tensor_metadata_recv = None + # self.grad_metadata_recv = None + + # P2P communication + self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p) + + # init communication map + self.communication_map = { + "SEND_FORWARD": self.send_forward, + "RECV_FORWARD": self.recv_forward, + "SEND_BACKWARD": self.send_backward, + "RECV_BACKWARD": self.recv_backward, + } + + # init buffer + self._free_buffers() + + def _free_buffers(self): + # free local buffer + # two dim array, first dim is the model chunk, second dim is the microbatch queue + + # x & y buffer for schedule b + self.input_tensors = [[], []] + self.output_tensors = [[], []] + + # y & dy buffer for schedule w + self.output_tensors_dw = [[], []] + self.output_tensors_grad_dw = [[], []] + + # buffer for communication + self.send_forward_buffer = [[], []] + self.recv_forward_buffer = [[], []] + self.send_backward_buffer = [[], []] + self.recv_backward_buffer = [[], []] + + # y buffer for local send fwd + self.local_send_forward_buffer = [] + # dy buffer for local send bwd + self.local_send_backward_buffer = [] + + def assert_buffer_empty(self): + # assert buuffer is empty at end + assert len(self.input_tensors[0]) == 0 + assert len(self.input_tensors[1]) == 0 + assert len(self.output_tensors[0]) == 0 + assert len(self.output_tensors[1]) == 0 + assert len(self.output_tensors_dw[0]) == 0 + assert len(self.output_tensors_dw[1]) == 0 + assert len(self.output_tensors_grad_dw[0]) == 0 + assert len(self.output_tensors_grad_dw[1]) == 0 + assert len(self.send_forward_buffer[0]) == 0 + assert len(self.send_forward_buffer[1]) == 0 + assert len(self.recv_forward_buffer[0]) == 0 + assert len(self.recv_forward_buffer[1]) == 0 + assert len(self.send_backward_buffer[0]) == 0 + assert len(self.send_backward_buffer[1]) == 0 + assert len(self.recv_backward_buffer[0]) == 0 + assert len(self.recv_backward_buffer[1]) == 0 + assert len(self.local_send_forward_buffer) == 0 + assert len(self.local_send_backward_buffer) == 0 + + def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: + """Load a batch from data iterator. + + Args: + data_iter (Iterable): Data iterator. + device (Optional[torch.device], optional): Target device. Defaults to None. + """ + batch = next(data_iter) + if device is not None: + batch = tree_map(partial(to_device, device=device), batch) + + self.microbatch_offset = [0 for _ in range(self.num_model_chunks)] + self.batch = batch + self.batch_size = get_batch_size(batch) + + if self.microbatch_size is None: + assert self.batch_size % self.num_microbatch == 0, "Batch size should divided by the number of microbatch" + self.microbatch_size = self.batch_size // self.num_microbatch + if self.num_microbatch is None: + assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size" + self.num_microbatch = self.batch_size // self.microbatch_size + + if not self.forward_only: + assert self.last_batch_size is None or self.last_batch_size == self.batch_size + assert self.batch_size == self.microbatch_size * self.num_microbatch + + assert ( + self.num_microbatch % self.stage_manager.num_stages == 0 + ), "Number of microbatch should be an integer multiple of number of pipeline parallel devices" + + if self.forward_only: + self.num_microbatch = (self.batch_size - 1) // self.microbatch_size + 1 + # NOTE: disable metadata cache when batch size changes (not valid anymore) + # if self.batch_size != self.last_batch_size: + # self.enable_metadata_cache = False + # self.send_tensor_metadata = True + # self.send_grad_metadata = True + # self.tensor_metadata_recv = None + # self.grad_metadata_recv = None + + self.last_batch_size = self.batch_size + + def load_micro_batch(self, model_chunk_id: int) -> Any: + """Load a micro batch from the current batch. + + Args: + microbatch_id (int): the current model chunk idx. + + Returns: + Any: Micro batch. + """ + assert self.microbatch_offset[model_chunk_id] <= self.batch_size, "Microbatches exhausted" + micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size) + self.microbatch_offset[model_chunk_id] += self.microbatch_size + return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch) + + def get_model_chunk_id(self, microbatch_id: int, is_forward: bool) -> int: + """Helper method to get the model chunk ID given the iteration number. + + Args: + microbatch_id (int): the current microbatch idx + forward (bool): if is the forward process + + Returns: + int: The model chunk idx of the input microbatch_id + """ + assert ( + microbatch_id < self.num_microbatch * self.num_model_chunks + ), f"microbatch_id {microbatch_id} is out of range ({self.num_microbatch * self.num_model_chunks})" + microbatch_id_in_group = microbatch_id % (self.stage_manager.num_stages * self.num_model_chunks) + model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages + if not is_forward: + # Reverse order + model_chunk_id = self.num_model_chunks - model_chunk_id - 1 + return model_chunk_id + + def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Tuple[Any, List]: + """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. + For ZBV. + + Args: + model_chunk_id (int): The current model chunk idx. + prev_rank (int, optional): The rank of the source of the tensor. + + Returns: + Any: The input tensor or input tensor list. + Any: The wait handles for the communication. + """ + with self.stage_manager.switch_model_chunk_id(model_chunk_id): + if model_chunk_id == 0: + ################ + # chunk = 0 & is_first_stage + # do nothing; cause u are chunk 0 in first rank, u have no prev rank; + ################# + if self.stage_manager.is_first_stage(ignore_chunk=True): + return None, [] + + ################ + # chunk = 0 & not is_first_stage + # Recv y from PREV_rank as input + ################# + else: + prev_rank = self.stage_manager.get_prev_rank() + input_tensor, wait_handles = self.comm.recv_forward(prev_rank=prev_rank) + self.recv_forward_buffer[model_chunk_id].append(input_tensor) + return input_tensor, wait_handles + + else: + ################ + # chunk = 1 & is_last_stage + # do nothing; cause u get y from local_send_forward_buffer in schedule f + ################ + if self.stage_manager.is_last_stage(ignore_chunk=True): + return None, [] + + ################ + # chunk = 1 & not is_last_stage + # recv y from NEXT_rank as input + ################ + else: + next_rank = self.stage_manager.get_next_rank() + input_tensor, wait_handles = self.comm.recv_forward(next_rank) + self.recv_forward_buffer[model_chunk_id].append(input_tensor) + return input_tensor, wait_handles + + def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Tuple[Any, List]: + """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. + For ZBV. + + Args: + model_chunk_id (int): The current model chunk idx. + next_rank (int, optional): The rank of the source of the tensor. + + Returns: + Any: The input gradient tensor or gradient tensor list. + Any: The wait handles for the communication. + """ + with self.stage_manager.switch_model_chunk_id(model_chunk_id): + if model_chunk_id == 0: + # bwd chunk0 is right V; + ################ + # chunk = 0 & is_last_stage + # do nothing; Already get dy from local_send_backward_buffer in schedule b + ################ + if self.stage_manager.is_last_stage(ignore_chunk=True): + return None, [] + + ################ + # chunk = 0 & not is_last_stage + # Recv bwd from next stage; + ################ + else: + next_rank = self.stage_manager.get_next_rank() + output_tensor_grad, wait_handles = self.comm.recv_backward(next_rank) + self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad) + return output_tensor_grad, wait_handles + + else: + # bwd chunk1 is left V; + ################ + # chunk = 1 & is_first_stage + # do nothing; get loss from local + ################ + if self.stage_manager.is_first_stage(ignore_chunk=True): + return None, [] + + ################ + # chunk = 1 & not first stage + # recv_backward recv bwd from prev stage; + ################ + else: + prev_rank = self.stage_manager.get_prev_rank() + output_tensor_grad, wait_handles = self.comm.recv_backward(next_rank=prev_rank) + self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad) + return output_tensor_grad, wait_handles + + def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List: + """Sends the input tensor to the next stage in pipeline. + For ZBV. + + Args: + model_chunk_id (int): The current model chunk idx. + next_rank (int, optional): The rank of the recipient of the tensor. + + Returns: + Any: The wait handles for the communication. + """ + + with self.stage_manager.switch_model_chunk_id(model_chunk_id): + if model_chunk_id == 0: + ################ + # chunk = 0 && is_last_stage + # do nothing; hold y on local_send_forward_buffer + ################ + if self.stage_manager.is_last_stage(ignore_chunk=True): + return [] + + ################ + # chunk = 0 && not is_last_stage + # self.comm.send_forward send y to NEXT stage + ################ + else: + next_rank = self.stage_manager.get_next_rank() + output_tensor = self.send_forward_buffer[model_chunk_id].pop(0) + send_handles = self.comm.send_forward(output_object=output_tensor, next_rank=next_rank) + return send_handles + + else: + ################ + # chunk = 1 && is_first_stage + # do nothing; Already send LOSS to local_send_backward_buffer in schedule f send part + ################ + if self.stage_manager.is_first_stage(ignore_chunk=True): + return [] + + ################ + # chunk = 1 && not is_first_stage + # self.comm.send_forward send y to PREV stage + ################ + else: + prev_rank = self.stage_manager.get_prev_rank() + output_tensor = self.send_forward_buffer[model_chunk_id].pop(0) + send_handles = self.comm.send_forward(output_tensor, prev_rank) + return send_handles + + def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List: + """Sends the gradient tensor to the previous stage in pipeline. + For ZBV. + + Args: + model_chunk_id (int): The current model chunk idx. + prev_rank (int, optional): The rank of the recipient of the tensor + + Returns: + Any: The wait handles for the communication. + """ + + with self.stage_manager.switch_model_chunk_id(model_chunk_id): + if model_chunk_id == 0: + # bwd chunk0 is right V; + ################ + # chunk = 0 && is_first_stage + # do nothing; cause u are the first chunk in first stage; bwd end + ################ + if self.stage_manager.is_first_stage(ignore_chunk=True): + return [] + + ################ + # chunk = 0 && not is_first_stage + # Send dx to PREV stage; + ################ + else: + prev_rank = self.stage_manager.get_prev_rank() + input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0) + send_handles = self.comm.send_backward(input_tensor_grad, prev_rank) + return send_handles + + # bwd chunk1 is left V; + else: + ################ + # chunk = 1 && is_last_stage + # do nothing; Already send input_tensor_grad to local_send_bwd_buffer in schedule b; + ################ + if self.stage_manager.is_last_stage(ignore_chunk=True): + return [] + + ################ + # chunk = 1 && not is_last_stage + # Send dx to NEXT stage; + ################ + else: + next_rank = self.stage_manager.get_next_rank() + input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0) + send_handles = self.comm.send_backward(input_tensor_grad, next_rank) + return send_handles + + def forward_step( + self, + model_chunk: Union[ModuleList, Module], + model_chunk_id: int, + input_obj: Optional[dict], + criterion: Callable, + accum_loss: Optional[torch.Tensor] = None, + outputs: Optional[List[Any]] = None, + ) -> Union[torch.Tensor, dict]: + """Forward one step of the pipeline + Args: + model_chunk (ModuleList or Module): Model Chunk to be run; + model_chunk_id (int): The current model chunk idx; + input_obj (Optional[dict]): x; + criterion (Callable): loss function; + accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None. + outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None. + + Returns: + Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor). + """ + # Load input ids, attention mask and labels + # micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id) + + # for the first stage, input_obj is None + # for other stages, input_obj is the output of the previous/next stage containing hidden_states etc. + # Only attention_mask from micro_batch is used + + with self.stage_manager.switch_model_chunk_id(model_chunk_id): + # fwd calculate + output_obj = model_chunk[model_chunk_id](input_obj) + # last layer in model + if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): + loss = criterion(output_obj) / self.num_microbatch + if accum_loss is not None: + accum_loss.add_(loss.detach()) + if outputs is not None: + outputs.append(tree_map(detach, output_obj)) + return loss + else: + return output_obj + + def backward_b_step( + self, + model_chunk: Union[ModuleList, Module], + model_chunk_id: int, + optimizer: OptimizerWrapper, + input_obj: Optional[dict], + output_obj: Union[dict, torch.Tensor], + output_obj_grad: Optional[dict], + ) -> Optional[dict]: + """Backward dx step of the pipeline; we calculate "dx = w*dy" here; + + Args: + model_chunk (ModuleList or Module): Model Chunk to be run; + model_chunk_id (int): The current model chunk idx; + optimizer (OptimizerWrapper): Optimizer to update the model + input_obj (Optional[dict]): x. + output_obj (Union[dict, torch.Tensor]): y. + output_obj_grad (dict): dy. + + Returns: + Optional[dict]: dx. + """ + # calculate bwd b step ; only dx = w*dy; + + # Retain the grad on the input_obj. + tree_map(retain_grad, input_obj) + + if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): + # loss backward; output_obj is loss; so output_obj_grad should be None + assert output_obj_grad is None + + optimizer.backward_by_grad( + tensor=output_obj, + grad=output_obj_grad, + inputs=input_obj, + retain_graph=True, + ) + return input_obj.grad + + def backward_w_step( + self, + model_chunk: Union[ModuleList, Module], + model_chunk_id: int, + optimizer: OptimizerWrapper, + output_obj: Union[dict, torch.Tensor], + output_obj_grad: Optional[dict], + ): + """Backward dw step of the pipeline; we calculate "dw = x*dy" here; + + Args: + model_chunk (ModuleList or Module): Model Chunk to be run; + model_chunk_id (int): The current model chunk idx; + optimizer (OptimizerWrapper): Optimizer to update the model + output_obj (Union[dict, torch.Tensor]): y. + output_obj_grad (dict): dy. + + Returns: + Nothing need to return; we only calculate dw then update w; + """ + # calculate bwd w step ; only dw = x*dy; + + if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): + # loss backward; output_obj is loss + output_obj_grad = None + optimizer.backward_by_grad( + tensor=output_obj, + grad=output_obj_grad, + inputs=list(model_chunk[model_chunk_id].parameters()), + retain_graph=False, + ) + + def schedule_f( + self, + scheduled_node, + model_chunk: torch.nn.ModuleList, + model_chunk_id: int, + criterion: Callable, + accum_loss: Optional[torch.Tensor] = None, + outputs: Optional[List[Any]] = None, + ): + """A complete forward schedule; Include recv fwd --> cal fwd --> send fwd; + + Args: + scheduled_node: + model_chunk (ModuleList or Module): Model Chunk to be run; + model_chunk_id (int): The current model chunk idx; + criterion (Callable): loss function; + accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None. + outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None. + + Returns: + Nothing. + """ + micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id) + # Step1: recv fwd + if model_chunk_id == 0: + # is first stage; get input from func param + if self.stage_manager.is_first_stage(ignore_chunk=True): + input_obj = micro_batch + else: + input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) + else: + # is last stage; recv from local + if self.stage_manager.is_last_stage(ignore_chunk=True): + input_obj = self.local_send_forward_buffer.pop(0) + # not last stage; recv from next + else: + input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) + + # Here, let input_obj.requires_grad_() + tree_map(torch.Tensor.requires_grad_, input_obj) + + # Step2: fwd step + output_obj = self.forward_step( + model_chunk=model_chunk, + model_chunk_id=model_chunk_id, + input_obj=input_obj, + criterion=criterion, + accum_loss=accum_loss, + outputs=outputs, + ) + if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): + # We should not detach bwd LOSS + pass + else: + detached_output_obj = output_obj.clone().detach() + + # Step3: send fwd + # add output to send_fwd_buffer + if model_chunk_id == 0: + # is last stage; send to local_send_forward_buffer + if self.stage_manager.is_last_stage(ignore_chunk=True): + self.local_send_forward_buffer.append(detached_output_obj) + else: + self.send_forward_buffer[model_chunk_id].append(detached_output_obj) + else: + # is first stage; end of fwd; append LOSS to local_send_backward_buffer + if self.stage_manager.is_first_stage(ignore_chunk=True): + pass + else: + self.send_forward_buffer[model_chunk_id].append(detached_output_obj) + + # add input and output object for backward b + self.input_tensors[model_chunk_id].append(input_obj) + # detached output; for bwd b&w, we only need the graph(grad_fn) of output_obj + deallocate_output_tensor(output_obj, deallocate_pipeline_outputs=True) + self.output_tensors[model_chunk_id].append(output_obj) + # add output object for backward w + self.output_tensors_dw[model_chunk_id].append(output_obj) + + def schedule_b( + self, + scheduled_node, + model_chunk: Union[ModuleList, Module], + model_chunk_id: int, + optimizer: OptimizerWrapper, + # input_obj: Optional[dict], + # output_obj: Union[dict, torch.Tensor], + # output_obj_grad: Optional[dict], + ): + """A complete backward b schedule; Include recv bwd --> cal bwd step --> send bwd; + + Args: + scheduled_node: + model_chunk (ModuleList or Module): Model Chunk to be run; + model_chunk_id (int): The current model chunk idx; + Returns: + Nothing. + """ + + # Step1: recv bwd + if model_chunk_id == 0: + # chunk0 is last stage; recv output_grad from local_send_backward_buffer + if self.stage_manager.is_last_stage(ignore_chunk=True): + output_tensor_grad = self.local_send_backward_buffer.pop(0) + # chunk 0 not last stage; recv output_grad from recv_backward_buffer + else: + output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) + else: + # chunk1, is first stage; recv LOSS from local send bwd buffer + if self.stage_manager.is_first_stage(ignore_chunk=True): + output_tensor_grad = None + # chunk1, not first stage; recv output_grad from recv_backward_buffer + else: + output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) + + # get input and output object from buffer; + input_obj = self.input_tensors[model_chunk_id].pop(0) + output_obj = self.output_tensors[model_chunk_id].pop(0) + + # save output_tensor_grad for dw + if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): + # we save loss here + self.output_tensors_grad_dw[model_chunk_id].append(output_obj) + else: + # we save output_tensor_grad here + self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad) + + # _wait_p2p(recv_bwd_handles) + # Step2: bwd step + input_object_grad = self.backward_b_step( + model_chunk=model_chunk, + model_chunk_id=model_chunk_id, + optimizer=optimizer, + input_obj=input_obj, + output_obj=output_obj, + output_obj_grad=output_tensor_grad, + ) + + # Step3: send bwd + if model_chunk_id == 0: + # do nothing; end of bwd; + if self.stage_manager.is_first_stage(ignore_chunk=True): + pass + # save input_object_grad to send_backward_buffer + else: + self.send_backward_buffer[model_chunk_id].append(input_object_grad) + else: + # send to local_send_backward_buffer + if self.stage_manager.is_last_stage(ignore_chunk=True): + self.local_send_backward_buffer.append(input_object_grad) + # send to next + else: + self.send_backward_buffer[model_chunk_id].append(input_object_grad) + + def schedule_w( + self, + scheduled_node, + model_chunk: Union[ModuleList, Module], + model_chunk_id: int, + optimizer: OptimizerWrapper, + ): + """A complete backward w schedule; Include get y & dy from buffer --> cal bwd w step(cal dw & update w); + + Args: + scheduled_node: + model_chunk (ModuleList or Module): Model Chunk to be run; + model_chunk_id (int): The current model chunk idx; + Returns: + Nothing. + """ + + # get y & dy from buffer + output_obj = self.output_tensors_dw[model_chunk_id].pop(0) + output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0) + + self.backward_w_step( + model_chunk=model_chunk, + model_chunk_id=model_chunk_id, + optimizer=optimizer, + output_obj=output_obj, + output_obj_grad=output_obj_grad, + ) + + def run_forward_only( + self, + model_chunk: Union[ModuleList, Module], + data_iter: Iterable, + criterion: Callable[..., Any], + return_loss: bool = False, + return_outputs: bool = False, + ) -> Dict: + assert self.forward_only + + # prepare batch + self.load_batch(data_iter) + + # prepare accum loss & output + accum_loss = None + + # reset accum loss at fwd end; + if return_loss and self.stage_manager.is_first_stage(ignore_chunk=True): + accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device()) + + outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None + + # while we still have schedules_node in self.schedules + for it in range(len(self.schedules)): + scheduled_node = self.schedules[it] + + if scheduled_node.type in {"RECV_FORWARD", "SEND_FORWARD"}: + # communication + communication_func = self.communication_map[scheduled_node.type] + communication_func(scheduled_node.chunk) + if scheduled_node.type == "F": + self.schedule_f( + scheduled_node=scheduled_node, + model_chunk=model_chunk, + model_chunk_id=scheduled_node.chunk, + criterion=criterion, + accum_loss=accum_loss, + outputs=outputs, + ) + # return loss & output + if outputs is not None: + outputs = merge_batch(outputs) + return {"loss": accum_loss, "outputs": outputs} + + def run_forward_backward( + self, + model_chunk: Union[ModuleList, Module], + data_iter: Iterable, + criterion: Callable[..., Any], + optimizer: Optional[OptimizerWrapper] = None, + return_loss: bool = False, + return_outputs: bool = False, + ) -> Dict: + """ + Runs Zerobubble schedule, with communication between pipeline stages. + """ + # prepare batch + self.load_batch(data_iter) + + # prepare accum loss & output + accum_loss = None + + # reset accum loss at fwd end; + if return_loss and self.stage_manager.is_first_stage(ignore_chunk=True): + accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device()) + + outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None + + # while we still have schedules_node in self.schedules + schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank) + for it in range(len(schedule)): + scheduled_node = schedule[it] + if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES: + # communication + communication_func = self.communication_map[scheduled_node.type] + communication_func(scheduled_node.chunk) + + if scheduled_node.type == "F": + self.schedule_f( + scheduled_node=scheduled_node, + model_chunk=model_chunk, + model_chunk_id=scheduled_node.chunk, + criterion=criterion, + accum_loss=accum_loss, + outputs=outputs, + ) + elif scheduled_node.type == "B": + self.schedule_b( + scheduled_node=scheduled_node, + model_chunk=model_chunk, + model_chunk_id=scheduled_node.chunk, + optimizer=optimizer, + ) + elif scheduled_node.type == "W": + self.schedule_w( + scheduled_node=scheduled_node, + model_chunk=model_chunk, + model_chunk_id=scheduled_node.chunk, + optimizer=optimizer, + ) + + # return loss & output + if outputs is not None: + outputs = merge_batch(outputs) + return {"loss": accum_loss, "outputs": outputs} + + def forward_backward_step( + self, + model_chunk: Union[ModuleList, Module], + data_iter: Iterable, + criterion: Callable[..., Any], + optimizer: Optional[OptimizerWrapper] = None, + return_loss: bool = False, + return_outputs: bool = False, + ) -> dict: + """ + Args: + model_chunk (ModuleList or Module): Model Chunk to be trained. Original interleaved uses a module list whereas shardformer uses entire model + layer specification + data_iter (Iterable): Data iterator. + criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor. + optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None. + return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss. + return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs. + + Returns: + dict: A dict with keys: 'loss' and 'outputs'. + """ + self.forward_only = not torch.is_grad_enabled() + if optimizer is None: + assert self.forward_only, "Optimizer should be passed when doing backward." + + if self.forward_only: + result = self.run_forward_only(model_chunk, data_iter, criterion, return_loss, return_outputs) + else: + result = self.run_forward_backward( + model_chunk, data_iter, criterion, optimizer, return_loss, return_outputs + ) + + self.assert_buffer_empty() + + return result diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py new file mode 100644 index 000000000..825c192d8 --- /dev/null +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -0,0 +1,769 @@ +from copy import deepcopy +from typing import Tuple + +import pytest +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.testing import assert_close + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.interface import OptimizerWrapper +from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode +from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo + + +class MlpModel(nn.Module): + def __init__(self, in_dim, out_dim, num_layers): + super().__init__() + self.layers = nn.ModuleList([nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)]) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + + +def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]: + num_params = 0 + num_params_trainable = 0 + for p in model.parameters(): + num_params += p.numel() + if p.requires_grad: + num_params_trainable += p.numel() + return num_params, num_params_trainable + + +# 1) Test manual v_schedule with multiple microbatch +@parameterize( + "test_config", + [ + { + "batch_size": 8, + "tp_size": 1, + "pp_size": 4, + "num_microbatches": 4, + "zero_stage": 1, + "precision": "bf16", + "num_model_chunk": 2, + }, + ], +) +def run_fwd_bwd_iter_input(test_config): + # init dist + rank = dist.get_rank() + pp_size = test_config["pp_size"] + pg_mesh = ProcessGroupMesh(pp_size) + num_microbatch = test_config["num_microbatches"] + num_model_chunk = test_config["num_model_chunk"] + # stage_manager + stage_manager = PipelineStageManager( + pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk + ) + + # schedule list + zbv_schedule = [ + # stage 0 + [ + # microbatch 0 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=0), + ScheduledNode(type="F", chunk=0, stage=0, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=0), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=0), + ScheduledNode(type="F", chunk=1, stage=0, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=0), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=0), + ScheduledNode(type="B", chunk=1, stage=0, minibatch=0), + ScheduledNode(type="W", chunk=1, stage=0, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=0), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=0), + ScheduledNode(type="B", chunk=0, stage=0, minibatch=0), + ScheduledNode(type="W", chunk=0, stage=0, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=0), + # microbatch 1 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=1), + ScheduledNode(type="F", chunk=0, stage=0, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=1), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=1), + ScheduledNode(type="F", chunk=1, stage=0, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=1), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=1), + ScheduledNode(type="B", chunk=1, stage=0, minibatch=1), + ScheduledNode(type="W", chunk=1, stage=0, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=1), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=1), + ScheduledNode(type="B", chunk=0, stage=0, minibatch=1), + ScheduledNode(type="W", chunk=0, stage=0, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=1), + # microbatch 2 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=2), + ScheduledNode(type="F", chunk=0, stage=0, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=2), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=2), + ScheduledNode(type="F", chunk=1, stage=0, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=2), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=2), + ScheduledNode(type="B", chunk=1, stage=0, minibatch=2), + ScheduledNode(type="W", chunk=1, stage=0, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=2), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=2), + ScheduledNode(type="B", chunk=0, stage=0, minibatch=2), + ScheduledNode(type="W", chunk=0, stage=0, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=2), + # microbatch 3 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=3), + ScheduledNode(type="F", chunk=0, stage=0, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=3), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=3), + ScheduledNode(type="F", chunk=1, stage=0, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=3), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=3), + ScheduledNode(type="B", chunk=1, stage=0, minibatch=3), + ScheduledNode(type="W", chunk=1, stage=0, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=3), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=3), + ScheduledNode(type="B", chunk=0, stage=0, minibatch=3), + ScheduledNode(type="W", chunk=0, stage=0, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=3), + ], + # stage 1 + [ + # microbatch 0 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=0), + ScheduledNode(type="F", chunk=0, stage=1, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=0), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=0), + ScheduledNode(type="F", chunk=1, stage=1, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=0), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=0), + ScheduledNode(type="B", chunk=1, stage=1, minibatch=0), + ScheduledNode(type="W", chunk=1, stage=1, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=0), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=0), + ScheduledNode(type="B", chunk=0, stage=1, minibatch=0), + ScheduledNode(type="W", chunk=0, stage=1, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=0), + # microbatch 1 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=1), + ScheduledNode(type="F", chunk=0, stage=1, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=1), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=1), + ScheduledNode(type="F", chunk=1, stage=1, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=1), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=1), + ScheduledNode(type="B", chunk=1, stage=1, minibatch=1), + ScheduledNode(type="W", chunk=1, stage=1, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=1), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=1), + ScheduledNode(type="B", chunk=0, stage=1, minibatch=1), + ScheduledNode(type="W", chunk=0, stage=1, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=1), + # microbatch 2 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=2), + ScheduledNode(type="F", chunk=0, stage=1, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=2), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=2), + ScheduledNode(type="F", chunk=1, stage=1, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=2), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=2), + ScheduledNode(type="B", chunk=1, stage=1, minibatch=2), + ScheduledNode(type="W", chunk=1, stage=1, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=2), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=2), + ScheduledNode(type="B", chunk=0, stage=1, minibatch=2), + ScheduledNode(type="W", chunk=0, stage=1, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=2), + # microbatch 3 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=3), + ScheduledNode(type="F", chunk=0, stage=1, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=3), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=3), + ScheduledNode(type="F", chunk=1, stage=1, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=3), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=3), + ScheduledNode(type="B", chunk=1, stage=1, minibatch=3), + ScheduledNode(type="W", chunk=1, stage=1, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=3), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=3), + ScheduledNode(type="B", chunk=0, stage=1, minibatch=3), + ScheduledNode(type="W", chunk=0, stage=1, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=3), + ], + # stage 2 + [ + # microbatch 0 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=0), + ScheduledNode(type="F", chunk=0, stage=2, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=0), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=0), + ScheduledNode(type="F", chunk=1, stage=2, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=0), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=0), + ScheduledNode(type="B", chunk=1, stage=2, minibatch=0), + ScheduledNode(type="W", chunk=1, stage=2, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=0), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=0), + ScheduledNode(type="B", chunk=0, stage=2, minibatch=0), + ScheduledNode(type="W", chunk=0, stage=2, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=0), + # microbatch 1 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=1), + ScheduledNode(type="F", chunk=0, stage=2, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=1), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=1), + ScheduledNode(type="F", chunk=1, stage=2, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=1), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=1), + ScheduledNode(type="B", chunk=1, stage=2, minibatch=1), + ScheduledNode(type="W", chunk=1, stage=2, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=1), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=1), + ScheduledNode(type="B", chunk=0, stage=2, minibatch=1), + ScheduledNode(type="W", chunk=0, stage=2, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=1), + # microbatch 2 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=2), + ScheduledNode(type="F", chunk=0, stage=2, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=2), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=2), + ScheduledNode(type="F", chunk=1, stage=2, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=2), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=2), + ScheduledNode(type="B", chunk=1, stage=2, minibatch=2), + ScheduledNode(type="W", chunk=1, stage=2, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=2), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=2), + ScheduledNode(type="B", chunk=0, stage=2, minibatch=2), + ScheduledNode(type="W", chunk=0, stage=2, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=2), + # microbatch 3 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=3), + ScheduledNode(type="F", chunk=0, stage=2, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=3), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=3), + ScheduledNode(type="F", chunk=1, stage=2, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=3), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=3), + ScheduledNode(type="B", chunk=1, stage=2, minibatch=3), + ScheduledNode(type="W", chunk=1, stage=2, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=3), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=3), + ScheduledNode(type="B", chunk=0, stage=2, minibatch=3), + ScheduledNode(type="W", chunk=0, stage=2, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=3), + ], + # stage 3 + [ + # microbatch 0 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=0), + ScheduledNode(type="F", chunk=0, stage=3, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=0), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=0), + ScheduledNode(type="F", chunk=1, stage=3, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=0), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=0), + ScheduledNode(type="B", chunk=1, stage=3, minibatch=0), + ScheduledNode(type="W", chunk=1, stage=3, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=0), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=0), + ScheduledNode(type="B", chunk=0, stage=3, minibatch=0), + ScheduledNode(type="W", chunk=0, stage=3, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=0), + # microbatch 1 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=1), + ScheduledNode(type="F", chunk=0, stage=3, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=1), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=1), + ScheduledNode(type="F", chunk=1, stage=3, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=1), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=1), + ScheduledNode(type="B", chunk=1, stage=3, minibatch=1), + ScheduledNode(type="W", chunk=1, stage=3, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=1), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=1), + ScheduledNode(type="B", chunk=0, stage=3, minibatch=1), + ScheduledNode(type="W", chunk=0, stage=3, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=1), + # microbatch 2 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=2), + ScheduledNode(type="F", chunk=0, stage=3, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=2), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=2), + ScheduledNode(type="F", chunk=1, stage=3, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=2), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=2), + ScheduledNode(type="B", chunk=1, stage=3, minibatch=2), + ScheduledNode(type="W", chunk=1, stage=3, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=2), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=2), + ScheduledNode(type="B", chunk=0, stage=3, minibatch=2), + ScheduledNode(type="W", chunk=0, stage=3, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=2), + # microbatch 3 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=3), + ScheduledNode(type="F", chunk=0, stage=3, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=3), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=3), + ScheduledNode(type="F", chunk=1, stage=3, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=3), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=3), + ScheduledNode(type="B", chunk=1, stage=3, minibatch=3), + ScheduledNode(type="W", chunk=1, stage=3, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=3), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=3), + ScheduledNode(type="B", chunk=0, stage=3, minibatch=3), + ScheduledNode(type="W", chunk=0, stage=3, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=3), + ], + ] + + scheduler = ZeroBubbleVPipeScheduler( + schedule=zbv_schedule, # hint: send whole schedule or local schedule only ? + stage_manager=stage_manager, + num_model_chunks=pp_size, + num_microbatch=num_microbatch, + overlap_p2p=False, + ) + + # loss func + def criterion(x, *args, **kwargs): + return (x * x).mean() + + # init model and input + batch_size = 4 + num_layers = 8 + in_dim = out_dim = 8 + print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};") + model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank) + data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)] + + input_base = [t.clone() for t in data_iter] + model_base = deepcopy(model) + + if rank == 0: + # layer 0 & 7 to chunk 0 on rank0 + local_chunk = torch.nn.ModuleList().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 0 or idx == 7: + local_chunk.append(sub_model) + elif rank == 1: + # layer 1 & 6 to chunk 1 on rank1 + local_chunk = torch.nn.ModuleList().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 1 or idx == 6: + local_chunk.append(sub_model) + elif rank == 2: + # layer 2 & 5 to chunk 2 on rank2 + local_chunk = torch.nn.ModuleList().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 2 or idx == 5: + local_chunk.append(sub_model) + else: + # layer 3 & 4 to chunk 3 on rank3 + local_chunk = torch.nn.ModuleList().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 3 or idx == 4: + local_chunk.append(sub_model) + # init optimizer + optimizer_base = torch.optim.SGD(model_base.parameters(), lr=1e-5) + optimizer_pp = OptimizerWrapper(torch.optim.SGD(local_chunk.parameters(), lr=1e-5)) + + print( + f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" + ) + + torch.cuda.synchronize() + result = scheduler.forward_backward_step( + model_chunk=local_chunk, + data_iter=iter(data_iter), + criterion=criterion, + optimizer=optimizer_pp, + return_loss=True, + return_outputs=True, + ) + + optimizer_pp.step() + + ########################## + # Fwd bwd for base + ########################## + # fwd & bwd + output_base = model_base(input_base[0]) + loss_base = criterion(output_base) + loss_base.backward() + optimizer_base.step() + print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + ########################## + # assert weight + ########################## + if rank == 0: + # layer 0 + assert_close(local_chunk[0].weight, model_base.layers[0].weight) + assert_close(local_chunk[0].weight.grad, model_base.layers[0].weight.grad) + # layer 7 + assert_close(local_chunk[1].weight, model_base.layers[7].weight) + assert_close(local_chunk[1].weight.grad, model_base.layers[7].weight.grad) + if rank == 1: + # layer 1 + assert_close(local_chunk[0].weight, model_base.layers[1].weight) + assert_close(local_chunk[0].weight.grad, model_base.layers[1].weight.grad) + # layer 6 + assert_close(local_chunk[1].weight, model_base.layers[6].weight) + assert_close(local_chunk[1].weight.grad, model_base.layers[6].weight.grad) + if rank == 2: + # layer 2 + assert_close(local_chunk[0].weight, model_base.layers[2].weight) + assert_close(local_chunk[0].weight.grad, model_base.layers[2].weight.grad) + # layer 5 + assert_close(local_chunk[1].weight, model_base.layers[5].weight) + assert_close(local_chunk[1].weight.grad, model_base.layers[5].weight.grad) + if rank == 3: + # layer 3 + assert_close(local_chunk[0].weight, model_base.layers[3].weight) + assert_close(local_chunk[0].weight.grad, model_base.layers[3].weight.grad) + # layer 4 + assert_close(local_chunk[1].weight, model_base.layers[4].weight) + assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad) + + +# 2) add optimizer base 1) +@parameterize( + "test_config", + [ + { + "batch_size": 8, + "tp_size": 1, + "pp_size": 4, + "num_microbatches": 4, + "zero_stage": 1, + "precision": "bf16", + "num_model_chunk": 2, + }, + { + "batch_size": 8, + "tp_size": 1, + "pp_size": 4, + "num_microbatches": 8, + "zero_stage": 1, + "precision": "bf16", + "num_model_chunk": 2, + }, + ], +) +def run_fwd_bwd_vschedule_with_optim(test_config): + # init dist + rank = dist.get_rank() + pp_size = test_config["pp_size"] + pg_mesh = ProcessGroupMesh(pp_size) + num_microbatch = test_config["num_microbatches"] + num_model_chunk = test_config["num_model_chunk"] + # stage_manager + stage_manager = PipelineStageManager( + pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk + ) + + h, a, s = 4096, 32, 1024 + mem_f = 34 * h + 5 * a * s + mem_w = -32 * h + mem_b = -mem_w - mem_f + graph = PipelineGraph( + n_stage=pp_size, + n_micro=num_microbatch, + f_cost=1, + b_cost=1, + w_cost=1, + c_cost=1, + f_mem=mem_f, + b_mem=mem_b, + w_mem=mem_w, + # max_mem=mem_f * (p * 2 + m_offset), + ) + + zbv_schedule = graph.get_v_schedule() + + scheduler = ZeroBubbleVPipeScheduler( + schedule=zbv_schedule, # hint: send whole schedule or local schedule only ? + stage_manager=stage_manager, + num_model_chunks=num_model_chunk, + num_microbatch=num_microbatch, + overlap_p2p=False, + ) + + # init loss func + def criterion(x, *args, **kwargs): + return (x * x).mean() + + # init model and input + batch_size = test_config["batch_size"] + num_layers = 8 + assert num_layers % num_model_chunk == 0, f"Model with {num_layers} layer can not dist on {num_model_chunk} chunk" + in_dim = out_dim = 4096 + before_init_memory = torch.cuda.memory_allocated() / 1024**3 + print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};") + model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank) + data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)] + + input_base = [t.clone() for t in data_iter] + model_base = deepcopy(model) + + if rank == 0: + # layer 0 & 7 to chunk 0 on rank0 + local_chunk = torch.nn.ModuleList().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 0 or idx == 7: + local_chunk.append(sub_model) + elif rank == 1: + # layer 1 & 6 to chunk 1 on rank1 + local_chunk = torch.nn.ModuleList().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 1 or idx == 6: + local_chunk.append(sub_model) + elif rank == 2: + # layer 2 & 5 to chunk 2 on rank2 + local_chunk = torch.nn.ModuleList().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 2 or idx == 5: + local_chunk.append(sub_model) + else: + # layer 3 & 4 to chunk 3 on rank3 + local_chunk = torch.nn.ModuleList().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 3 or idx == 4: + local_chunk.append(sub_model) + + # init optimizer + optimizer_base = torch.optim.SGD(model_base.parameters(), momentum=0.1, lr=1e-5) + optimizer_pp = OptimizerWrapper(torch.optim.SGD(local_chunk.parameters(), momentum=0.1, lr=1e-5)) + + after_init_memory = torch.cuda.memory_allocated() / 1024**3 + print(f"After init Model & input: {after_init_memory :.5f} GB on device {stage_manager.get_rank()};") + + torch.cuda.synchronize() + result = scheduler.forward_backward_step( + model_chunk=local_chunk, + data_iter=iter(data_iter), + criterion=criterion, + optimizer=optimizer_pp, + return_loss=True, + return_outputs=True, + ) + + optimizer_pp.step() + + after_pp_step_memory = torch.cuda.memory_allocated() / 1024**3 + + # assert memory + if rank != 0: + # w.grad hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3 + # output hid_dim * hid_dim * 4(fp32) / 1024**3 + # optim state hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3 + print(f"rank {rank}: {(after_pp_step_memory - after_init_memory)} <= {(in_dim * in_dim * 4 * 5 / 1024**3)}") + assert (after_pp_step_memory - after_init_memory) <= (in_dim * in_dim * 4 * 5 / 1024**3) + else: + # rank0 will also hold output; + print( + f"rank {rank}: {round((after_pp_step_memory - after_init_memory), 5)} <= {round((in_dim * in_dim * 4 * 5 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5)}" + ) + assert round((after_pp_step_memory - after_init_memory), 5) <= round( + (in_dim * in_dim * 4 * 5 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5 + ) + + ########################## + # Fwd bwd for base + ########################## + # fwd & bwd + output_base = model_base(input_base[0]) + loss_base = criterion(output_base) + loss_base.backward() + optimizer_base.step() + + ########################## + # assert loss & output + ########################## + # only chunk 1 stage 0 hold loss and output + if rank == 0: + assert_close(result["loss"], loss_base) + assert_close(result["outputs"], output_base) + + # print(f"pp result {result}; base result loss:{loss_base} output_base:{output_base} ") + ########################## + # assert weight + ########################## + if rank == 0: + # layer 0 + assert_close(local_chunk[0].weight, model_base.layers[0].weight) + assert_close(local_chunk[0].weight.grad, model_base.layers[0].weight.grad) + # layer 7 + assert_close(local_chunk[1].weight, model_base.layers[7].weight) + assert_close(local_chunk[1].weight.grad, model_base.layers[7].weight.grad) + if rank == 1: + # layer 1 + assert_close(local_chunk[0].weight, model_base.layers[1].weight) + assert_close(local_chunk[0].weight.grad, model_base.layers[1].weight.grad) + # layer 6 + assert_close(local_chunk[1].weight, model_base.layers[6].weight) + assert_close(local_chunk[1].weight.grad, model_base.layers[6].weight.grad) + if rank == 2: + # layer 2 + assert_close(local_chunk[0].weight, model_base.layers[2].weight) + assert_close(local_chunk[0].weight.grad, model_base.layers[2].weight.grad) + # layer 5 + assert_close(local_chunk[1].weight, model_base.layers[5].weight) + assert_close(local_chunk[1].weight.grad, model_base.layers[5].weight.grad) + if rank == 3: + # layer 3 + assert_close(local_chunk[0].weight, model_base.layers[3].weight) + assert_close(local_chunk[0].weight.grad, model_base.layers[3].weight.grad) + # layer 4 + assert_close(local_chunk[1].weight, model_base.layers[4].weight) + assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad) + + ########################## + # assert optim state + ########################## + optim_base_state = optimizer_base.state_dict()["state"] + optim_pp_state = optimizer_pp.state_dict()["state"] + optim_base_param_groups = optimizer_base.state_dict()["param_groups"][0] + optim_pp_param_groups = optimizer_pp.state_dict()["param_groups"][0] + # if rank == 0: + # print(f"optim_base_state {optim_base_state}") + + # assert param group + for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_param_groups.items(), optim_pp_param_groups.items()): + if key_base == key_pp: + if key_base != "params": + assert val_base == val_pp + else: + # BUG: + # param_base: [0, 1, 2, 3, 4, 5, 6, 7]; + # params pp: [0, 1]; + assert val_base[:2] == val_pp + + # assert state + assert_close(optim_pp_state[0]["momentum_buffer"], optim_base_state[2 * rank]["momentum_buffer"]) + assert_close(optim_pp_state[1]["momentum_buffer"], optim_base_state[2 * rank + 1]["momentum_buffer"]) + + +# TODO:4) support Hybrid base 3) +def run_with_hybridplugin(test_config): + pass + + +# TODO:5) support MoEHybrid base 3) +@parameterize( + "test_config", + [ + { + "batch_size": 8, + "tp_size": 1, + "pp_size": 4, + "num_microbatches": 4, + "zero_stage": 1, + "precision": "bf16", + "num_model_chunk": 2, + }, + ], +) +def run_with_moehybridplugin(test_config): + model_zoo.get_sub_registry("transformers_bert") + test_config["use_lazy_init"] = False + test_config["initial_scale"] = 2**16 + model_list = [ + "transformers_bert", + ] + + +# TODO:6) support booster & Hybrid base 4) + +# TODO:7) support booster & MoEHybrid base 4) + + +def run_dist(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + # run_fwd_bwd_iter_input() + run_fwd_bwd_vschedule_with_optim() + # run_with_moehybridplugin() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_pp(): + spawn( + run_dist, + nprocs=4, + ) + + +if __name__ == "__main__": + test_pp()