[feat] add apply v_schedule graph; p & p.grad assert err exist;

pull/6034/head
duanjunwen 2024-08-27 10:29:39 +00:00
parent 8b37323f16
commit fe209164f1
2 changed files with 150 additions and 11 deletions

View File

@ -12,8 +12,8 @@ class ScheduledNode:
chunk: int
stage: int
minibatch: int
# start_time: int
# completion_time: int
start_time: int = 0
completion_time: int = 0
rollback: bool = False
@ -460,9 +460,9 @@ class PipelineGraph(object):
)
)
assert len(rollback_comm) == 0
for node in local_order_with_rollback[rank]:
print(f"Rank {rank} Node info {node}")
print(f"{node.type}-{node.minibatch}-{int(node.rollback)}", end=", ")
print()
# for node in local_order_with_rollback[rank]:
# print(f"Rank {rank} Node info {node}")
# print(f"{node.type}-{node.minibatch}-{int(node.rollback)}", end=", ")
# print()
return local_order_with_rollback

View File

@ -9,7 +9,7 @@ from torch.testing import assert_close
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.schedule.v_schedule import ScheduledNode
from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import rerun_if_address_is_in_use, spawn
@ -389,10 +389,9 @@ def test_run_fwd_bwd_iter_input(
in_dim = out_dim = 8
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
input0 = torch.rand(in_dim, out_dim, requires_grad=True).to(rank)
data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)]
[t.clone() for t in data_iter]
input_base = [t.clone() for t in data_iter]
model_base = deepcopy(model)
if rank == 0:
@ -437,7 +436,143 @@ def test_run_fwd_bwd_iter_input(
# Fwd bwd for base
##########################
# fwd & bwd
output_base = model_base(data_iter[0])
output_base = model_base(input_base[0])
loss_base = criterion(output_base)
loss_base.backward()
print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
##########################
# assert weight
##########################
if rank == 0:
# layer 0
assert_close(local_chunk[0].weight, model_base.layers[0].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[0].weight.grad)
# layer 7
assert_close(local_chunk[1].weight, model_base.layers[7].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[7].weight.grad)
if rank == 1:
# layer 1
assert_close(local_chunk[0].weight, model_base.layers[1].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[1].weight.grad)
# layer 6
assert_close(local_chunk[1].weight, model_base.layers[6].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[6].weight.grad)
if rank == 2:
# layer 2
assert_close(local_chunk[0].weight, model_base.layers[2].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[2].weight.grad)
# layer 5
assert_close(local_chunk[1].weight, model_base.layers[5].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[5].weight.grad)
if rank == 3:
# layer 3
assert_close(local_chunk[0].weight, model_base.layers[3].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[3].weight.grad)
# layer 4
assert_close(local_chunk[1].weight, model_base.layers[4].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad)
# T
def test_run_fwd_bwd_with_vschedule(
rank: int,
world_size: int,
port: int,
):
# init dist
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
rank = dist.get_rank()
pp_size = world_size
pg_mesh = ProcessGroupMesh(pp_size)
num_microbatch = 4
# stage_manager
stage_manager = PipelineStageManager(pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=pp_size)
h, a, s = 4096, 32, 1024
mem_f = 34 * h + 5 * a * s
mem_w = -32 * h
mem_b = -mem_w - mem_f
graph = PipelineGraph(
n_stage=world_size,
n_micro=num_microbatch,
f_cost=6,
b_cost=6,
w_cost=6,
c_cost=6,
f_mem=mem_f,
b_mem=mem_b,
w_mem=mem_w,
# max_mem=mem_f * (p * 2 + m_offset),
)
zbv_schedule = graph.get_v_schedule()
scheduler = ZeroBubbleVPipeScheduler(
schedule=zbv_schedule[rank], # hint: send whole schedule or local schedule only ?
stage_manager=stage_manager,
num_model_chunks=pp_size,
num_microbatch=num_microbatch,
overlap_p2p=False,
)
def criterion(x, *args, **kwargs):
return (x * x).mean()
# init model and input
batch_size = 4
num_layers = 8
in_dim = out_dim = 8
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)]
input_base = [t.clone() for t in data_iter]
model_base = deepcopy(model)
if rank == 0:
# layer 0 & 7 to chunk 0 on rank0
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 0 or idx == 7:
local_chunk.append(sub_model)
elif rank == 1:
# layer 1 & 6 to chunk 1 on rank1
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 1 or idx == 6:
local_chunk.append(sub_model)
elif rank == 2:
# layer 2 & 5 to chunk 2 on rank2
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 2 or idx == 5:
local_chunk.append(sub_model)
else:
# layer 3 & 4 to chunk 3 on rank3
local_chunk = torch.nn.Sequential().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 3 or idx == 4:
local_chunk.append(sub_model)
print(
f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
)
torch.cuda.synchronize()
scheduler.run_forward_backward(
model_chunk=local_chunk,
data_iter=iter(data_iter),
criterion=criterion,
optimizer=None,
return_loss=None,
return_outputs=None,
)
##########################
# Fwd bwd for base
##########################
# fwd & bwd
output_base = model_base(input_base[0])
loss_base = criterion(output_base)
loss_base.backward()
print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
@ -481,8 +616,12 @@ def test_run_fwd_bwd_iter_input(
# @pytest.mark.parametrize("num_model_chunk", [2])
@rerun_if_address_is_in_use()
def test_pp():
# spawn(
# test_run_fwd_bwd_iter_input,
# nprocs=4,
# )
spawn(
test_run_fwd_bwd_iter_input,
test_run_fwd_bwd_with_vschedule,
nprocs=4,
)