ColossalAI/tests/test_pipeline/test_schedule/test_zerobubble_pp.py

903 lines
39 KiB
Python
Raw Normal View History

from copy import deepcopy
from functools import partial
from types import MethodType
from typing import Tuple
2024-08-27 07:31:58 +00:00
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
2024-08-23 06:04:12 +00:00
from torch.testing import assert_close
import colossalai
from colossalai.cluster import ProcessGroupMesh
2024-08-29 03:16:59 +00:00
from colossalai.interface import OptimizerWrapper
2024-09-02 09:50:47 +00:00
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
from colossalai.pipeline.stage_manager import PipelineStageManager
2024-09-12 02:51:46 +00:00
from colossalai.tensor.d_tensor.api import clear_layout_converter
2024-09-02 09:50:47 +00:00
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
[zerobubble]Support ZeroBubble Pipeline (#6034) * [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble; * [feat] add dw test; * [fix] fix weight not close; * [update] update text; * [feat] add test run_fwd_bwd automatic scheduling; * [feat] split communication and calculation; fix pop empty send_bwd_buffer error; * [feat] add test for p & p grad; * [feat] add comments for ZBV func; * [fix] rm useless assign and comments; * [fix] fix ci test; add pytest; * [feat] add run_fwd_bwd_with_microbatch (replace input) & test; add p&p.grad assert close test & all pass; * [feat] add apply v_schedule graph; p & p.grad assert err exist; * [fix] update * [feat] fix ci; add assert; * [feat] fix poc format * [feat] fix func name & ci; add comments; * [fix] fix poc test; add comments in poc; * [feat] add optim backward_b_by_grad * [feat] fix optimizer bwd b & w; support return accum loss & output * [feat] add fwd_bwd_step, run_fwd_only; * [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict; * [fix] fix communication_map; * [feat] update test; rm comments; * [fix] rm zbv in hybridplugin * [fix] fix optim bwd; * [fix] fix optim bwd; * [fix] rm output.data after send fwd; * [fix] fix bwd step if condition; remove useless comments and format info; * [fix] fix detach output & release output; * [fix] rm requir_grad for output; * [fix] fix requir grad position and detach position and input&output local buffer append position; * [feat] add memory assertation; * [fix] fix mem check; * [fix] mem assertation' * [fix] fix mem assertation * [fix] fix mem; use a new model shape; only assert mem less and equal than theo; * [fix] fix model zoo import; * [fix] fix redundant detach & clone; add buffer assertation in the end; * [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap; * [fix] update optim state dict assert (include param group & state); fix mem assert after add optim; * [fix] add testcase with microbatch 4;
2024-09-10 09:33:09 +00:00
from tests.kit.model_zoo import model_zoo
class MlpModel(nn.Module):
def __init__(self, in_dim, out_dim, num_layers):
super().__init__()
self.layers = nn.ModuleList([nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)])
def forward(
self,
hidden_states,
):
for layer in self.layers:
hidden_states = layer(hidden_states)
return hidden_states
def pp_linear_fwd(
forward,
data: torch.Tensor = None,
hidden_states: torch.Tensor = None,
stage_mgr: PipelineStageManager = None,
model_chunk_id: int = None,
):
with stage_mgr.switch_model_chunk_id(model_chunk_id):
# fwd end
if stage_mgr.is_first_stage() and model_chunk_id == 1:
return forward(hidden_states)
# fwd start
elif stage_mgr.is_first_stage() and model_chunk_id == 0:
return {"hidden_states": forward(hidden_states)}
# fwd middle
else:
return {"hidden_states": forward(hidden_states)}
def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
num_params = 0
num_params_trainable = 0
for p in model.parameters():
num_params += p.numel()
if p.requires_grad:
num_params_trainable += p.numel()
return num_params, num_params_trainable
# 1) Test manual v_schedule with multiple microbatch
2024-09-02 09:50:47 +00:00
@parameterize(
"test_config",
[
{
"batch_size": 8,
2024-09-02 09:50:47 +00:00
"tp_size": 1,
"pp_size": 4,
"num_microbatches": 4,
"zero_stage": 1,
"precision": "bf16",
"num_model_chunk": 2,
2024-09-02 09:50:47 +00:00
},
],
)
def run_fwd_bwd_iter_input(test_config):
# init dist
rank = dist.get_rank()
2024-09-02 09:50:47 +00:00
pp_size = test_config["pp_size"]
pg_mesh = ProcessGroupMesh(pp_size)
2024-09-02 09:50:47 +00:00
num_microbatch = test_config["num_microbatches"]
num_model_chunk = test_config["num_model_chunk"]
# stage_manager
2024-09-02 09:50:47 +00:00
stage_manager = PipelineStageManager(
pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk
)
# schedule list
zbv_schedule = [
# stage 0
[
# microbatch 0
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=0),
ScheduledNode(type="F", chunk=0, stage=0, minibatch=0),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=0),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=0),
ScheduledNode(type="F", chunk=1, stage=0, minibatch=0),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=0),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=0),
ScheduledNode(type="B", chunk=1, stage=0, minibatch=0),
ScheduledNode(type="W", chunk=1, stage=0, minibatch=0),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=0),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=0),
ScheduledNode(type="B", chunk=0, stage=0, minibatch=0),
ScheduledNode(type="W", chunk=0, stage=0, minibatch=0),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=0),
# microbatch 1
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=1),
ScheduledNode(type="F", chunk=0, stage=0, minibatch=1),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=1),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=1),
ScheduledNode(type="F", chunk=1, stage=0, minibatch=1),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=1),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=1),
ScheduledNode(type="B", chunk=1, stage=0, minibatch=1),
ScheduledNode(type="W", chunk=1, stage=0, minibatch=1),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=1),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=1),
ScheduledNode(type="B", chunk=0, stage=0, minibatch=1),
ScheduledNode(type="W", chunk=0, stage=0, minibatch=1),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=1),
# microbatch 2
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=2),
ScheduledNode(type="F", chunk=0, stage=0, minibatch=2),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=2),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=2),
ScheduledNode(type="F", chunk=1, stage=0, minibatch=2),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=2),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=2),
ScheduledNode(type="B", chunk=1, stage=0, minibatch=2),
ScheduledNode(type="W", chunk=1, stage=0, minibatch=2),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=2),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=2),
ScheduledNode(type="B", chunk=0, stage=0, minibatch=2),
ScheduledNode(type="W", chunk=0, stage=0, minibatch=2),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=2),
# microbatch 3
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=3),
ScheduledNode(type="F", chunk=0, stage=0, minibatch=3),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=3),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=3),
ScheduledNode(type="F", chunk=1, stage=0, minibatch=3),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=3),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=3),
ScheduledNode(type="B", chunk=1, stage=0, minibatch=3),
ScheduledNode(type="W", chunk=1, stage=0, minibatch=3),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=3),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=3),
ScheduledNode(type="B", chunk=0, stage=0, minibatch=3),
ScheduledNode(type="W", chunk=0, stage=0, minibatch=3),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=3),
],
# stage 1
[
# microbatch 0
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=0),
ScheduledNode(type="F", chunk=0, stage=1, minibatch=0),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=0),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=0),
ScheduledNode(type="F", chunk=1, stage=1, minibatch=0),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=0),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=0),
ScheduledNode(type="B", chunk=1, stage=1, minibatch=0),
ScheduledNode(type="W", chunk=1, stage=1, minibatch=0),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=0),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=0),
ScheduledNode(type="B", chunk=0, stage=1, minibatch=0),
ScheduledNode(type="W", chunk=0, stage=1, minibatch=0),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=0),
# microbatch 1
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=1),
ScheduledNode(type="F", chunk=0, stage=1, minibatch=1),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=1),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=1),
ScheduledNode(type="F", chunk=1, stage=1, minibatch=1),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=1),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=1),
ScheduledNode(type="B", chunk=1, stage=1, minibatch=1),
ScheduledNode(type="W", chunk=1, stage=1, minibatch=1),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=1),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=1),
ScheduledNode(type="B", chunk=0, stage=1, minibatch=1),
ScheduledNode(type="W", chunk=0, stage=1, minibatch=1),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=1),
# microbatch 2
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=2),
ScheduledNode(type="F", chunk=0, stage=1, minibatch=2),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=2),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=2),
ScheduledNode(type="F", chunk=1, stage=1, minibatch=2),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=2),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=2),
ScheduledNode(type="B", chunk=1, stage=1, minibatch=2),
ScheduledNode(type="W", chunk=1, stage=1, minibatch=2),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=2),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=2),
ScheduledNode(type="B", chunk=0, stage=1, minibatch=2),
ScheduledNode(type="W", chunk=0, stage=1, minibatch=2),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=2),
# microbatch 3
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=3),
ScheduledNode(type="F", chunk=0, stage=1, minibatch=3),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=3),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=3),
ScheduledNode(type="F", chunk=1, stage=1, minibatch=3),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=3),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=3),
ScheduledNode(type="B", chunk=1, stage=1, minibatch=3),
ScheduledNode(type="W", chunk=1, stage=1, minibatch=3),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=3),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=3),
ScheduledNode(type="B", chunk=0, stage=1, minibatch=3),
ScheduledNode(type="W", chunk=0, stage=1, minibatch=3),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=3),
],
# stage 2
[
# microbatch 0
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=0),
ScheduledNode(type="F", chunk=0, stage=2, minibatch=0),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=0),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=0),
ScheduledNode(type="F", chunk=1, stage=2, minibatch=0),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=0),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=0),
ScheduledNode(type="B", chunk=1, stage=2, minibatch=0),
ScheduledNode(type="W", chunk=1, stage=2, minibatch=0),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=0),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=0),
ScheduledNode(type="B", chunk=0, stage=2, minibatch=0),
ScheduledNode(type="W", chunk=0, stage=2, minibatch=0),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=0),
# microbatch 1
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=1),
ScheduledNode(type="F", chunk=0, stage=2, minibatch=1),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=1),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=1),
ScheduledNode(type="F", chunk=1, stage=2, minibatch=1),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=1),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=1),
ScheduledNode(type="B", chunk=1, stage=2, minibatch=1),
ScheduledNode(type="W", chunk=1, stage=2, minibatch=1),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=1),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=1),
ScheduledNode(type="B", chunk=0, stage=2, minibatch=1),
ScheduledNode(type="W", chunk=0, stage=2, minibatch=1),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=1),
# microbatch 2
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=2),
ScheduledNode(type="F", chunk=0, stage=2, minibatch=2),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=2),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=2),
ScheduledNode(type="F", chunk=1, stage=2, minibatch=2),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=2),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=2),
ScheduledNode(type="B", chunk=1, stage=2, minibatch=2),
ScheduledNode(type="W", chunk=1, stage=2, minibatch=2),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=2),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=2),
ScheduledNode(type="B", chunk=0, stage=2, minibatch=2),
ScheduledNode(type="W", chunk=0, stage=2, minibatch=2),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=2),
# microbatch 3
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=3),
ScheduledNode(type="F", chunk=0, stage=2, minibatch=3),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=3),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=3),
ScheduledNode(type="F", chunk=1, stage=2, minibatch=3),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=3),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=3),
ScheduledNode(type="B", chunk=1, stage=2, minibatch=3),
ScheduledNode(type="W", chunk=1, stage=2, minibatch=3),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=3),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=3),
ScheduledNode(type="B", chunk=0, stage=2, minibatch=3),
ScheduledNode(type="W", chunk=0, stage=2, minibatch=3),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=3),
],
# stage 3
[
# microbatch 0
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=0),
ScheduledNode(type="F", chunk=0, stage=3, minibatch=0),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=0),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=0),
ScheduledNode(type="F", chunk=1, stage=3, minibatch=0),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=0),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=0),
ScheduledNode(type="B", chunk=1, stage=3, minibatch=0),
ScheduledNode(type="W", chunk=1, stage=3, minibatch=0),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=0),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=0),
ScheduledNode(type="B", chunk=0, stage=3, minibatch=0),
ScheduledNode(type="W", chunk=0, stage=3, minibatch=0),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=0),
# microbatch 1
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=1),
ScheduledNode(type="F", chunk=0, stage=3, minibatch=1),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=1),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=1),
ScheduledNode(type="F", chunk=1, stage=3, minibatch=1),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=1),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=1),
ScheduledNode(type="B", chunk=1, stage=3, minibatch=1),
ScheduledNode(type="W", chunk=1, stage=3, minibatch=1),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=1),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=1),
ScheduledNode(type="B", chunk=0, stage=3, minibatch=1),
ScheduledNode(type="W", chunk=0, stage=3, minibatch=1),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=1),
# microbatch 2
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=2),
ScheduledNode(type="F", chunk=0, stage=3, minibatch=2),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=2),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=2),
ScheduledNode(type="F", chunk=1, stage=3, minibatch=2),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=2),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=2),
ScheduledNode(type="B", chunk=1, stage=3, minibatch=2),
ScheduledNode(type="W", chunk=1, stage=3, minibatch=2),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=2),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=2),
ScheduledNode(type="B", chunk=0, stage=3, minibatch=2),
ScheduledNode(type="W", chunk=0, stage=3, minibatch=2),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=2),
# microbatch 3
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=3),
ScheduledNode(type="F", chunk=0, stage=3, minibatch=3),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=3),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=3),
ScheduledNode(type="F", chunk=1, stage=3, minibatch=3),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=3),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=3),
ScheduledNode(type="B", chunk=1, stage=3, minibatch=3),
ScheduledNode(type="W", chunk=1, stage=3, minibatch=3),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=3),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=3),
ScheduledNode(type="B", chunk=0, stage=3, minibatch=3),
ScheduledNode(type="W", chunk=0, stage=3, minibatch=3),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=3),
],
]
scheduler = ZeroBubbleVPipeScheduler(
2024-09-02 09:50:47 +00:00
schedule=zbv_schedule, # hint: send whole schedule or local schedule only ?
stage_manager=stage_manager,
num_model_chunks=pp_size,
num_microbatch=num_microbatch,
overlap_p2p=False,
)
# loss func
def criterion(x, *args, **kwargs):
return (x * x).mean()
# init model and input
batch_size = 4
num_layers = 8
in_dim = out_dim = 8
2024-08-27 06:37:26 +00:00
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)]
input_base = [t.clone() for t in data_iter]
2024-08-27 06:37:26 +00:00
model_base = deepcopy(model)
if rank == 0:
# layer 0 & 7 to chunk 0 on rank0
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 0 or idx == 7:
local_chunk.append(sub_model)
elif rank == 1:
# layer 1 & 6 to chunk 1 on rank1
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 1 or idx == 6:
local_chunk.append(sub_model)
elif rank == 2:
# layer 2 & 5 to chunk 2 on rank2
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 2 or idx == 5:
local_chunk.append(sub_model)
else:
# layer 3 & 4 to chunk 3 on rank3
2024-09-02 11:19:42 +00:00
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 3 or idx == 4:
local_chunk.append(sub_model)
2024-09-02 09:50:47 +00:00
# init optimizer
optimizer_base = torch.optim.SGD(model_base.parameters(), lr=1e-5)
optimizer_pp = OptimizerWrapper(torch.optim.SGD(local_chunk.parameters(), lr=1e-5))
print(
f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
)
torch.cuda.synchronize()
2024-09-02 09:50:47 +00:00
result = scheduler.forward_backward_step(
model_chunk=local_chunk,
data_iter=iter(data_iter),
criterion=criterion,
2024-09-02 09:50:47 +00:00
optimizer=optimizer_pp,
return_loss=True,
return_outputs=True,
)
2024-09-02 09:50:47 +00:00
optimizer_pp.step()
##########################
# Fwd bwd for base
##########################
# fwd & bwd
output_base = model_base(input_base[0])
2024-08-27 06:37:26 +00:00
loss_base = criterion(output_base)
loss_base.backward()
2024-09-02 09:50:47 +00:00
optimizer_base.step()
2024-08-27 06:37:26 +00:00
print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
##########################
# assert weight
##########################
if rank == 0:
# layer 0
assert_close(local_chunk[0].weight, model_base.layers[0].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[0].weight.grad)
# layer 7
assert_close(local_chunk[1].weight, model_base.layers[7].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[7].weight.grad)
if rank == 1:
# layer 1
assert_close(local_chunk[0].weight, model_base.layers[1].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[1].weight.grad)
# layer 6
assert_close(local_chunk[1].weight, model_base.layers[6].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[6].weight.grad)
if rank == 2:
# layer 2
assert_close(local_chunk[0].weight, model_base.layers[2].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[2].weight.grad)
# layer 5
assert_close(local_chunk[1].weight, model_base.layers[5].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[5].weight.grad)
if rank == 3:
# layer 3
assert_close(local_chunk[0].weight, model_base.layers[3].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[3].weight.grad)
# layer 4
assert_close(local_chunk[1].weight, model_base.layers[4].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad)
2024-09-02 09:50:47 +00:00
# 2) add optimizer base 1)
@parameterize(
"test_config",
[
{
"batch_size": 8,
2024-09-02 09:50:47 +00:00
"tp_size": 1,
"pp_size": 4,
"num_microbatches": 4,
"zero_stage": 1,
"precision": "bf16",
"num_model_chunk": 2,
2024-09-02 09:50:47 +00:00
},
2024-09-18 07:51:54 +00:00
{
"batch_size": 8,
"tp_size": 1,
"pp_size": 4,
"num_microbatches": 8,
"zero_stage": 1,
"precision": "bf16",
"num_model_chunk": 2,
},
2024-09-02 09:50:47 +00:00
],
)
def run_fwd_bwd_vschedule_with_optim(test_config):
2024-08-29 03:16:59 +00:00
# init dist
rank = dist.get_rank()
2024-09-02 09:50:47 +00:00
pp_size = test_config["pp_size"]
2024-08-29 03:16:59 +00:00
pg_mesh = ProcessGroupMesh(pp_size)
2024-09-02 09:50:47 +00:00
num_microbatch = test_config["num_microbatches"]
num_model_chunk = test_config["num_model_chunk"]
2024-08-29 03:16:59 +00:00
# stage_manager
stage_manager = PipelineStageManager(
pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk
)
h, a, s = 4096, 32, 1024
mem_f = 34 * h + 5 * a * s
mem_w = -32 * h
mem_b = -mem_w - mem_f
graph = PipelineGraph(
2024-09-02 09:50:47 +00:00
n_stage=pp_size,
2024-08-29 03:16:59 +00:00
n_micro=num_microbatch,
2024-08-30 02:47:52 +00:00
f_cost=1,
b_cost=1,
w_cost=1,
c_cost=1,
2024-08-29 03:16:59 +00:00
f_mem=mem_f,
b_mem=mem_b,
w_mem=mem_w,
# max_mem=mem_f * (p * 2 + m_offset),
)
zbv_schedule = graph.get_v_schedule()
scheduler = ZeroBubbleVPipeScheduler(
2024-09-02 09:50:47 +00:00
schedule=zbv_schedule, # hint: send whole schedule or local schedule only ?
2024-08-29 03:16:59 +00:00
stage_manager=stage_manager,
num_model_chunks=num_model_chunk,
num_microbatch=num_microbatch,
overlap_p2p=False,
)
# init loss func
def criterion(x, *args, **kwargs):
x = x["hidden_states"]
return (x * x).mean()
def criterion_base(x, *args, **kwargs):
2024-08-29 03:16:59 +00:00
return (x * x).mean()
# init model and input
2024-09-02 09:50:47 +00:00
batch_size = test_config["batch_size"]
2024-08-29 03:16:59 +00:00
num_layers = 8
assert num_layers % num_model_chunk == 0, f"Model with {num_layers} layer can not dist on {num_model_chunk} chunk"
2024-09-19 08:27:47 +00:00
in_dim = out_dim = 4096
2024-09-04 06:34:18 +00:00
before_init_memory = torch.cuda.memory_allocated() / 1024**3
print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};")
2024-08-29 03:16:59 +00:00
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
# data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)]
data_iter = {"hidden_states": torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)}
# input_base = [t.clone() for t in data_iter]
input_base = {k: v.clone() for k, v in data_iter.items()}
2024-08-29 03:16:59 +00:00
model_base = deepcopy(model)
if rank == 0:
# layer 0 & 7 to chunk 0 on rank0
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 0 or idx == 7:
sub_model._forward = sub_model.forward
sub_model.forward = MethodType(
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
sub_model._forward,
)
2024-08-29 03:16:59 +00:00
local_chunk.append(sub_model)
elif rank == 1:
# layer 1 & 6 to chunk 1 on rank1
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 1 or idx == 6:
sub_model._forward = sub_model.forward
sub_model.forward = MethodType(
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
sub_model._forward,
)
2024-08-29 03:16:59 +00:00
local_chunk.append(sub_model)
elif rank == 2:
# layer 2 & 5 to chunk 2 on rank2
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 2 or idx == 5:
sub_model._forward = sub_model.forward
sub_model.forward = MethodType(
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
sub_model._forward,
)
2024-08-29 03:16:59 +00:00
local_chunk.append(sub_model)
else:
# layer 3 & 4 to chunk 3 on rank3
2024-09-02 11:19:42 +00:00
local_chunk = torch.nn.ModuleList().to(rank)
2024-08-29 03:16:59 +00:00
for idx, sub_model in enumerate(model.layers):
if idx == 3 or idx == 4:
sub_model._forward = sub_model.forward
sub_model.forward = MethodType(
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
sub_model._forward,
)
2024-08-29 03:16:59 +00:00
local_chunk.append(sub_model)
# init optimizer
optimizer_base = torch.optim.SGD(model_base.parameters(), momentum=0.1, lr=1e-5)
optimizer_pp = OptimizerWrapper(torch.optim.SGD(local_chunk.parameters(), momentum=0.1, lr=1e-5))
2024-08-29 03:16:59 +00:00
2024-09-04 06:34:18 +00:00
after_init_memory = torch.cuda.memory_allocated() / 1024**3
print(f"After init Model & input: {after_init_memory :.5f} GB on device {stage_manager.get_rank()};")
2024-08-29 03:16:59 +00:00
torch.cuda.synchronize()
2024-08-30 02:47:52 +00:00
result = scheduler.forward_backward_step(
2024-08-29 03:16:59 +00:00
model_chunk=local_chunk,
data_iter=iter([data_iter]),
2024-08-29 03:16:59 +00:00
criterion=criterion,
optimizer=optimizer_pp,
return_loss=True,
return_outputs=True,
2024-08-29 03:16:59 +00:00
)
optimizer_pp.step()
2024-09-09 05:41:39 +00:00
after_pp_step_memory = torch.cuda.memory_allocated() / 1024**3
2024-09-04 06:34:18 +00:00
# assert memory
if rank != 0:
2024-09-19 08:27:47 +00:00
# w.grad: hid_dim * hid_dim * microbatch * 4(fp32) * 2 (2 layer in each stage) / 1024**3
# output: hid_dim * hid_dim * microbatch * 4(fp32) / 1024**3
# optim: state hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3
2024-09-19 08:27:47 +00:00
print(
f" num_microbatch {num_microbatch} rank {rank}: {(after_pp_step_memory - after_init_memory)} <= {(in_dim * in_dim * 4 * 5 * batch_size / 1024**3)}"
)
assert (after_pp_step_memory - after_init_memory) <= (in_dim * in_dim * 4 * 5 * batch_size / 1024**3)
2024-09-04 06:34:18 +00:00
else:
2024-09-04 10:57:38 +00:00
# rank0 will also hold output;
2024-09-09 05:41:39 +00:00
print(
2024-09-19 08:27:47 +00:00
f" num_microbatch {num_microbatch} rank {rank}: {round((after_pp_step_memory - after_init_memory), 5)} <= {round((in_dim * in_dim * 4 * 5 * batch_size / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5)}"
)
assert round((after_pp_step_memory - after_init_memory), 5) <= round(
(in_dim * in_dim * 4 * 5 * batch_size / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5
2024-09-09 05:41:39 +00:00
)
2024-08-29 03:16:59 +00:00
##########################
# Fwd bwd for base
##########################
# fwd & bwd
output_base = model_base(input_base["hidden_states"])
loss_base = criterion_base(output_base)
2024-08-29 03:16:59 +00:00
loss_base.backward()
optimizer_base.step()
##########################
# assert loss & output
##########################
# only chunk 1 stage 0 hold loss and output
if rank == 0:
assert_close(result["loss"], loss_base)
assert_close(result["outputs"]["hidden_states"], output_base)
# print(f"pp result {result}; base result loss:{loss_base} output_base:{output_base} ")
2024-08-29 03:16:59 +00:00
##########################
# assert weight
##########################
if rank == 0:
# layer 0
assert_close(local_chunk[0].weight, model_base.layers[0].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[0].weight.grad)
# layer 7
assert_close(local_chunk[1].weight, model_base.layers[7].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[7].weight.grad)
if rank == 1:
# layer 1
assert_close(local_chunk[0].weight, model_base.layers[1].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[1].weight.grad)
# layer 6
assert_close(local_chunk[1].weight, model_base.layers[6].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[6].weight.grad)
if rank == 2:
# layer 2
assert_close(local_chunk[0].weight, model_base.layers[2].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[2].weight.grad)
# layer 5
assert_close(local_chunk[1].weight, model_base.layers[5].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[5].weight.grad)
if rank == 3:
# layer 3
assert_close(local_chunk[0].weight, model_base.layers[3].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[3].weight.grad)
# layer 4
assert_close(local_chunk[1].weight, model_base.layers[4].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad)
##########################
# assert optim state
##########################
optim_base_state = optimizer_base.state_dict()["state"]
optim_pp_state = optimizer_pp.state_dict()["state"]
optim_base_param_groups = optimizer_base.state_dict()["param_groups"][0]
optim_pp_param_groups = optimizer_pp.state_dict()["param_groups"][0]
# if rank == 0:
# print(f"optim_base_state {optim_base_state}")
# assert param group
for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_param_groups.items(), optim_pp_param_groups.items()):
if key_base == key_pp:
if key_base != "params":
assert val_base == val_pp
else:
# BUG:
# param_base: [0, 1, 2, 3, 4, 5, 6, 7];
# params pp: [0, 1];
assert val_base[:2] == val_pp
# assert state
assert_close(optim_pp_state[0]["momentum_buffer"], optim_base_state[2 * rank]["momentum_buffer"])
assert_close(optim_pp_state[1]["momentum_buffer"], optim_base_state[2 * rank + 1]["momentum_buffer"])
2024-09-02 09:50:47 +00:00
# TODO:4) support Hybrid base 3)
2024-09-02 10:00:43 +00:00
def run_with_hybridplugin(test_config):
pass
# TODO:5) support MoEHybrid base 3)
2024-09-02 09:50:47 +00:00
@parameterize(
"test_config",
[
{
2024-09-12 02:51:46 +00:00
"pp_style": "zbv",
2024-09-02 09:50:47 +00:00
"tp_size": 1,
"ep_size": 1,
2024-09-02 09:50:47 +00:00
"pp_size": 4,
"num_microbatches": 4,
"zero_stage": 1,
"precision": "bf16",
2024-09-12 02:51:46 +00:00
"num_model_chunks": 2,
2024-09-02 09:50:47 +00:00
},
],
)
2024-09-02 10:00:43 +00:00
def run_with_moehybridplugin(test_config):
2024-09-12 02:51:46 +00:00
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
# test_config["use_lazy_init"] = False
test_config["initial_scale"] = 2**16
2024-09-02 10:00:43 +00:00
model_list = [
"transformers_bert",
]
2024-09-12 02:51:46 +00:00
clear_layout_converter()
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
if name in model_list:
# base param
model = model_fn()
data = data_gen_fn()
print(f"data {data}")
2024-09-12 02:51:46 +00:00
criterion = loss_fn
optimizer = torch.optim.SGD(model.parameters(), momentum=0.1, lr=1e-5)
output = model(**data)
loss = criterion(output)
loss.backward()
optimizer.step()
print(f"output {output}")
# # pp param
# model_pp = deepcopy(model)
# data_pp = deepcopy(data)
# optimizer_pp = OptimizerWrapper(torch.optim.SGD(model_pp.parameters(), momentum=0.1, lr=1e-5))
# # init pipeline graph
# h, a, s = model.config.hidden_size, model.config.num_attention_heads, 1024
# mem_f = 34 * h + 5 * a * s
# mem_w = -32 * h
# mem_b = -mem_w - mem_f
# graph = PipelineGraph(
# n_stage=test_config["pp_size"],
# n_micro=test_config["num_microbatches"],
# f_cost=1,
# b_cost=1,
# w_cost=1,
# c_cost=1,
# f_mem=mem_f,
# b_mem=mem_b,
# w_mem=mem_w,
# # max_mem=mem_f * (p * 2 + m_offset),
# )
# zbv_schedule = graph.get_v_schedule()
# test_config["scheduler_nodes"] = zbv_schedule
# plugin = MoeHybridParallelPlugin(
# **test_config
# )
# model_pp, optimizer_pp, criterion, data_pp, _ = plugin.configure(
2024-09-12 02:51:46 +00:00
# model = model_pp,
# optimizer = optimizer_pp,
# criterion = criterion,
# dataloader = data_pp,
# )
# output_pp = plugin.execute_pipeline(
# data_iter=iter(data),
# model=model,
# criterion=criterion,
# optimizer=optimizer,
# return_loss = True,
# return_outputs = True,
# )
2024-08-30 02:47:52 +00:00
2024-09-02 09:50:47 +00:00
# TODO:6) support booster & Hybrid base 4)
2024-09-02 09:50:47 +00:00
# TODO:7) support booster & MoEHybrid base 4)
@parameterize(
"test_config",
[
{
"pp_style": "zbv",
"tp_size": 1,
"ep_size": 1,
"pp_size": 4,
"num_microbatches": 4,
"zero_stage": 1,
"precision": "bf16",
"num_model_chunks": 2,
},
],
)
def run_with_booster_moehybridplugin(test_config):
pass
2024-08-30 02:47:52 +00:00
2024-09-02 09:50:47 +00:00
def run_dist(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
2024-09-02 11:19:42 +00:00
# run_fwd_bwd_iter_input()
run_fwd_bwd_vschedule_with_optim()
2024-09-02 11:19:42 +00:00
# run_with_moehybridplugin()
# run_with_booster_moehybridplugin()
2024-08-30 02:47:52 +00:00
2024-08-27 07:31:58 +00:00
@pytest.mark.dist
@rerun_if_address_is_in_use()
2024-09-02 09:50:47 +00:00
def test_pp():
spawn(
2024-09-02 09:50:47 +00:00
run_dist,
nprocs=4,
)
if __name__ == "__main__":
2024-09-02 09:50:47 +00:00
test_pp()