mirror of https://github.com/hpcaitech/ColossalAI
[feat] add apply v_schedule graph; p & p.grad assert err exist;
parent
8b37323f16
commit
fe209164f1
|
@ -12,8 +12,8 @@ class ScheduledNode:
|
|||
chunk: int
|
||||
stage: int
|
||||
minibatch: int
|
||||
# start_time: int
|
||||
# completion_time: int
|
||||
start_time: int = 0
|
||||
completion_time: int = 0
|
||||
rollback: bool = False
|
||||
|
||||
|
||||
|
@ -460,9 +460,9 @@ class PipelineGraph(object):
|
|||
)
|
||||
)
|
||||
assert len(rollback_comm) == 0
|
||||
for node in local_order_with_rollback[rank]:
|
||||
print(f"Rank {rank} Node info {node}")
|
||||
print(f"{node.type}-{node.minibatch}-{int(node.rollback)}", end=", ")
|
||||
print()
|
||||
# for node in local_order_with_rollback[rank]:
|
||||
# print(f"Rank {rank} Node info {node}")
|
||||
# print(f"{node.type}-{node.minibatch}-{int(node.rollback)}", end=", ")
|
||||
# print()
|
||||
|
||||
return local_order_with_rollback
|
||||
|
|
|
@ -9,7 +9,7 @@ from torch.testing import assert_close
|
|||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.pipeline.schedule.v_schedule import ScheduledNode
|
||||
from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode
|
||||
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
@ -389,10 +389,9 @@ def test_run_fwd_bwd_iter_input(
|
|||
in_dim = out_dim = 8
|
||||
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
|
||||
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
|
||||
input0 = torch.rand(in_dim, out_dim, requires_grad=True).to(rank)
|
||||
data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)]
|
||||
|
||||
[t.clone() for t in data_iter]
|
||||
input_base = [t.clone() for t in data_iter]
|
||||
model_base = deepcopy(model)
|
||||
|
||||
if rank == 0:
|
||||
|
@ -437,7 +436,143 @@ def test_run_fwd_bwd_iter_input(
|
|||
# Fwd bwd for base
|
||||
##########################
|
||||
# fwd & bwd
|
||||
output_base = model_base(data_iter[0])
|
||||
output_base = model_base(input_base[0])
|
||||
loss_base = criterion(output_base)
|
||||
loss_base.backward()
|
||||
print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||
|
||||
##########################
|
||||
# assert weight
|
||||
##########################
|
||||
if rank == 0:
|
||||
# layer 0
|
||||
assert_close(local_chunk[0].weight, model_base.layers[0].weight)
|
||||
assert_close(local_chunk[0].weight.grad, model_base.layers[0].weight.grad)
|
||||
# layer 7
|
||||
assert_close(local_chunk[1].weight, model_base.layers[7].weight)
|
||||
assert_close(local_chunk[1].weight.grad, model_base.layers[7].weight.grad)
|
||||
if rank == 1:
|
||||
# layer 1
|
||||
assert_close(local_chunk[0].weight, model_base.layers[1].weight)
|
||||
assert_close(local_chunk[0].weight.grad, model_base.layers[1].weight.grad)
|
||||
# layer 6
|
||||
assert_close(local_chunk[1].weight, model_base.layers[6].weight)
|
||||
assert_close(local_chunk[1].weight.grad, model_base.layers[6].weight.grad)
|
||||
if rank == 2:
|
||||
# layer 2
|
||||
assert_close(local_chunk[0].weight, model_base.layers[2].weight)
|
||||
assert_close(local_chunk[0].weight.grad, model_base.layers[2].weight.grad)
|
||||
# layer 5
|
||||
assert_close(local_chunk[1].weight, model_base.layers[5].weight)
|
||||
assert_close(local_chunk[1].weight.grad, model_base.layers[5].weight.grad)
|
||||
if rank == 3:
|
||||
# layer 3
|
||||
assert_close(local_chunk[0].weight, model_base.layers[3].weight)
|
||||
assert_close(local_chunk[0].weight.grad, model_base.layers[3].weight.grad)
|
||||
# layer 4
|
||||
assert_close(local_chunk[1].weight, model_base.layers[4].weight)
|
||||
assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad)
|
||||
|
||||
|
||||
# T
|
||||
def test_run_fwd_bwd_with_vschedule(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
port: int,
|
||||
):
|
||||
# init dist
|
||||
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
rank = dist.get_rank()
|
||||
pp_size = world_size
|
||||
pg_mesh = ProcessGroupMesh(pp_size)
|
||||
num_microbatch = 4
|
||||
# stage_manager
|
||||
stage_manager = PipelineStageManager(pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=pp_size)
|
||||
|
||||
h, a, s = 4096, 32, 1024
|
||||
mem_f = 34 * h + 5 * a * s
|
||||
mem_w = -32 * h
|
||||
mem_b = -mem_w - mem_f
|
||||
graph = PipelineGraph(
|
||||
n_stage=world_size,
|
||||
n_micro=num_microbatch,
|
||||
f_cost=6,
|
||||
b_cost=6,
|
||||
w_cost=6,
|
||||
c_cost=6,
|
||||
f_mem=mem_f,
|
||||
b_mem=mem_b,
|
||||
w_mem=mem_w,
|
||||
# max_mem=mem_f * (p * 2 + m_offset),
|
||||
)
|
||||
|
||||
zbv_schedule = graph.get_v_schedule()
|
||||
|
||||
scheduler = ZeroBubbleVPipeScheduler(
|
||||
schedule=zbv_schedule[rank], # hint: send whole schedule or local schedule only ?
|
||||
stage_manager=stage_manager,
|
||||
num_model_chunks=pp_size,
|
||||
num_microbatch=num_microbatch,
|
||||
overlap_p2p=False,
|
||||
)
|
||||
|
||||
def criterion(x, *args, **kwargs):
|
||||
return (x * x).mean()
|
||||
|
||||
# init model and input
|
||||
batch_size = 4
|
||||
num_layers = 8
|
||||
in_dim = out_dim = 8
|
||||
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
|
||||
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
|
||||
data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)]
|
||||
|
||||
input_base = [t.clone() for t in data_iter]
|
||||
model_base = deepcopy(model)
|
||||
|
||||
if rank == 0:
|
||||
# layer 0 & 7 to chunk 0 on rank0
|
||||
local_chunk = torch.nn.ModuleList().to(rank)
|
||||
for idx, sub_model in enumerate(model.layers):
|
||||
if idx == 0 or idx == 7:
|
||||
local_chunk.append(sub_model)
|
||||
elif rank == 1:
|
||||
# layer 1 & 6 to chunk 1 on rank1
|
||||
local_chunk = torch.nn.ModuleList().to(rank)
|
||||
for idx, sub_model in enumerate(model.layers):
|
||||
if idx == 1 or idx == 6:
|
||||
local_chunk.append(sub_model)
|
||||
elif rank == 2:
|
||||
# layer 2 & 5 to chunk 2 on rank2
|
||||
local_chunk = torch.nn.ModuleList().to(rank)
|
||||
for idx, sub_model in enumerate(model.layers):
|
||||
if idx == 2 or idx == 5:
|
||||
local_chunk.append(sub_model)
|
||||
else:
|
||||
# layer 3 & 4 to chunk 3 on rank3
|
||||
local_chunk = torch.nn.Sequential().to(rank)
|
||||
for idx, sub_model in enumerate(model.layers):
|
||||
if idx == 3 or idx == 4:
|
||||
local_chunk.append(sub_model)
|
||||
print(
|
||||
f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
scheduler.run_forward_backward(
|
||||
model_chunk=local_chunk,
|
||||
data_iter=iter(data_iter),
|
||||
criterion=criterion,
|
||||
optimizer=None,
|
||||
return_loss=None,
|
||||
return_outputs=None,
|
||||
)
|
||||
|
||||
##########################
|
||||
# Fwd bwd for base
|
||||
##########################
|
||||
# fwd & bwd
|
||||
output_base = model_base(input_base[0])
|
||||
loss_base = criterion(output_base)
|
||||
loss_base.backward()
|
||||
print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||
|
@ -481,8 +616,12 @@ def test_run_fwd_bwd_iter_input(
|
|||
# @pytest.mark.parametrize("num_model_chunk", [2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_pp():
|
||||
# spawn(
|
||||
# test_run_fwd_bwd_iter_input,
|
||||
# nprocs=4,
|
||||
# )
|
||||
spawn(
|
||||
test_run_fwd_bwd_iter_input,
|
||||
test_run_fwd_bwd_with_vschedule,
|
||||
nprocs=4,
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue