diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index b7b284213..8c319aceb 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -432,7 +432,6 @@ def _communicate( overlap_p2p=overlap_p2p, send_first=send_first if send_first != None else True, ) - if metadata_recv is not None: assert isinstance(metadata_recv, P2PMetadata) tree_spec = metadata_recv.tree_spec diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index e310e9bf3..89c868aae 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -64,10 +64,28 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # P2PMeta cache self.enable_metadata_cache = enable_metadata_cache - self.send_tensor_metadata = True - self.send_grad_metadata = True - self.tensor_metadata_recv = None - self.grad_metadata_recv = None + + # check send_tensor_metadata, send_grad_metadata + # pp4 as sample, we should follow this meta strategy + # send_tensor_meta(fwd) send_grad_meta(bwd) + # chunk0 | chunk1 chunk0 | chunk 1 + # stage 0 T | F F | T + # stage 1 T | T T | T + # stage 2 T | T T | T + # stage 3 F | T F | T + if stage_manager.is_first_stage(ignore_chunk=True): + self.send_tensor_metadata = [True, False] + self.send_grad_metadata = [False, True] + elif stage_manager.is_last_stage(ignore_chunk=True): + self.send_tensor_metadata = [False, True] + self.send_grad_metadata = [True, False] + else: + self.send_tensor_metadata = [True, True] + self.send_grad_metadata = [True, True] + + # meta cache buffer + self.tensor_metadata_recv = [None, None] # [chunk 0 meta, chunk 1 meta] + self.grad_metadata_recv = [None, None] # P2P communication self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p) @@ -96,10 +114,16 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): self.output_tensors_grad_dw = [[], []] # buffer for communication - self.send_forward_buffer = [[], []] - self.recv_forward_buffer = [[], []] - self.send_backward_buffer = [[], []] - self.recv_backward_buffer = [[], []] + self.send_forward_buffer = [[], []] # [chunk0:[torch.Tensor], chunk1:[torch.Tensor]] + self.recv_forward_buffer = [ + [], + [], + ] # [chunk0:[(torch.Tensor, wait_handle)], chunk1:[(torch.Tensor, wait_handle)]] + self.send_backward_buffer = [[], []] # [chunk0:[torch.Tensor], chunk1:[torch.Tensor]] + self.recv_backward_buffer = [ + [], + [], + ] # [chunk0:[(torch.Tensor, wait_handle)], chunk1:[(torch.Tensor, wait_handle)]] # y buffer for local send fwd self.local_send_forward_buffer = [] @@ -225,7 +249,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # do nothing; cause u are chunk 0 in first rank, u have no prev rank; ################# if self.stage_manager.is_first_stage(ignore_chunk=True): - # return None, [] return [] ################ @@ -235,12 +258,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): else: prev_rank = self.stage_manager.get_prev_rank() input_tensor, wait_handles = self.comm.recv_forward( - prev_rank=prev_rank, metadata_recv=self.tensor_metadata_recv + prev_rank=prev_rank, metadata_recv=self.tensor_metadata_recv[model_chunk_id] ) - if self.enable_metadata_cache and self.tensor_metadata_recv is None: - self.tensor_metadata_recv = create_send_metadata(input_tensor) - self.recv_forward_buffer[model_chunk_id].append(input_tensor) - # return input_tensor, wait_handles + if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None: + self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor) + self.recv_forward_buffer[model_chunk_id].append((input_tensor, wait_handles)) return wait_handles else: @@ -259,12 +281,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): else: next_rank = self.stage_manager.get_next_rank() input_tensor, wait_handles = self.comm.recv_forward( - next_rank, metadata_recv=self.tensor_metadata_recv + next_rank, metadata_recv=self.tensor_metadata_recv[model_chunk_id] ) - if self.enable_metadata_cache and self.tensor_metadata_recv is None: - self.tensor_metadata_recv = create_send_metadata(input_tensor) - self.recv_forward_buffer[model_chunk_id].append(input_tensor) - # return input_tensor, wait_handles + if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None: + self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor) + self.recv_forward_buffer[model_chunk_id].append((input_tensor, wait_handles)) return wait_handles def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> List: @@ -287,7 +308,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # do nothing; Already get dy from local_send_backward_buffer in schedule b ################ if self.stage_manager.is_last_stage(ignore_chunk=True): - # return None, [] return [] ################ @@ -297,12 +317,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): else: next_rank = self.stage_manager.get_next_rank() output_tensor_grad, wait_handles = self.comm.recv_backward( - next_rank, metadata_recv=self.grad_metadata_recv + next_rank, metadata_recv=self.grad_metadata_recv[model_chunk_id] ) - if self.enable_metadata_cache and self.grad_metadata_recv is None: - self.grad_metadata_recv = create_send_metadata(output_tensor_grad) - self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad) - # return output_tensor_grad, wait_handles + if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None: + self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad) + self.recv_backward_buffer[model_chunk_id].append((output_tensor_grad, wait_handles)) return wait_handles else: @@ -312,7 +331,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # do nothing; get loss from local ################ if self.stage_manager.is_first_stage(ignore_chunk=True): - # return None, [] return [] ################ @@ -322,12 +340,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): else: prev_rank = self.stage_manager.get_prev_rank() output_tensor_grad, wait_handles = self.comm.recv_backward( - next_rank=prev_rank, metadata_recv=self.grad_metadata_recv + next_rank=prev_rank, metadata_recv=self.grad_metadata_recv[model_chunk_id] ) - if self.enable_metadata_cache and self.grad_metadata_recv is None: - self.grad_metadata_recv = create_send_metadata(output_tensor_grad) - self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad) - # return output_tensor_grad, wait_handles + if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None: + self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad) + self.recv_backward_buffer[model_chunk_id].append((output_tensor_grad, wait_handles)) return wait_handles def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List: @@ -349,6 +366,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # do nothing; hold y on local_send_forward_buffer ################ if self.stage_manager.is_last_stage(ignore_chunk=True): + self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache return [] ################ @@ -359,9 +377,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): next_rank = self.stage_manager.get_next_rank() output_tensor = self.send_forward_buffer[model_chunk_id].pop(0) send_handles = self.comm.send_forward( - output_object=output_tensor, next_rank=next_rank, send_metadata=self.send_tensor_metadata + output_object=output_tensor, + next_rank=next_rank, + send_metadata=self.send_tensor_metadata[model_chunk_id], ) - self.send_tensor_metadata = not self.enable_metadata_cache + self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache return send_handles else: @@ -370,6 +390,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # do nothing; Already send LOSS to local_send_backward_buffer in schedule f send part ################ if self.stage_manager.is_first_stage(ignore_chunk=True): + self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache return [] ################ @@ -380,9 +401,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): prev_rank = self.stage_manager.get_prev_rank() output_tensor = self.send_forward_buffer[model_chunk_id].pop(0) send_handles = self.comm.send_forward( - output_tensor, prev_rank, send_metadata=self.send_tensor_metadata + output_tensor, prev_rank, send_metadata=self.send_tensor_metadata[model_chunk_id] ) - self.send_tensor_metadata = not self.enable_metadata_cache + self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache return send_handles def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List: @@ -405,6 +426,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # do nothing; cause u are the first chunk in first stage; bwd end ################ if self.stage_manager.is_first_stage(ignore_chunk=True): + self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache return [] ################ @@ -415,9 +437,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): prev_rank = self.stage_manager.get_prev_rank() input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0) send_handles = self.comm.send_backward( - input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata + input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata[model_chunk_id] ) - self.send_grad_metadata = not self.enable_metadata_cache + self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache return send_handles # bwd chunk1 is left V; @@ -427,6 +449,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # do nothing; Already send input_tensor_grad to local_send_bwd_buffer in schedule b; ################ if self.stage_manager.is_last_stage(ignore_chunk=True): + self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache return [] ################ @@ -437,9 +460,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): next_rank = self.stage_manager.get_next_rank() input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0) send_handles = self.comm.send_backward( - input_tensor_grad, next_rank, send_metadata=self.send_grad_metadata + input_tensor_grad, next_rank, send_metadata=self.send_grad_metadata[model_chunk_id] ) - self.send_grad_metadata = not self.enable_metadata_cache + self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache return send_handles def forward_step( @@ -519,8 +542,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): output_obj_grad_ = [] # For chunk 0 stage 0, use micro_batch as input_obj_; and we don't have to cal microbatch dx. - # if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): - # return None # For loss backward; output_obj is loss; output_obj_grad should be None if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): @@ -633,9 +654,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): if model_chunk_id == 0: # is first stage; get input from microbatch if self.stage_manager.is_first_stage(ignore_chunk=True): - input_obj = None + input_obj = None # (tensor, wait_handle) else: input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) + for h in input_obj[1]: + h.wait() + input_obj = input_obj[0] else: # is last stage; recv from local if self.stage_manager.is_last_stage(ignore_chunk=True): @@ -643,7 +667,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # not last stage; recv from next else: input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) - + for h in input_obj[1]: + h.wait() + input_obj = input_obj[0] # Here, let input_obj.requires_grad_() # if input_obj is not None: if not isinstance(input_obj, torch.Tensor): @@ -689,10 +715,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # Do not release_tensor_data loss, release_tensor_data other output_obj; if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): self.output_tensors[model_chunk_id].append(output_obj) - # self.output_tensors_dw[model_chunk_id].append(output_obj) else: self.output_tensors[model_chunk_id].append(output_obj) - # self.output_tensors_dw[model_chunk_id].append(output_obj) # add output to send_fwd_buffer if model_chunk_id == 0: # chunk 0 @@ -732,6 +756,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # chunk0 not last stage; recv output_grad from recv_backward_buffer else: output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) + for h in output_tensor_grad[1]: + h.wait() + output_tensor_grad = output_tensor_grad[0] else: # chunk1, is first stage; recv LOSS from local send bwd buffer if self.stage_manager.is_first_stage(ignore_chunk=True): @@ -739,25 +766,14 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # chunk1, not first stage; recv output_grad from recv_backward_buffer else: output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) + for h in output_tensor_grad[1]: + h.wait() + output_tensor_grad = output_tensor_grad[0] # get input and output object from buffer; input_obj = self.input_tensors[model_chunk_id].pop(0) output_obj = self.output_tensors[model_chunk_id].pop(0) - # # save output_tensor_grad for dw - # if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): - # # we save loss here - # self.output_tensors_grad_dw[model_chunk_id].append(output_obj) - # else: - # # we save output_tensor_grad here - # self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad) - # the_output_obj_grad = [] - # if isinstance(output_obj, dict): - # for (k, v) in output_obj.items(): - # the_output_obj_grad.append(v.requires_grad) - # else: - # the_output_obj_grad.append(output_obj.requires_grad) - input_object_grad = self.backward_b_step( model_chunk=model_chunk, model_chunk_id=model_chunk_id, @@ -800,20 +816,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): Returns: Nothing. """ - - # get y & dy from buffer - # output_obj = self.output_tensors_dw[model_chunk_id].pop(0) - # output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0) WeightGradStore.pop(chunk=model_chunk_id) - # self.backward_w_step( - # model_chunk=model_chunk, - # model_chunk_id=model_chunk_id, - # optimizer=optimizer, - # output_obj=output_obj, - # output_obj_grad=output_obj_grad, - # ) - def run_forward_only( self, model_chunk: Union[ModuleList, Module], @@ -890,7 +894,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # communication communication_func = self.communication_map[scheduled_node.type] wait_handle = communication_func(scheduled_node.chunk) - self.wait_handles.append(wait_handle) + # We wait recv handle in fwd step and bwd step. Here only need to wait for send handle + if scheduled_node.type in {"SEND_FORWARD", "SEND_BACKWARD"}: + self.wait_handles.append(wait_handle) elif scheduled_node.type == "F": self.schedule_f( scheduled_node=scheduled_node, @@ -914,10 +920,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): model_chunk_id=scheduled_node.chunk, optimizer=optimizer, ) + # wait here to ensure all communication is done for h in self.wait_handles: for hh in h: hh.wait() - # return loss & output if outputs is not None: outputs = merge_batch(outputs) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 3202ebf25..019a6b140 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -6,6 +6,7 @@ import torch.distributed import torch.distributed as dist import torch.nn.functional as F from einops import rearrange +from packaging import version from colossalai.kernel.kernel_loader import ( FlashAttentionDaoLoader, @@ -642,9 +643,7 @@ class RingAttention(torch.autograd.Function): max_seqlen_q = max_seqlen_kv = max_seqlen cu_seqlens_half = cu_seqlens // 2 max_seqlen_half = max_seqlen // 2 - misc_kwargs = { - "window_size": (-1, -1), "alibi_slopes": None, "softmax_scale": q.shape[-1] ** -0.5 if softmax_scale is None else softmax_scale, "dropout_p": dropout_p, @@ -652,6 +651,13 @@ class RingAttention(torch.autograd.Function): "softcap": 0.0, "return_softmax": False, } + import flash_attn + + if version.parse(flash_attn.__version__) > version.parse("2.6.3"): + misc_kwargs["window_size_left"] = -1 + misc_kwargs["window_size_right"] = -1 + else: + misc_kwargs["window_size"] = (-1, -1) if ( RingAttention.HALF_INDICES is not None @@ -707,26 +713,39 @@ class RingAttention(torch.autograd.Function): # Helper to pass args to FA def _forward(q, k, v, causal): - ( - _, - _, - _, - _, - out, - softmax_lse, - _, - rng_state, - ) = _flash_attn_forward( - q, - k, - v, - cu_seqlens_q if q.shape[0] == t else cu_seqlens_half, - cu_seqlens_kv if k.shape[0] == t else cu_seqlens_half, - max_seqlen_q if q.shape[0] == t else max_seqlen_half, - max_seqlen_kv if k.shape[0] == t else max_seqlen_half, - causal=causal, - **misc_kwargs, - ) + if version.parse(flash_attn.__version__) > version.parse("2.6.3"): + (out, softmax_lse, S_dmask, rng_state) = _flash_attn_forward( + q, + k, + v, + cu_seqlens_q if q.shape[0] == t else cu_seqlens_half, + cu_seqlens_kv if k.shape[0] == t else cu_seqlens_half, + max_seqlen_q if q.shape[0] == t else max_seqlen_half, + max_seqlen_kv if k.shape[0] == t else max_seqlen_half, + causal=causal, + **misc_kwargs, + ) + else: + ( + _, + _, + _, + _, + out, + softmax_lse, + _, + rng_state, + ) = _flash_attn_forward( + q, + k, + v, + cu_seqlens_q if q.shape[0] == t else cu_seqlens_half, + cu_seqlens_kv if k.shape[0] == t else cu_seqlens_half, + max_seqlen_q if q.shape[0] == t else max_seqlen_half, + max_seqlen_kv if k.shape[0] == t else max_seqlen_half, + causal=causal, + **misc_kwargs, + ) return out, softmax_lse, rng_state def _kv_comm(i): diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index a51a1df9f..d1ad84604 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -191,7 +191,6 @@ class LlamaPipelineForwards: num_model_chunks=stage_manager.num_model_chunks, ) assert num_ckpt_layers <= end_idx - start_idx - for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): if output_hidden_states: all_hidden_states += (hidden_states,) diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py index 3687cfb99..a88db87bc 100644 --- a/colossalai/shardformer/modeling/mixtral.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -381,7 +381,6 @@ class MixtralPipelineForwards: output_router_logits, use_cache, ) - hidden_states = layer_outputs[0] if use_cache: diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 09673d396..63cd49280 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -75,6 +75,8 @@ class BertPolicy(Policy): sp_partial_derived = sp_mode == "split_gather" + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + if self.shard_config.enable_tensor_parallelism: assert ( self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 @@ -97,6 +99,7 @@ class BertPolicy(Policy): kwargs={ "seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -105,6 +108,7 @@ class BertPolicy(Policy): kwargs={ "seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -113,6 +117,7 @@ class BertPolicy(Policy): kwargs={ "seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -125,6 +130,7 @@ class BertPolicy(Policy): kwargs={ "seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -138,6 +144,7 @@ class BertPolicy(Policy): "seq_parallel_mode": sp_mode, "skip_bias_add": self.enable_bias_gelu_fused, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -146,6 +153,97 @@ class BertPolicy(Policy): kwargs={ "seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="output.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ], + ) + + policy[BertEmbeddings] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + ] + ) + if self.enable_bias_gelu_fused: + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_bert_intermediate_forward(), + }, + policy=policy, + target_key=BertIntermediate, + ) + + elif use_zbv: + policy[BertLayer] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attention.self.query", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attention.self.key", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attention.self.value", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attention.self.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attention.output.dense", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="intermediate.dense", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "seq_parallel_mode": sp_mode, + "skip_bias_add": self.enable_bias_gelu_fused, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="output.dense", + target_module=col_nn.LinearWithGradAccum, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 2b3a30bad..e8f9471f9 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -9,6 +9,7 @@ from colossalai.shardformer.layer import ( FusedRMSNorm, Linear1D_Col, Linear1D_Row, + LinearWithGradAccum, PaddingEmbedding, PaddingLMHead, RMSNorm, @@ -104,7 +105,7 @@ class LlamaPolicy(Policy): policy=policy, target_key=LlamaModel, ) - + # enable tp, replace layer to tp Linear1D_Col,Linear1D_Row, if self.shard_config.enable_tensor_parallelism: assert ( num_q_heads % tp_size == 0 @@ -191,6 +192,76 @@ class LlamaPolicy(Policy): ], ) + # not enable tp, replace layer to LinearWithGradAccum + elif use_zbv: + policy[LlamaDecoderLayer] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="mlp.gate_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=LinearWithGradAccum, + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ), + ], + ) + if embedding_cls is not None: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( @@ -416,6 +487,7 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy): policy = super().module_policy() use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + # enable tp, replace layer to tp Linear1D_Col,Linear1D_Row, if self.shard_config.enable_tensor_parallelism: # add a new item for sequence classification new_item = { @@ -434,6 +506,25 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy): ) } policy.update(new_item) + # enable tp, replace layer to LinearWithGradAccum + elif use_zbv: + # add a new item for sequence classification + new_item = { + LlamaForSequenceClassification: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="score", + target_module=LinearWithGradAccum, + kwargs=dict( + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ) + ] + ) + } + policy.update(new_item) + # to be confirmed if self.pipeline_stage_manager: # set None as default diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index 4d16038c1..b4b87df92 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -10,6 +10,7 @@ from colossalai.shardformer.layer import ( FusedRMSNorm, Linear1D_Col, Linear1D_Row, + LinearWithGradAccum, PaddingEmbedding, PaddingLMHead, VocabParallelEmbedding1D, @@ -62,6 +63,8 @@ class MistralPolicy(Policy): if self.tie_weight: embedding_cls = PaddingEmbedding + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False warnings.warn( @@ -90,6 +93,7 @@ class MistralPolicy(Policy): target_module=Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -97,6 +101,7 @@ class MistralPolicy(Policy): target_module=Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -104,6 +109,7 @@ class MistralPolicy(Policy): target_module=Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -111,6 +117,7 @@ class MistralPolicy(Policy): target_module=Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -118,6 +125,7 @@ class MistralPolicy(Policy): target_module=Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -125,6 +133,7 @@ class MistralPolicy(Policy): target_module=Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -132,6 +141,68 @@ class MistralPolicy(Policy): target_module=Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + ], + ) + elif use_zbv: + policy[MistralDecoderLayer] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.gate_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, }, ), ], diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index ece72d929..fab437c01 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -7,9 +7,18 @@ from torch import Tensor from torch.nn import Module from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM, MixtralModel -from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col -from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D -from colossalai.shardformer.layer.linear import Linear1D_Row +from colossalai.shardformer.layer import ( + FusedRMSNorm, + Linear1D_Col, + Linear1D_Row, + LinearWithGradAccum, + PaddingEmbedding, + VocabParallelEmbedding1D, +) + +# from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col +# from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D +# from colossalai.shardformer.layer.linear import Linear1D_Row from colossalai.shardformer.modeling.mixtral import ( EPMixtralSparseMoeBlock, MixtralPipelineForwards, @@ -166,6 +175,51 @@ class MixtralPolicy(Policy): ], ) + elif use_zbv: + policy[MixtralDecoderLayer] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + SubModuleReplacementDescription( + suffix="block_sparse_moe.gate", + target_module=LinearWithGradAccum, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": use_zbv, + }, + ), + ], + ) if embedding_cls is not None: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( @@ -351,6 +405,22 @@ class MixtralForCausalLMPolicy(MixtralPolicy): ) } policy.update(new_item) + elif use_zbv: + new_item = { + MixtralForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=LinearWithGradAccum, + kwargs=dict( + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ) + ], + ) + } + policy.update(new_item) if self.pipeline_stage_manager: # set None as default diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index ad5d35161..4976f0c37 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -163,8 +163,6 @@ def main(): enable_async_reduce=not args.disable_async_reduce, use_fp8=args.use_fp8, fp8_communication=args.use_fp8_comm, - use_fp8=args.use_fp8, - fp8_communication=args.use_fp8_comm, ) elif args.plugin == "gemini_auto": plugin = GeminiPlugin( @@ -179,8 +177,6 @@ def main(): enable_flash_attention=args.xformers, use_fp8=args.use_fp8, fp8_communication=args.use_fp8_comm, - use_fp8=args.use_fp8, - fp8_communication=args.use_fp8_comm, ) elif args.plugin == "fsdp": if use_empty_init: @@ -192,7 +188,6 @@ def main(): ), param_init_fn=empty_init(), fp8_communication=args.use_fp8_comm, - fp8_communication=args.use_fp8_comm, ) else: plugin = TorchFSDPPlugin( @@ -214,7 +209,6 @@ def main(): cpu_offload=CPUOffload(offload_params=True), param_init_fn=empty_init(), fp8_communication=args.use_fp8_comm, - fp8_communication=args.use_fp8_comm, ) else: plugin = TorchFSDPPlugin( @@ -225,7 +219,6 @@ def main(): ), cpu_offload=CPUOffload(offload_params=True), fp8_communication=args.use_fp8_comm, - fp8_communication=args.use_fp8_comm, ) elif args.plugin == "3d": if args.pp_style == "zbv": diff --git a/examples/language/mixtral/benchmark.py b/examples/language/mixtral/benchmark.py index 0334bd81c..dbffd0c2a 100644 --- a/examples/language/mixtral/benchmark.py +++ b/examples/language/mixtral/benchmark.py @@ -122,7 +122,7 @@ def main(): num_ckpt_layers_per_stage=[19, 19, 19, 13], ), "num_layers_per_stage": [19, 20, 20, 21], - # "pp_style": "interleaved", + "pp_style": "interleaved", } if args.custom_ckpt else {} diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 71ff11059..a01b75eee 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -749,24 +749,17 @@ def run_fwd_bwd_vschedule_with_optim(test_config): assert_optim_param_groups(optim_base_param_groups, optim_pp_param_groups) -# TODO:3) support booster & Hybrid base 2) -def run_with_hybridplugin(test_config): - pass - - -# TODO:4) support booster & MoEHybrid base 2) @parameterize( "config", [ - # (0, 1, 4, 1, 1), - # (1, 2, 2, 1, 1), + (1, 2, 1, 1, 2), (1, 1, 2, 2, 1), - # (1, 2, 1, 2, 1), - # (1, 2, 1, 1, 2), + (1, 2, 1, 2, 1), + (1, 2, 2, 1, 1), + (1, 1, 4, 1, 1), ], ) def run_with_booster_moehybridplugin(config: Tuple[int, ...]): - test_config = config stage, ep_size, pp_size, tp_size, sp_size = config num_microbatches = pp_size dist.get_world_size() @@ -876,7 +869,6 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]): return_outputs=True, ) # stage 0 chunk 0 - parallel_output = None if ( booster.plugin.stage_manager.is_first_stage(ignore_chunk=True) and rank == dist.get_process_group_ranks(plugin.pp_group)[0] @@ -910,9 +902,7 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]): p.grad /= dp_size torch_optimizer.step() torch_optimizer.zero_grad() - assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) - print(f"rank {dist.get_rank()} config {test_config} test passed") clear_layout_converter() Randomizer.reset_index() torch.cuda.empty_cache() @@ -921,11 +911,11 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]): @parameterize( "config", [ - (1, 2, 2, 1), # Pass - # TODO: only support pp + tp accleration; Will support fully pp and None tp Hybrid in furture; - # (0, 4, 1, 1), - # (1, 2, 1, 2), - # (1, 1, 2, 2), + # Pass + (1, 2, 2, 1), + (1, 2, 1, 2), + (1, 1, 2, 2), + (1, 4, 1, 1), ], ) def run_with_booster_hybridplugin(config: Tuple[int, ...]): @@ -1034,7 +1024,6 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]): return_outputs=True, ) # stage 0 chunk 0 - parallel_output = None if ( booster.plugin.stage_manager.is_first_stage(ignore_chunk=True) and rank == dist.get_process_group_ranks(plugin.pp_group)[0] @@ -1068,9 +1057,8 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]): p.grad /= dp_size torch_optimizer.step() torch_optimizer.zero_grad() - assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) - print(f"rank {dist.get_rank()} pp_size:{pp_size}, tp_size {tp_size}, sp_size :{sp_size} test passed") + clear_layout_converter() Randomizer.reset_index() torch.cuda.empty_cache()