ColossalAI/tests/test_pipeline/test_schedule/test_zerobubble_pp.py

784 lines
34 KiB
Python
Raw Normal View History

from copy import deepcopy
from typing import Tuple
2024-08-27 07:31:58 +00:00
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
2024-08-23 06:04:12 +00:00
from torch.testing import assert_close
import colossalai
from colossalai.cluster import ProcessGroupMesh
2024-08-29 03:16:59 +00:00
from colossalai.interface import OptimizerWrapper
2024-09-02 09:50:47 +00:00
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
from colossalai.pipeline.stage_manager import PipelineStageManager
2024-09-02 10:00:43 +00:00
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
2024-09-02 09:50:47 +00:00
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
2024-09-02 10:00:43 +00:00
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_weight,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
class MlpModel(nn.Module):
def __init__(self, in_dim, out_dim, num_layers):
super().__init__()
self.layers = nn.ModuleList([nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
num_params = 0
num_params_trainable = 0
for p in model.parameters():
num_params += p.numel()
if p.requires_grad:
num_params_trainable += p.numel()
return num_params, num_params_trainable
# 1) Test manual v_schedule with multiple microbatch
2024-09-02 09:50:47 +00:00
@parameterize(
"test_config",
[
{
"batch_size": 4,
"tp_size": 1,
"pp_size": 4,
"num_microbatches": 4,
"zero_stage": 1,
"precision": "bf16",
"num_model_chunk": 4,
},
],
)
def run_fwd_bwd_iter_input(test_config):
# init dist
rank = dist.get_rank()
2024-09-02 09:50:47 +00:00
pp_size = test_config["pp_size"]
pg_mesh = ProcessGroupMesh(pp_size)
2024-09-02 09:50:47 +00:00
num_microbatch = test_config["num_microbatches"]
num_model_chunk = test_config["num_model_chunk"]
# stage_manager
2024-09-02 09:50:47 +00:00
stage_manager = PipelineStageManager(
pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk
)
# schedule list
zbv_schedule = [
# stage 0
[
# microbatch 0
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=0),
ScheduledNode(type="F", chunk=0, stage=0, minibatch=0),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=0),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=0),
ScheduledNode(type="F", chunk=1, stage=0, minibatch=0),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=0),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=0),
ScheduledNode(type="B", chunk=1, stage=0, minibatch=0),
ScheduledNode(type="W", chunk=1, stage=0, minibatch=0),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=0),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=0),
ScheduledNode(type="B", chunk=0, stage=0, minibatch=0),
ScheduledNode(type="W", chunk=0, stage=0, minibatch=0),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=0),
# microbatch 1
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=1),
ScheduledNode(type="F", chunk=0, stage=0, minibatch=1),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=1),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=1),
ScheduledNode(type="F", chunk=1, stage=0, minibatch=1),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=1),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=1),
ScheduledNode(type="B", chunk=1, stage=0, minibatch=1),
ScheduledNode(type="W", chunk=1, stage=0, minibatch=1),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=1),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=1),
ScheduledNode(type="B", chunk=0, stage=0, minibatch=1),
ScheduledNode(type="W", chunk=0, stage=0, minibatch=1),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=1),
# microbatch 2
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=2),
ScheduledNode(type="F", chunk=0, stage=0, minibatch=2),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=2),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=2),
ScheduledNode(type="F", chunk=1, stage=0, minibatch=2),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=2),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=2),
ScheduledNode(type="B", chunk=1, stage=0, minibatch=2),
ScheduledNode(type="W", chunk=1, stage=0, minibatch=2),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=2),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=2),
ScheduledNode(type="B", chunk=0, stage=0, minibatch=2),
ScheduledNode(type="W", chunk=0, stage=0, minibatch=2),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=2),
# microbatch 3
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=3),
ScheduledNode(type="F", chunk=0, stage=0, minibatch=3),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=3),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=3),
ScheduledNode(type="F", chunk=1, stage=0, minibatch=3),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=3),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=3),
ScheduledNode(type="B", chunk=1, stage=0, minibatch=3),
ScheduledNode(type="W", chunk=1, stage=0, minibatch=3),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=3),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=3),
ScheduledNode(type="B", chunk=0, stage=0, minibatch=3),
ScheduledNode(type="W", chunk=0, stage=0, minibatch=3),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=3),
],
# stage 1
[
# microbatch 0
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=0),
ScheduledNode(type="F", chunk=0, stage=1, minibatch=0),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=0),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=0),
ScheduledNode(type="F", chunk=1, stage=1, minibatch=0),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=0),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=0),
ScheduledNode(type="B", chunk=1, stage=1, minibatch=0),
ScheduledNode(type="W", chunk=1, stage=1, minibatch=0),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=0),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=0),
ScheduledNode(type="B", chunk=0, stage=1, minibatch=0),
ScheduledNode(type="W", chunk=0, stage=1, minibatch=0),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=0),
# microbatch 1
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=1),
ScheduledNode(type="F", chunk=0, stage=1, minibatch=1),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=1),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=1),
ScheduledNode(type="F", chunk=1, stage=1, minibatch=1),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=1),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=1),
ScheduledNode(type="B", chunk=1, stage=1, minibatch=1),
ScheduledNode(type="W", chunk=1, stage=1, minibatch=1),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=1),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=1),
ScheduledNode(type="B", chunk=0, stage=1, minibatch=1),
ScheduledNode(type="W", chunk=0, stage=1, minibatch=1),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=1),
# microbatch 2
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=2),
ScheduledNode(type="F", chunk=0, stage=1, minibatch=2),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=2),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=2),
ScheduledNode(type="F", chunk=1, stage=1, minibatch=2),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=2),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=2),
ScheduledNode(type="B", chunk=1, stage=1, minibatch=2),
ScheduledNode(type="W", chunk=1, stage=1, minibatch=2),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=2),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=2),
ScheduledNode(type="B", chunk=0, stage=1, minibatch=2),
ScheduledNode(type="W", chunk=0, stage=1, minibatch=2),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=2),
# microbatch 3
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=3),
ScheduledNode(type="F", chunk=0, stage=1, minibatch=3),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=3),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=3),
ScheduledNode(type="F", chunk=1, stage=1, minibatch=3),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=3),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=3),
ScheduledNode(type="B", chunk=1, stage=1, minibatch=3),
ScheduledNode(type="W", chunk=1, stage=1, minibatch=3),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=3),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=3),
ScheduledNode(type="B", chunk=0, stage=1, minibatch=3),
ScheduledNode(type="W", chunk=0, stage=1, minibatch=3),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=3),
],
# stage 2
[
# microbatch 0
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=0),
ScheduledNode(type="F", chunk=0, stage=2, minibatch=0),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=0),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=0),
ScheduledNode(type="F", chunk=1, stage=2, minibatch=0),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=0),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=0),
ScheduledNode(type="B", chunk=1, stage=2, minibatch=0),
ScheduledNode(type="W", chunk=1, stage=2, minibatch=0),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=0),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=0),
ScheduledNode(type="B", chunk=0, stage=2, minibatch=0),
ScheduledNode(type="W", chunk=0, stage=2, minibatch=0),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=0),
# microbatch 1
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=1),
ScheduledNode(type="F", chunk=0, stage=2, minibatch=1),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=1),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=1),
ScheduledNode(type="F", chunk=1, stage=2, minibatch=1),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=1),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=1),
ScheduledNode(type="B", chunk=1, stage=2, minibatch=1),
ScheduledNode(type="W", chunk=1, stage=2, minibatch=1),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=1),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=1),
ScheduledNode(type="B", chunk=0, stage=2, minibatch=1),
ScheduledNode(type="W", chunk=0, stage=2, minibatch=1),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=1),
# microbatch 2
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=2),
ScheduledNode(type="F", chunk=0, stage=2, minibatch=2),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=2),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=2),
ScheduledNode(type="F", chunk=1, stage=2, minibatch=2),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=2),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=2),
ScheduledNode(type="B", chunk=1, stage=2, minibatch=2),
ScheduledNode(type="W", chunk=1, stage=2, minibatch=2),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=2),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=2),
ScheduledNode(type="B", chunk=0, stage=2, minibatch=2),
ScheduledNode(type="W", chunk=0, stage=2, minibatch=2),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=2),
# microbatch 3
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=3),
ScheduledNode(type="F", chunk=0, stage=2, minibatch=3),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=3),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=3),
ScheduledNode(type="F", chunk=1, stage=2, minibatch=3),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=3),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=3),
ScheduledNode(type="B", chunk=1, stage=2, minibatch=3),
ScheduledNode(type="W", chunk=1, stage=2, minibatch=3),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=3),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=3),
ScheduledNode(type="B", chunk=0, stage=2, minibatch=3),
ScheduledNode(type="W", chunk=0, stage=2, minibatch=3),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=3),
],
# stage 3
[
# microbatch 0
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=0),
ScheduledNode(type="F", chunk=0, stage=3, minibatch=0),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=0),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=0),
ScheduledNode(type="F", chunk=1, stage=3, minibatch=0),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=0),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=0),
ScheduledNode(type="B", chunk=1, stage=3, minibatch=0),
ScheduledNode(type="W", chunk=1, stage=3, minibatch=0),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=0),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=0),
ScheduledNode(type="B", chunk=0, stage=3, minibatch=0),
ScheduledNode(type="W", chunk=0, stage=3, minibatch=0),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=0),
# microbatch 1
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=1),
ScheduledNode(type="F", chunk=0, stage=3, minibatch=1),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=1),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=1),
ScheduledNode(type="F", chunk=1, stage=3, minibatch=1),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=1),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=1),
ScheduledNode(type="B", chunk=1, stage=3, minibatch=1),
ScheduledNode(type="W", chunk=1, stage=3, minibatch=1),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=1),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=1),
ScheduledNode(type="B", chunk=0, stage=3, minibatch=1),
ScheduledNode(type="W", chunk=0, stage=3, minibatch=1),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=1),
# microbatch 2
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=2),
ScheduledNode(type="F", chunk=0, stage=3, minibatch=2),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=2),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=2),
ScheduledNode(type="F", chunk=1, stage=3, minibatch=2),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=2),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=2),
ScheduledNode(type="B", chunk=1, stage=3, minibatch=2),
ScheduledNode(type="W", chunk=1, stage=3, minibatch=2),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=2),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=2),
ScheduledNode(type="B", chunk=0, stage=3, minibatch=2),
ScheduledNode(type="W", chunk=0, stage=3, minibatch=2),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=2),
# microbatch 3
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=3),
ScheduledNode(type="F", chunk=0, stage=3, minibatch=3),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=3),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=3),
ScheduledNode(type="F", chunk=1, stage=3, minibatch=3),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=3),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=3),
ScheduledNode(type="B", chunk=1, stage=3, minibatch=3),
ScheduledNode(type="W", chunk=1, stage=3, minibatch=3),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=3),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=3),
ScheduledNode(type="B", chunk=0, stage=3, minibatch=3),
ScheduledNode(type="W", chunk=0, stage=3, minibatch=3),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=3),
],
]
scheduler = ZeroBubbleVPipeScheduler(
2024-09-02 09:50:47 +00:00
schedule=zbv_schedule, # hint: send whole schedule or local schedule only ?
stage_manager=stage_manager,
num_model_chunks=pp_size,
num_microbatch=num_microbatch,
overlap_p2p=False,
)
# loss func
def criterion(x, *args, **kwargs):
return (x * x).mean()
# init model and input
batch_size = 4
num_layers = 8
in_dim = out_dim = 8
2024-08-27 06:37:26 +00:00
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)]
input_base = [t.clone() for t in data_iter]
2024-08-27 06:37:26 +00:00
model_base = deepcopy(model)
if rank == 0:
# layer 0 & 7 to chunk 0 on rank0
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 0 or idx == 7:
local_chunk.append(sub_model)
elif rank == 1:
# layer 1 & 6 to chunk 1 on rank1
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 1 or idx == 6:
local_chunk.append(sub_model)
elif rank == 2:
# layer 2 & 5 to chunk 2 on rank2
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 2 or idx == 5:
local_chunk.append(sub_model)
else:
# layer 3 & 4 to chunk 3 on rank3
local_chunk = torch.nn.Sequential().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 3 or idx == 4:
local_chunk.append(sub_model)
2024-09-02 09:50:47 +00:00
# init optimizer
optimizer_base = torch.optim.SGD(model_base.parameters(), lr=1e-5)
optimizer_pp = OptimizerWrapper(torch.optim.SGD(local_chunk.parameters(), lr=1e-5))
print(
f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
)
torch.cuda.synchronize()
2024-09-02 09:50:47 +00:00
result = scheduler.forward_backward_step(
model_chunk=local_chunk,
data_iter=iter(data_iter),
criterion=criterion,
2024-09-02 09:50:47 +00:00
optimizer=optimizer_pp,
return_loss=True,
return_outputs=True,
)
2024-09-02 09:50:47 +00:00
optimizer_pp.step()
##########################
# Fwd bwd for base
##########################
# fwd & bwd
output_base = model_base(input_base[0])
2024-08-27 06:37:26 +00:00
loss_base = criterion(output_base)
loss_base.backward()
2024-09-02 09:50:47 +00:00
optimizer_base.step()
2024-08-27 06:37:26 +00:00
print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
##########################
# assert weight
##########################
if rank == 0:
# layer 0
assert_close(local_chunk[0].weight, model_base.layers[0].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[0].weight.grad)
# layer 7
assert_close(local_chunk[1].weight, model_base.layers[7].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[7].weight.grad)
if rank == 1:
# layer 1
assert_close(local_chunk[0].weight, model_base.layers[1].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[1].weight.grad)
# layer 6
assert_close(local_chunk[1].weight, model_base.layers[6].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[6].weight.grad)
if rank == 2:
# layer 2
assert_close(local_chunk[0].weight, model_base.layers[2].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[2].weight.grad)
# layer 5
assert_close(local_chunk[1].weight, model_base.layers[5].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[5].weight.grad)
if rank == 3:
# layer 3
assert_close(local_chunk[0].weight, model_base.layers[3].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[3].weight.grad)
# layer 4
assert_close(local_chunk[1].weight, model_base.layers[4].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad)
2024-09-02 09:50:47 +00:00
# 2) add optimizer base 1)
@parameterize(
"test_config",
[
{
"batch_size": 4,
"tp_size": 1,
"pp_size": 4,
"num_microbatches": 4,
"zero_stage": 1,
"precision": "bf16",
"num_model_chunk": 4,
},
],
)
def run_fwd_bwd_vschedule_with_optim(test_config):
2024-08-29 03:16:59 +00:00
# init dist
rank = dist.get_rank()
2024-09-02 09:50:47 +00:00
pp_size = test_config["pp_size"]
2024-08-29 03:16:59 +00:00
pg_mesh = ProcessGroupMesh(pp_size)
2024-09-02 09:50:47 +00:00
num_microbatch = test_config["num_microbatches"]
num_model_chunk = test_config["num_model_chunk"]
2024-08-29 03:16:59 +00:00
# stage_manager
stage_manager = PipelineStageManager(
pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk
)
h, a, s = 4096, 32, 1024
mem_f = 34 * h + 5 * a * s
mem_w = -32 * h
mem_b = -mem_w - mem_f
graph = PipelineGraph(
2024-09-02 09:50:47 +00:00
n_stage=pp_size,
2024-08-29 03:16:59 +00:00
n_micro=num_microbatch,
2024-08-30 02:47:52 +00:00
f_cost=1,
b_cost=1,
w_cost=1,
c_cost=1,
2024-08-29 03:16:59 +00:00
f_mem=mem_f,
b_mem=mem_b,
w_mem=mem_w,
# max_mem=mem_f * (p * 2 + m_offset),
)
zbv_schedule = graph.get_v_schedule()
scheduler = ZeroBubbleVPipeScheduler(
2024-09-02 09:50:47 +00:00
schedule=zbv_schedule, # hint: send whole schedule or local schedule only ?
2024-08-29 03:16:59 +00:00
stage_manager=stage_manager,
num_model_chunks=num_model_chunk,
num_microbatch=num_microbatch,
overlap_p2p=False,
)
# init loss func
def criterion(x, *args, **kwargs):
return (x * x).mean()
# init model and input
2024-09-02 09:50:47 +00:00
batch_size = test_config["batch_size"]
2024-08-29 03:16:59 +00:00
num_layers = 8
assert num_layers % num_model_chunk == 0, f"Model with {num_layers} layer can not dist on {num_model_chunk} chunk"
in_dim = out_dim = 16
2024-08-29 03:16:59 +00:00
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)]
input_base = [t.clone() for t in data_iter]
model_base = deepcopy(model)
if rank == 0:
# layer 0 & 7 to chunk 0 on rank0
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 0 or idx == 7:
local_chunk.append(sub_model)
elif rank == 1:
# layer 1 & 6 to chunk 1 on rank1
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 1 or idx == 6:
local_chunk.append(sub_model)
elif rank == 2:
# layer 2 & 5 to chunk 2 on rank2
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 2 or idx == 5:
local_chunk.append(sub_model)
else:
# layer 3 & 4 to chunk 3 on rank3
local_chunk = torch.nn.Sequential().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 3 or idx == 4:
local_chunk.append(sub_model)
# init optimizer
optimizer_base = torch.optim.SGD(model_base.parameters(), lr=1e-5)
optimizer_pp = OptimizerWrapper(torch.optim.SGD(local_chunk.parameters(), lr=1e-5))
print(
f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
)
torch.cuda.synchronize()
2024-08-30 02:47:52 +00:00
result = scheduler.forward_backward_step(
2024-08-29 03:16:59 +00:00
model_chunk=local_chunk,
data_iter=iter(data_iter),
criterion=criterion,
optimizer=optimizer_pp,
return_loss=True,
return_outputs=True,
2024-08-29 03:16:59 +00:00
)
optimizer_pp.step()
2024-08-29 03:16:59 +00:00
##########################
# Fwd bwd for base
##########################
# fwd & bwd
output_base = model_base(input_base[0])
loss_base = criterion(output_base)
loss_base.backward()
optimizer_base.step()
print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
##########################
# assert loss & output
##########################
# only chunk 1 stage 0 hold loss and output
if rank == 0:
assert_close(result["loss"], loss_base)
assert_close(result["outputs"], output_base)
# print(f"pp result {result}; base result loss:{loss_base} output_base:{output_base} ")
2024-08-29 03:16:59 +00:00
##########################
# assert weight
##########################
if rank == 0:
# layer 0
assert_close(local_chunk[0].weight, model_base.layers[0].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[0].weight.grad)
# layer 7
assert_close(local_chunk[1].weight, model_base.layers[7].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[7].weight.grad)
if rank == 1:
# layer 1
assert_close(local_chunk[0].weight, model_base.layers[1].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[1].weight.grad)
# layer 6
assert_close(local_chunk[1].weight, model_base.layers[6].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[6].weight.grad)
if rank == 2:
# layer 2
assert_close(local_chunk[0].weight, model_base.layers[2].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[2].weight.grad)
# layer 5
assert_close(local_chunk[1].weight, model_base.layers[5].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[5].weight.grad)
if rank == 3:
# layer 3
assert_close(local_chunk[0].weight, model_base.layers[3].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[3].weight.grad)
# layer 4
assert_close(local_chunk[1].weight, model_base.layers[4].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad)
##########################
# assert optim state
##########################
optim_base_state_dict = optimizer_base.state_dict()["param_groups"][0]
optim_pp_state_dict = optimizer_pp.state_dict()["param_groups"][0]
for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_state_dict.items(), optim_pp_state_dict.items()):
if key_base == key_pp:
if key_base != "params":
assert val_base == val_pp
else:
# BUG:
# param_base: [0, 1, 2, 3, 4, 5, 6, 7];
# params pp: [0, 1];
assert val_base[:2] == val_pp
2024-09-02 09:50:47 +00:00
# TODO:4) support Hybrid base 3)
2024-09-02 10:00:43 +00:00
def run_with_hybridplugin(test_config):
pass
# TODO:5) support MoEHybrid base 3)
2024-09-02 09:50:47 +00:00
@parameterize(
"test_config",
[
{
"batch_size": 4,
"tp_size": 1,
"pp_size": 4,
"num_microbatches": 4,
"zero_stage": 1,
"precision": "bf16",
"num_model_chunk": 4,
},
],
)
2024-09-02 10:00:43 +00:00
def run_with_moehybridplugin(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
test_config["use_lazy_init"] = False
test_config["pp_size"] = 1 # Do NOT test Pipeline Parallel
test_config["initial_scale"] = 2**16 # avoid overflow
model_list = [
"transformers_bert",
]
clear_layout_converter()
torch.set_default_dtype(torch.bfloat16)
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
if name in model_list:
(
org_model,
org_optimizer,
sharded_model,
sharded_optimizer,
criterion,
booster,
) = build_model_from_hybrid_plugin(model_fn, loss_fn, test_config, torch.optim.SGD, torch.optim.SGD)
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
)
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
bert = unwrap_model(org_model, "BertModel", "bert")
sharded_bert = unwrap_model(sharded_model, "BertModel", "bert")
weight_layer_for_check = ["encoder.layer[0].output.dense", "encoder.layer[1].output.dense"]
org_optimizer.step()
sharded_optimizer.step()
# check weights
if test_config["precision"] == "bf16":
atol, rtol = 5e-4, 5e-4
else:
atol, rtol = 5e-4, 5e-4
if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1)
# check optim states
# check_dist_optim_state(org_optimizer, sharded_optimizer.optim)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
print(f"Bert Model Zoo Test Passed")
2024-08-30 02:47:52 +00:00
2024-09-02 09:50:47 +00:00
# TODO:6) support booster & Hybrid base 4)
# TODO:7) support booster & MoEHybrid base 4)
2024-08-30 02:47:52 +00:00
2024-09-02 09:50:47 +00:00
def run_dist(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_fwd_bwd_iter_input()
run_fwd_bwd_vschedule_with_optim()
2024-08-30 02:47:52 +00:00
2024-08-27 07:31:58 +00:00
@pytest.mark.dist
@rerun_if_address_is_in_use()
2024-09-02 09:50:47 +00:00
def test_pp():
spawn(
2024-09-02 09:50:47 +00:00
run_dist,
nprocs=4,
)
if __name__ == "__main__":
2024-09-02 09:50:47 +00:00
test_pp()