From fe209164f1cb96de0c8a834736466bbd27fc5ce9 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 27 Aug 2024 10:29:39 +0000 Subject: [PATCH] [feat] add apply v_schedule graph; p & p.grad assert err exist; --- colossalai/pipeline/schedule/v_schedule.py | 12 +- .../test_schedule/test_zerobubble_pp.py | 149 +++++++++++++++++- 2 files changed, 150 insertions(+), 11 deletions(-) diff --git a/colossalai/pipeline/schedule/v_schedule.py b/colossalai/pipeline/schedule/v_schedule.py index f1ea3f61e..b5c255e50 100644 --- a/colossalai/pipeline/schedule/v_schedule.py +++ b/colossalai/pipeline/schedule/v_schedule.py @@ -12,8 +12,8 @@ class ScheduledNode: chunk: int stage: int minibatch: int - # start_time: int - # completion_time: int + start_time: int = 0 + completion_time: int = 0 rollback: bool = False @@ -460,9 +460,9 @@ class PipelineGraph(object): ) ) assert len(rollback_comm) == 0 - for node in local_order_with_rollback[rank]: - print(f"Rank {rank} Node info {node}") - print(f"{node.type}-{node.minibatch}-{int(node.rollback)}", end=", ") - print() + # for node in local_order_with_rollback[rank]: + # print(f"Rank {rank} Node info {node}") + # print(f"{node.type}-{node.minibatch}-{int(node.rollback)}", end=", ") + # print() return local_order_with_rollback diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 40aedfa47..605524a88 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -9,7 +9,7 @@ from torch.testing import assert_close import colossalai from colossalai.cluster import ProcessGroupMesh -from colossalai.pipeline.schedule.v_schedule import ScheduledNode +from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.testing import rerun_if_address_is_in_use, spawn @@ -389,10 +389,9 @@ def test_run_fwd_bwd_iter_input( in_dim = out_dim = 8 print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};") model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank) - input0 = torch.rand(in_dim, out_dim, requires_grad=True).to(rank) data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)] - [t.clone() for t in data_iter] + input_base = [t.clone() for t in data_iter] model_base = deepcopy(model) if rank == 0: @@ -437,7 +436,143 @@ def test_run_fwd_bwd_iter_input( # Fwd bwd for base ########################## # fwd & bwd - output_base = model_base(data_iter[0]) + output_base = model_base(input_base[0]) + loss_base = criterion(output_base) + loss_base.backward() + print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + ########################## + # assert weight + ########################## + if rank == 0: + # layer 0 + assert_close(local_chunk[0].weight, model_base.layers[0].weight) + assert_close(local_chunk[0].weight.grad, model_base.layers[0].weight.grad) + # layer 7 + assert_close(local_chunk[1].weight, model_base.layers[7].weight) + assert_close(local_chunk[1].weight.grad, model_base.layers[7].weight.grad) + if rank == 1: + # layer 1 + assert_close(local_chunk[0].weight, model_base.layers[1].weight) + assert_close(local_chunk[0].weight.grad, model_base.layers[1].weight.grad) + # layer 6 + assert_close(local_chunk[1].weight, model_base.layers[6].weight) + assert_close(local_chunk[1].weight.grad, model_base.layers[6].weight.grad) + if rank == 2: + # layer 2 + assert_close(local_chunk[0].weight, model_base.layers[2].weight) + assert_close(local_chunk[0].weight.grad, model_base.layers[2].weight.grad) + # layer 5 + assert_close(local_chunk[1].weight, model_base.layers[5].weight) + assert_close(local_chunk[1].weight.grad, model_base.layers[5].weight.grad) + if rank == 3: + # layer 3 + assert_close(local_chunk[0].weight, model_base.layers[3].weight) + assert_close(local_chunk[0].weight.grad, model_base.layers[3].weight.grad) + # layer 4 + assert_close(local_chunk[1].weight, model_base.layers[4].weight) + assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad) + + +# T +def test_run_fwd_bwd_with_vschedule( + rank: int, + world_size: int, + port: int, +): + # init dist + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") + rank = dist.get_rank() + pp_size = world_size + pg_mesh = ProcessGroupMesh(pp_size) + num_microbatch = 4 + # stage_manager + stage_manager = PipelineStageManager(pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=pp_size) + + h, a, s = 4096, 32, 1024 + mem_f = 34 * h + 5 * a * s + mem_w = -32 * h + mem_b = -mem_w - mem_f + graph = PipelineGraph( + n_stage=world_size, + n_micro=num_microbatch, + f_cost=6, + b_cost=6, + w_cost=6, + c_cost=6, + f_mem=mem_f, + b_mem=mem_b, + w_mem=mem_w, + # max_mem=mem_f * (p * 2 + m_offset), + ) + + zbv_schedule = graph.get_v_schedule() + + scheduler = ZeroBubbleVPipeScheduler( + schedule=zbv_schedule[rank], # hint: send whole schedule or local schedule only ? + stage_manager=stage_manager, + num_model_chunks=pp_size, + num_microbatch=num_microbatch, + overlap_p2p=False, + ) + + def criterion(x, *args, **kwargs): + return (x * x).mean() + + # init model and input + batch_size = 4 + num_layers = 8 + in_dim = out_dim = 8 + print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};") + model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank) + data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)] + + input_base = [t.clone() for t in data_iter] + model_base = deepcopy(model) + + if rank == 0: + # layer 0 & 7 to chunk 0 on rank0 + local_chunk = torch.nn.ModuleList().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 0 or idx == 7: + local_chunk.append(sub_model) + elif rank == 1: + # layer 1 & 6 to chunk 1 on rank1 + local_chunk = torch.nn.ModuleList().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 1 or idx == 6: + local_chunk.append(sub_model) + elif rank == 2: + # layer 2 & 5 to chunk 2 on rank2 + local_chunk = torch.nn.ModuleList().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 2 or idx == 5: + local_chunk.append(sub_model) + else: + # layer 3 & 4 to chunk 3 on rank3 + local_chunk = torch.nn.Sequential().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 3 or idx == 4: + local_chunk.append(sub_model) + print( + f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" + ) + + torch.cuda.synchronize() + scheduler.run_forward_backward( + model_chunk=local_chunk, + data_iter=iter(data_iter), + criterion=criterion, + optimizer=None, + return_loss=None, + return_outputs=None, + ) + + ########################## + # Fwd bwd for base + ########################## + # fwd & bwd + output_base = model_base(input_base[0]) loss_base = criterion(output_base) loss_base.backward() print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") @@ -481,8 +616,12 @@ def test_run_fwd_bwd_iter_input( # @pytest.mark.parametrize("num_model_chunk", [2]) @rerun_if_address_is_in_use() def test_pp(): + # spawn( + # test_run_fwd_bwd_iter_input, + # nprocs=4, + # ) spawn( - test_run_fwd_bwd_iter_input, + test_run_fwd_bwd_with_vschedule, nprocs=4, )