diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index e310e9bf3..bab118b85 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -64,10 +64,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # P2PMeta cache self.enable_metadata_cache = enable_metadata_cache - self.send_tensor_metadata = True - self.send_grad_metadata = True - self.tensor_metadata_recv = None - self.grad_metadata_recv = None + self.send_tensor_metadata = [True, True] + self.send_grad_metadata = [True, True] + # meta cache buffer + self.tensor_metadata_recv = [None, None] # [chunk 0 meta, chunk 1 meta] + self.grad_metadata_recv = [None, None] # P2P communication self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p) @@ -235,10 +236,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): else: prev_rank = self.stage_manager.get_prev_rank() input_tensor, wait_handles = self.comm.recv_forward( - prev_rank=prev_rank, metadata_recv=self.tensor_metadata_recv + prev_rank=prev_rank, metadata_recv=self.tensor_metadata_recv[model_chunk_id] ) - if self.enable_metadata_cache and self.tensor_metadata_recv is None: - self.tensor_metadata_recv = create_send_metadata(input_tensor) + if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None: + self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor) self.recv_forward_buffer[model_chunk_id].append(input_tensor) # return input_tensor, wait_handles return wait_handles @@ -259,10 +260,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): else: next_rank = self.stage_manager.get_next_rank() input_tensor, wait_handles = self.comm.recv_forward( - next_rank, metadata_recv=self.tensor_metadata_recv + next_rank, metadata_recv=self.tensor_metadata_recv[model_chunk_id] ) - if self.enable_metadata_cache and self.tensor_metadata_recv is None: - self.tensor_metadata_recv = create_send_metadata(input_tensor) + if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None: + self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor) self.recv_forward_buffer[model_chunk_id].append(input_tensor) # return input_tensor, wait_handles return wait_handles @@ -297,10 +298,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): else: next_rank = self.stage_manager.get_next_rank() output_tensor_grad, wait_handles = self.comm.recv_backward( - next_rank, metadata_recv=self.grad_metadata_recv + next_rank, metadata_recv=self.grad_metadata_recv[model_chunk_id] ) - if self.enable_metadata_cache and self.grad_metadata_recv is None: - self.grad_metadata_recv = create_send_metadata(output_tensor_grad) + if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None: + self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad) self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad) # return output_tensor_grad, wait_handles return wait_handles @@ -322,10 +323,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): else: prev_rank = self.stage_manager.get_prev_rank() output_tensor_grad, wait_handles = self.comm.recv_backward( - next_rank=prev_rank, metadata_recv=self.grad_metadata_recv + next_rank=prev_rank, metadata_recv=self.grad_metadata_recv[model_chunk_id] ) - if self.enable_metadata_cache and self.grad_metadata_recv is None: - self.grad_metadata_recv = create_send_metadata(output_tensor_grad) + if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None: + self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad) self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad) # return output_tensor_grad, wait_handles return wait_handles @@ -359,9 +360,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): next_rank = self.stage_manager.get_next_rank() output_tensor = self.send_forward_buffer[model_chunk_id].pop(0) send_handles = self.comm.send_forward( - output_object=output_tensor, next_rank=next_rank, send_metadata=self.send_tensor_metadata + output_object=output_tensor, + next_rank=next_rank, + send_metadata=self.send_tensor_metadata[model_chunk_id], ) - self.send_tensor_metadata = not self.enable_metadata_cache + self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache return send_handles else: @@ -380,9 +383,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): prev_rank = self.stage_manager.get_prev_rank() output_tensor = self.send_forward_buffer[model_chunk_id].pop(0) send_handles = self.comm.send_forward( - output_tensor, prev_rank, send_metadata=self.send_tensor_metadata + output_tensor, prev_rank, send_metadata=self.send_tensor_metadata[model_chunk_id] ) - self.send_tensor_metadata = not self.enable_metadata_cache + self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache return send_handles def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List: @@ -415,9 +418,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): prev_rank = self.stage_manager.get_prev_rank() input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0) send_handles = self.comm.send_backward( - input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata + input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata[model_chunk_id] ) - self.send_grad_metadata = not self.enable_metadata_cache + self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache return send_handles # bwd chunk1 is left V; @@ -437,9 +440,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): next_rank = self.stage_manager.get_next_rank() input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0) send_handles = self.comm.send_backward( - input_tensor_grad, next_rank, send_metadata=self.send_grad_metadata + input_tensor_grad, next_rank, send_metadata=self.send_grad_metadata[model_chunk_id] ) - self.send_grad_metadata = not self.enable_metadata_cache + self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache return send_handles def forward_step( @@ -662,6 +665,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): accum_loss=accum_loss, outputs=outputs, ) + # print(f"stage {self.stage_manager.stage}; model_chunk_id {model_chunk_id}; output_obj {output_obj};") # Step3: # 3-1:detach output; detach output for send fwd; @@ -886,6 +890,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank) for it in range(len(schedule)): scheduled_node = schedule[it] + # print(f"rank {torch.distributed.get_rank()}; stage {self.stage_manager.stage}; scheduled_node {scheduled_node};") if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES: # communication communication_func = self.communication_map[scheduled_node.type] diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index a51a1df9f..d1ad84604 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -191,7 +191,6 @@ class LlamaPipelineForwards: num_model_chunks=stage_manager.num_model_chunks, ) assert num_ckpt_layers <= end_idx - start_idx - for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): if output_hidden_states: all_hidden_states += (hidden_states,) diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 2b3a30bad..528638f41 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -9,6 +9,7 @@ from colossalai.shardformer.layer import ( FusedRMSNorm, Linear1D_Col, Linear1D_Row, + LinearWithGradAccum, PaddingEmbedding, PaddingLMHead, RMSNorm, @@ -104,7 +105,7 @@ class LlamaPolicy(Policy): policy=policy, target_key=LlamaModel, ) - + # enable tp, replace layer to tp Linear1D_Col,Linear1D_Row, if self.shard_config.enable_tensor_parallelism: assert ( num_q_heads % tp_size == 0 @@ -191,6 +192,84 @@ class LlamaPolicy(Policy): ], ) + # not enable tp, replace layer to LinearWithGradAccum + else: + decoder_attribute_replacement = { + "self_attn.hidden_size": self.model.config.hidden_size // tp_size, + "self_attn.num_heads": num_q_heads, + } + if getattr(self.model.config, "num_key_value_heads", False): + decoder_attribute_replacement["self_attn.num_key_value_heads"] = num_kv_heads + + policy[LlamaDecoderLayer] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="mlp.gate_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + ], + ) + if embedding_cls is not None: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( @@ -416,6 +495,7 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy): policy = super().module_policy() use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + # enable tp, replace layer to tp Linear1D_Col,Linear1D_Row, if self.shard_config.enable_tensor_parallelism: # add a new item for sequence classification new_item = { @@ -434,6 +514,25 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy): ) } policy.update(new_item) + # enable tp, replace layer to LinearWithGradAccum + else: + # add a new item for sequence classification + new_item = { + LlamaForSequenceClassification: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="score", + target_module=LinearWithGradAccum, + kwargs=dict( + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ) + ] + ) + } + policy.update(new_item) + # to be confirmed if self.pipeline_stage_manager: # set None as default diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index ad5d35161..4976f0c37 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -163,8 +163,6 @@ def main(): enable_async_reduce=not args.disable_async_reduce, use_fp8=args.use_fp8, fp8_communication=args.use_fp8_comm, - use_fp8=args.use_fp8, - fp8_communication=args.use_fp8_comm, ) elif args.plugin == "gemini_auto": plugin = GeminiPlugin( @@ -179,8 +177,6 @@ def main(): enable_flash_attention=args.xformers, use_fp8=args.use_fp8, fp8_communication=args.use_fp8_comm, - use_fp8=args.use_fp8, - fp8_communication=args.use_fp8_comm, ) elif args.plugin == "fsdp": if use_empty_init: @@ -192,7 +188,6 @@ def main(): ), param_init_fn=empty_init(), fp8_communication=args.use_fp8_comm, - fp8_communication=args.use_fp8_comm, ) else: plugin = TorchFSDPPlugin( @@ -214,7 +209,6 @@ def main(): cpu_offload=CPUOffload(offload_params=True), param_init_fn=empty_init(), fp8_communication=args.use_fp8_comm, - fp8_communication=args.use_fp8_comm, ) else: plugin = TorchFSDPPlugin( @@ -225,7 +219,6 @@ def main(): ), cpu_offload=CPUOffload(offload_params=True), fp8_communication=args.use_fp8_comm, - fp8_communication=args.use_fp8_comm, ) elif args.plugin == "3d": if args.pp_style == "zbv": diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 71ff11059..b8ef09bea 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -758,11 +758,13 @@ def run_with_hybridplugin(test_config): @parameterize( "config", [ - # (0, 1, 4, 1, 1), + # # Pass + (1, 2, 1, 1, 2), + # TODO: adapt mixtral with no TP Linear # (1, 2, 2, 1, 1), - (1, 1, 2, 2, 1), + # (0, 1, 4, 1, 1), + # (1, 1, 2, 2, 1), # (1, 2, 1, 2, 1), - # (1, 2, 1, 1, 2), ], ) def run_with_booster_moehybridplugin(config: Tuple[int, ...]): @@ -910,7 +912,6 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]): p.grad /= dp_size torch_optimizer.step() torch_optimizer.zero_grad() - assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) print(f"rank {dist.get_rank()} config {test_config} test passed") clear_layout_converter() @@ -921,11 +922,12 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]): @parameterize( "config", [ - (1, 2, 2, 1), # Pass - # TODO: only support pp + tp accleration; Will support fully pp and None tp Hybrid in furture; - # (0, 4, 1, 1), - # (1, 2, 1, 2), - # (1, 1, 2, 2), + # # Pass + (1, 2, 2, 1), + (1, 2, 1, 2), + (1, 1, 2, 2), + # TODO: acc err in pp4 + (1, 4, 1, 1), ], ) def run_with_booster_hybridplugin(config: Tuple[int, ...]):