diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index b4b40020f..3568a5dda 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -28,7 +28,8 @@ from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.interface.optimizer import DistributedOptim from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed -from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule +from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, ZeroBubbleVPipeScheduler +from colossalai.pipeline.schedule.v_schedule import PipelineGraph from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer @@ -1092,8 +1093,10 @@ class HybridParallelPlugin(PipelinePluginBase): self.custom_policy = custom_policy assert zero_stage in (0, 1, 2) if self.pp_size > 1: - assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style" - assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" + assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style" + assert ( + pp_style == "interleaved" or pp_style == "zbv" + ) or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" assert ( num_microbatches is not None or microbatch_size is not None ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" @@ -1103,7 +1106,7 @@ class HybridParallelPlugin(PipelinePluginBase): self.stage_manager = PipelineStageManager( self.pg_mesh, pipeline_axis=self.pp_axis, - enable_interleave=pp_style == "interleaved", + enable_interleave=(pp_style == "interleaved") or (pp_style == "zbv"), num_model_chunks=num_model_chunks, num_layers_per_stage=num_layers_per_stage, ) @@ -1125,6 +1128,31 @@ class HybridParallelPlugin(PipelinePluginBase): microbatch_size=microbatch_size, enable_metadata_cache=enable_metadata_cache, ) + elif pp_style == "zbv": + h, a, s = 4096, 32, 1024 + mem_f = 34 * h + 5 * a * s + mem_w = -32 * h + mem_b = -mem_w - mem_f + zbv_schedule = PipelineGraph( + n_stage=self.pp_size, + n_micro=num_microbatches, + f_cost=1, + b_cost=1, + w_cost=1, + c_cost=1, + f_mem=mem_f, + b_mem=mem_b, + w_mem=mem_w, + ).get_v_schedule() + self.schedule = ZeroBubbleVPipeScheduler( + schedule=zbv_schedule, + stage_manager=self.stage_manager, + num_model_chunks=num_model_chunks, + num_microbatch=num_microbatches, + microbatch_size=microbatch_size, + enable_metadata_cache=enable_metadata_cache, + overlap_p2p=overlap_p2p, + ) else: raise NotImplementedError() if sequence_parallelism_mode == "ring_attn": diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 41a886a90..da3039a6f 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -353,7 +353,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # bwd chunk1 is left V; else: - # print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage} self.send_backward_buffer {self.send_backward_buffer}") ################ # chunk = 1 && is_last_stage # do nothing; Already send input_tensor_grad to local_send_bwd_buffer in schedule b; @@ -409,7 +408,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): accum_loss.add_(loss.detach()) if outputs is not None: outputs.append(tree_map(detach, output_obj)) - # print(f"accum_loss {accum_loss}; outputs {len(outputs)}; model_chunk_id {model_chunk_id}") return loss else: return output_obj @@ -537,11 +535,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): Returns: Nothing. """ + micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id) # Step1: recv fwd if model_chunk_id == 0: # is first stage; get input from func param if self.stage_manager.is_first_stage(ignore_chunk=True): - input_obj = self.load_micro_batch(model_chunk_id=model_chunk_id) + input_obj = micro_batch else: input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) else: @@ -619,8 +618,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): else: output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) - # print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage}; output_tensor_grad {output_tensor_grad}\n") - # get input and output object from buffer; input_obj = self.input_tensors[model_chunk_id].pop(0) output_obj = self.output_tensors[model_chunk_id].pop(0) @@ -643,7 +640,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): output_obj=output_obj, output_obj_grad=output_tensor_grad, ) - # print(f"model_chunk_id {model_chunk_id}; stage {self.stage_manager.stage}; input_object_grad {input_object_grad}") # Step3: send bwd if model_chunk_id == 0: @@ -748,9 +744,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): """ # prepare batch self.load_batch(data_iter) - print( - f"self.batch_size {self.batch_size}; self.batch shape {self.batch.shape}; self.num_microbatch {self.num_microbatch}; self.microbatch_size {self.microbatch_size}" - ) # prepare accum loss & output accum_loss = None @@ -762,12 +755,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None # while we still have schedules_node in self.schedules - for it in range(len(self.schedules)): - scheduled_node = self.schedules[it] - - print( - f"it {it}; manger_stage {self.stage_manager.stage}; node_stage {scheduled_node.stage} chunk {scheduled_node.chunk} {scheduled_node.type};" - ) + schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank) + for it in range(len(schedule)): + scheduled_node = schedule[it] if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES: # communication communication_func = self.communication_map[scheduled_node.type] diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index 4adc38619..029968231 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -2,7 +2,8 @@ from .albert import * from .bert import * from .blip2 import * from .bloom import * -from .chatglm2 import * + +# from .chatglm2 import * from .command import * from .deepseek import * from .falcon import * diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 8c869ae52..b2c988a8b 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -10,10 +10,11 @@ from torch.testing import assert_close import colossalai from colossalai.cluster import ProcessGroupMesh from colossalai.interface import OptimizerWrapper +from colossalai.logging import disable_existing_loggers from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn class MlpModel(nn.Module): @@ -38,19 +39,31 @@ def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]: # 1) Test manual v_schedule with multiple microbatch -def run_fwd_bwd_iter_input( - rank: int, - world_size: int, - port: int, -): +@parameterize( + "test_config", + [ + { + "batch_size": 4, + "tp_size": 1, + "pp_size": 4, + "num_microbatches": 4, + "zero_stage": 1, + "precision": "bf16", + "num_model_chunk": 4, + }, + ], +) +def run_fwd_bwd_iter_input(test_config): # init dist - colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") rank = dist.get_rank() - pp_size = world_size + pp_size = test_config["pp_size"] pg_mesh = ProcessGroupMesh(pp_size) - num_microbatch = 4 + num_microbatch = test_config["num_microbatches"] + num_model_chunk = test_config["num_model_chunk"] # stage_manager - stage_manager = PipelineStageManager(pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=pp_size) + stage_manager = PipelineStageManager( + pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk + ) # schedule list zbv_schedule = [ @@ -373,7 +386,7 @@ def run_fwd_bwd_iter_input( ] scheduler = ZeroBubbleVPipeScheduler( - schedule=zbv_schedule[rank], # hint: send whole schedule or local schedule only ? + schedule=zbv_schedule, # hint: send whole schedule or local schedule only ? stage_manager=stage_manager, num_model_chunks=pp_size, num_microbatch=num_microbatch, @@ -419,20 +432,26 @@ def run_fwd_bwd_iter_input( for idx, sub_model in enumerate(model.layers): if idx == 3 or idx == 4: local_chunk.append(sub_model) + # init optimizer + optimizer_base = torch.optim.SGD(model_base.parameters(), lr=1e-5) + optimizer_pp = OptimizerWrapper(torch.optim.SGD(local_chunk.parameters(), lr=1e-5)) + print( f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" ) torch.cuda.synchronize() - scheduler.run_forward_backward( + result = scheduler.forward_backward_step( model_chunk=local_chunk, data_iter=iter(data_iter), criterion=criterion, - optimizer=None, - return_loss=None, - return_outputs=None, + optimizer=optimizer_pp, + return_loss=True, + return_outputs=True, ) + optimizer_pp.step() + ########################## # Fwd bwd for base ########################## @@ -440,6 +459,7 @@ def run_fwd_bwd_iter_input( output_base = model_base(input_base[0]) loss_base = criterion(output_base) loss_base.backward() + optimizer_base.step() print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") ########################## @@ -475,21 +495,28 @@ def run_fwd_bwd_iter_input( assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad) -# 2) Test v_schedule generated by graph with multiple microbatch -def run_fwd_bwd_with_vschedule( - rank: int, - world_size: int, - port: int, - num_microbatch: int, - batch_size: int, - num_model_chunk: int, -): +# 2) add optimizer base 1) +@parameterize( + "test_config", + [ + { + "batch_size": 4, + "tp_size": 1, + "pp_size": 4, + "num_microbatches": 4, + "zero_stage": 1, + "precision": "bf16", + "num_model_chunk": 4, + }, + ], +) +def run_fwd_bwd_vschedule_with_optim(test_config): # init dist - colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") rank = dist.get_rank() - pp_size = world_size + pp_size = test_config["pp_size"] pg_mesh = ProcessGroupMesh(pp_size) - num_microbatch = num_microbatch + num_microbatch = test_config["num_microbatches"] + num_model_chunk = test_config["num_model_chunk"] # stage_manager stage_manager = PipelineStageManager( pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk @@ -500,149 +527,7 @@ def run_fwd_bwd_with_vschedule( mem_w = -32 * h mem_b = -mem_w - mem_f graph = PipelineGraph( - n_stage=world_size, - n_micro=num_microbatch, - f_cost=6, - b_cost=6, - w_cost=6, - c_cost=6, - f_mem=mem_f, - b_mem=mem_b, - w_mem=mem_w, - # max_mem=mem_f * (p * 2 + m_offset), - ) - - zbv_schedule = graph.get_v_schedule() - - scheduler = ZeroBubbleVPipeScheduler( - schedule=zbv_schedule[rank], # hint: send whole schedule or local schedule only ? - stage_manager=stage_manager, - num_model_chunks=num_model_chunk, - num_microbatch=num_microbatch, - overlap_p2p=False, - ) - - def criterion(x, *args, **kwargs): - return (x * x).mean() - - # init model and input - batch_size = batch_size - num_layers = 8 - assert num_layers % num_model_chunk == 0, f"Model with {num_layers} layer can not dist on {num_model_chunk} chunk" - in_dim = out_dim = 8 - print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};") - model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank) - data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)] - - input_base = [t.clone() for t in data_iter] - model_base = deepcopy(model) - - if rank == 0: - # layer 0 & 7 to chunk 0 on rank0 - local_chunk = torch.nn.ModuleList().to(rank) - for idx, sub_model in enumerate(model.layers): - if idx == 0 or idx == 7: - local_chunk.append(sub_model) - elif rank == 1: - # layer 1 & 6 to chunk 1 on rank1 - local_chunk = torch.nn.ModuleList().to(rank) - for idx, sub_model in enumerate(model.layers): - if idx == 1 or idx == 6: - local_chunk.append(sub_model) - elif rank == 2: - # layer 2 & 5 to chunk 2 on rank2 - local_chunk = torch.nn.ModuleList().to(rank) - for idx, sub_model in enumerate(model.layers): - if idx == 2 or idx == 5: - local_chunk.append(sub_model) - else: - # layer 3 & 4 to chunk 3 on rank3 - local_chunk = torch.nn.Sequential().to(rank) - for idx, sub_model in enumerate(model.layers): - if idx == 3 or idx == 4: - local_chunk.append(sub_model) - print( - f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" - ) - - torch.cuda.synchronize() - scheduler.run_forward_backward( - model_chunk=local_chunk, - data_iter=iter(data_iter), - criterion=criterion, - optimizer=None, - return_loss=None, - return_outputs=None, - ) - - ########################## - # Fwd bwd for base - ########################## - # fwd & bwd - output_base = model_base(input_base[0]) - loss_base = criterion(output_base) - loss_base.backward() - print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") - - ########################## - # assert weight - ########################## - if rank == 0: - # layer 0 - assert_close(local_chunk[0].weight, model_base.layers[0].weight) - assert_close(local_chunk[0].weight.grad, model_base.layers[0].weight.grad) - # layer 7 - assert_close(local_chunk[1].weight, model_base.layers[7].weight) - assert_close(local_chunk[1].weight.grad, model_base.layers[7].weight.grad) - if rank == 1: - # layer 1 - assert_close(local_chunk[0].weight, model_base.layers[1].weight) - assert_close(local_chunk[0].weight.grad, model_base.layers[1].weight.grad) - # layer 6 - assert_close(local_chunk[1].weight, model_base.layers[6].weight) - assert_close(local_chunk[1].weight.grad, model_base.layers[6].weight.grad) - if rank == 2: - # layer 2 - assert_close(local_chunk[0].weight, model_base.layers[2].weight) - assert_close(local_chunk[0].weight.grad, model_base.layers[2].weight.grad) - # layer 5 - assert_close(local_chunk[1].weight, model_base.layers[5].weight) - assert_close(local_chunk[1].weight.grad, model_base.layers[5].weight.grad) - if rank == 3: - # layer 3 - assert_close(local_chunk[0].weight, model_base.layers[3].weight) - assert_close(local_chunk[0].weight.grad, model_base.layers[3].weight.grad) - # layer 4 - assert_close(local_chunk[1].weight, model_base.layers[4].weight) - assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad) - - -# 3) add optimizer base 2) -def run_fwd_bwd_vschedule_with_optim( - rank: int, - world_size: int, - port: int, - num_microbatch: int, - batch_size: int, - num_model_chunk: int, -): - # init dist - colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") - rank = dist.get_rank() - pp_size = world_size - pg_mesh = ProcessGroupMesh(pp_size) - num_microbatch = num_microbatch - # stage_manager - stage_manager = PipelineStageManager( - pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk - ) - - h, a, s = 4096, 32, 1024 - mem_f = 34 * h + 5 * a * s - mem_w = -32 * h - mem_b = -mem_w - mem_f - graph = PipelineGraph( - n_stage=world_size, + n_stage=pp_size, n_micro=num_microbatch, f_cost=1, b_cost=1, @@ -657,7 +542,7 @@ def run_fwd_bwd_vschedule_with_optim( zbv_schedule = graph.get_v_schedule() scheduler = ZeroBubbleVPipeScheduler( - schedule=zbv_schedule[rank], # hint: send whole schedule or local schedule only ? + schedule=zbv_schedule, # hint: send whole schedule or local schedule only ? stage_manager=stage_manager, num_model_chunks=num_model_chunk, num_microbatch=num_microbatch, @@ -669,7 +554,7 @@ def run_fwd_bwd_vschedule_with_optim( return (x * x).mean() # init model and input - batch_size = batch_size + batch_size = test_config["batch_size"] num_layers = 8 assert num_layers % num_model_chunk == 0, f"Model with {num_layers} layer can not dist on {num_model_chunk} chunk" in_dim = out_dim = 16 @@ -793,8 +678,27 @@ def run_fwd_bwd_vschedule_with_optim( assert val_base[:2] == val_pp -# 4) support Hybrid base 3) -def run_with_hybrid( +# TODO:4) support Hybrid base 3) +@parameterize( + "test_config", + [ + { + "batch_size": 4, + "tp_size": 1, + "pp_size": 4, + "num_microbatches": 4, + "zero_stage": 1, + "precision": "bf16", + "num_model_chunk": 4, + }, + ], +) +def run_with_hybridplugin(test_config): + pass + + +# TODO:5) support MoEHybrid base 3) +def run_with_moehybridplugin( rank: int, world_size: int, port: int, @@ -805,35 +709,26 @@ def run_with_hybrid( pass -# 5) support MoE base 3) +# TODO:6) support booster & Hybrid base 4) -# 6) support booster & Hybrid base 4) +# TODO:7) support booster & MoEHybrid base 4) -# 6) support booster & MoE base 4) + +def run_dist(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_fwd_bwd_iter_input() + run_fwd_bwd_vschedule_with_optim() @pytest.mark.dist -@pytest.mark.parametrize("num_microbatch", [4]) -@pytest.mark.parametrize("batch_size", [4]) -@pytest.mark.parametrize("num_model_chunk", [4]) @rerun_if_address_is_in_use() -def test_pp(num_microbatch: int, batch_size: int, num_model_chunk: int): - # spawn( - # run_fwd_bwd_with_vschedule, - # nprocs=4, - # num_microbatch=num_microbatch, - # batch_size=batch_size, - # num_model_chunk=num_model_chunk, - # ) - +def test_pp(): spawn( - run_fwd_bwd_vschedule_with_optim, + run_dist, nprocs=4, - num_microbatch=num_microbatch, - batch_size=batch_size, - num_model_chunk=num_model_chunk, ) if __name__ == "__main__": - test_pp(num_microbatch=4, batch_size=4, num_model_chunk=4) + test_pp()