From b2483c8e31c7291f1d083d7c9c59bd79203c3c3b Mon Sep 17 00:00:00 2001 From: Wang Binluo <32676639+wangbluo@users.noreply.github.com> Date: Mon, 12 Aug 2024 18:17:05 +0800 Subject: [PATCH] [fp8] support hybrid parallel plugin (#5982) * support fp8 comm for qwen2 model * support fp8 comm for qwen2 model * support fp8 comm for qwen2 model * fp8 * fix * bert and bloom * chatglm and command * gpt2,gptj,bert, falcon,blip2 * mistral,opy,sam,t5,vit,whisper * fix * fix * fix --- colossalai/shardformer/layer/embedding.py | 6 +- colossalai/shardformer/layer/linear.py | 2 + .../shardformer/layer/qkv_fused_linear.py | 6 +- colossalai/shardformer/modeling/bert.py | 30 ++++++-- colossalai/shardformer/modeling/bloom.py | 20 ++++- colossalai/shardformer/modeling/chatglm2.py | 37 +++++++++- colossalai/shardformer/modeling/command.py | 30 ++++++-- colossalai/shardformer/modeling/gpt2.py | 2 + colossalai/shardformer/modeling/gptj.py | 4 + colossalai/shardformer/modeling/llama.py | 16 +++- colossalai/shardformer/modeling/qwen2.py | 30 ++++++-- colossalai/shardformer/policies/bert.py | 22 +++++- colossalai/shardformer/policies/blip2.py | 70 +++++++++++++++++- colossalai/shardformer/policies/bloom.py | 40 ++++++++-- colossalai/shardformer/policies/chatglm2.py | 16 +++- colossalai/shardformer/policies/command.py | 24 ++++-- colossalai/shardformer/policies/falcon.py | 9 ++- colossalai/shardformer/policies/gpt2.py | 11 ++- colossalai/shardformer/policies/gptj.py | 22 +++++- colossalai/shardformer/policies/llama.py | 10 ++- colossalai/shardformer/policies/mistral.py | 39 +++++++++- colossalai/shardformer/policies/opt.py | 22 +++++- colossalai/shardformer/policies/qwen2.py | 33 ++++++--- colossalai/shardformer/policies/sam.py | 64 ++++++++++++++++ colossalai/shardformer/policies/t5.py | 73 +++++++++++++++++-- colossalai/shardformer/policies/vit.py | 20 ++++- colossalai/shardformer/policies/whisper.py | 58 ++++++++++++++- 27 files changed, 633 insertions(+), 83 deletions(-) diff --git a/colossalai/shardformer/layer/embedding.py b/colossalai/shardformer/layer/embedding.py index 186063503..18efb0ec5 100644 --- a/colossalai/shardformer/layer/embedding.py +++ b/colossalai/shardformer/layer/embedding.py @@ -68,6 +68,7 @@ class Embedding1D(ParallelModule): gather_output: bool = True, weight: Optional[nn.Parameter] = None, weight_initializer: Callable = init.normal_(), + fp8_communication: bool = False, *args, **kwargs, ): @@ -81,6 +82,7 @@ class Embedding1D(ParallelModule): self.embed_args = args self.embed_kwargs = kwargs self.gather_output = gather_output + self.fp8_communication = fp8_communication # offset the seed with randomizer index and rank seed = torch.random.initial_seed() @@ -155,7 +157,9 @@ class Embedding1D(ParallelModule): def forward(self, input_: Tensor) -> Tensor: output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) if self.gather_output: - output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) + output = gather_forward_split_backward( + output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication + ) return output else: return output_parallel diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 38a6ef1a1..4e5ebef0d 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -572,6 +572,7 @@ class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule): weight: Optional[Parameter] = None, bias_: Optional[Parameter] = None, make_vocab_size_divisible_by: int = 64, + fp8_communication: bool = False, **kwargs, ): # create weight and bias @@ -602,6 +603,7 @@ class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule): **kwargs, new_num_embeddings=new_out_features, old_num_embeddings=out_features, + fp8_communication=fp8_communication, ) # get the length of valid embeddings tp_rank = dist.get_rank(process_group) diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index 93a7eb231..561867993 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -627,6 +627,7 @@ class FusedLinear1D_Col(ParallelModule): bias_: Optional[Parameter] = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + fp8_communication: bool = False, ): super().__init__() # Keep input parameters @@ -638,6 +639,7 @@ class FusedLinear1D_Col(ParallelModule): self.n_fused = n_fused self.process_group = process_group self.async_communication = async_communication + self.fp8_communication = fp8_communication if skip_bias_add and not bias: raise ValueError("cannot skip bias addition if bias is None") @@ -767,7 +769,9 @@ class FusedLinear1D_Col(ParallelModule): if self.gather_output: # All-gather across the partitions. - output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) + output = gather_forward_split_backward( + output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication + ) else: output = output_parallel diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index 7710b56e7..580f3618c 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -187,11 +187,17 @@ class BertPipelineForwards: if shard_config is not None and shard_config.enable_sequence_parallelism: if shard_config.sequence_parallelism_mode == "split_gather": hidden_states = split_forward_gather_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) if encoder_hidden_states is not None: encoder_hidden_states = split_forward_gather_backward( - encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + encoder_hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx): @@ -242,7 +248,10 @@ class BertPipelineForwards: if shard_config is not None and shard_config.enable_sequence_parallelism: if shard_config.sequence_parallelism_mode == "split_gather": hidden_states = gather_forward_split_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) if output_hidden_states: @@ -1135,11 +1144,17 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig): # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] embedding_output = split_forward_gather_backward( - embedding_output, dim=1, process_group=shard_config.tensor_parallel_process_group + embedding_output, + dim=1, + process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) if encoder_hidden_states is not None: encoder_hidden_states = split_forward_gather_backward( - encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + encoder_hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) encoder_outputs = self.encoder( @@ -1159,7 +1174,10 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig): # When sequence parallelism done, gather the output tensor in forward and split it in backward sequence_output = gather_forward_split_backward( - sequence_output, dim=1, process_group=shard_config.tensor_parallel_process_group + sequence_output, + dim=1, + process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) pooled_output = self.pooler(sequence_output) if self.pooler is not None else None diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index 26ffef6c5..f8fd4665f 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -221,7 +221,10 @@ class BloomPipelineForwards: if shard_config and shard_config.enable_sequence_parallelism: if shard_config.sequence_parallelism_mode == "split_gather": hidden_states = split_forward_gather_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) start_idx, end_idx = stage_index[0], stage_index[1] @@ -264,7 +267,10 @@ class BloomPipelineForwards: if shard_config and shard_config.enable_sequence_parallelism: if shard_config.sequence_parallelism_mode == "split_gather": hidden_states = gather_forward_split_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) if stage_manager.is_last_stage(): @@ -922,7 +928,10 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig): # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] hidden_states = split_forward_gather_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): @@ -960,7 +969,10 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig): # When sequence parallelism done, gather the output tensor in forward and split it in backward hidden_states = gather_forward_split_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) # Add last hidden state hidden_states = self.ln_f(hidden_states) diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py index 34d900d8d..5fd7b6461 100644 --- a/colossalai/shardformer/modeling/chatglm2.py +++ b/colossalai/shardformer/modeling/chatglm2.py @@ -206,6 +206,7 @@ class ChatGLMPipelineForwards: hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) elif shard_config.sequence_parallelism_mode == "all_to_all": hidden_states = split_forward_gather_backward( @@ -213,6 +214,7 @@ class ChatGLMPipelineForwards: dim=0, process_group=shard_config.sequence_parallel_process_group, grad_scale=1 / shard_config.sequence_parallel_size, + fp8_communication=shard_config.fp8_communication, ) for idx in range(start_idx, end_idx): layer = self.encoder._get_layer(idx) @@ -245,6 +247,7 @@ class ChatGLMPipelineForwards: hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) elif shard_config.sequence_parallelism_mode == "all_to_all": hidden_states = gather_forward_split_backward( @@ -252,6 +255,7 @@ class ChatGLMPipelineForwards: dim=0, process_group=shard_config.sequence_parallel_process_group, grad_scale=shard_config.sequence_parallel_size, + fp8_communication=shard_config.fp8_communication, ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -414,6 +418,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode, inputs_embeds, dim=0, process_group=sp_group, + fp8_communication=shard_config.fp8_communication, ) elif sp_mode == "all_to_all": inputs_embeds = split_forward_gather_backward( @@ -421,6 +426,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode, dim=0, process_group=sp_group, grad_scale=1 / sp_size, + fp8_communication=shard_config.fp8_communication, ) hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( inputs_embeds, @@ -436,6 +442,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode, hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) elif sp_mode == "all_to_all": hidden_states = gather_forward_split_backward( @@ -443,6 +450,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode, dim=0, process_group=sp_group, grad_scale=sp_size, + fp8_communication=shard_config.fp8_communication, ) if not return_dict: @@ -532,9 +540,24 @@ def get_chatglm_sequence_parallel_attention_forward(shard_config: ShardConfig, s key_layer = key_layer.reshape(sq, bs, -1) value_layer = value_layer.reshape(sq, bs, -1) - query_layer = all_to_all_comm(query_layer, sp_group, gather_dim=0) - key_layer = all_to_all_comm(key_layer, sp_group, gather_dim=0) - value_layer = all_to_all_comm(value_layer, sp_group, gather_dim=0) + query_layer = all_to_all_comm( + query_layer, + sp_group, + gather_dim=0, + fp8_communication=shard_config.fp8_communication, + ) + key_layer = all_to_all_comm( + key_layer, + sp_group, + gather_dim=0, + fp8_communication=shard_config.fp8_communication, + ) + value_layer = all_to_all_comm( + value_layer, + sp_group, + gather_dim=0, + fp8_communication=shard_config.fp8_communication, + ) query_layer = query_layer.view( sq * sp_size, @@ -610,7 +633,13 @@ def get_chatglm_sequence_parallel_attention_forward(shard_config: ShardConfig, s context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) if sp_mode == "all_to_all": - context_layer = all_to_all_comm(context_layer, sp_group, gather_dim=2, scatter_dim=0) + context_layer = all_to_all_comm( + context_layer, + sp_group, + gather_dim=2, + scatter_dim=0, + fp8_communication=shard_config.fp8_communication, + ) # ================= # Output. [sq, b, h] diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py index 5b36fc7db..1ece0c118 100644 --- a/colossalai/shardformer/modeling/command.py +++ b/colossalai/shardformer/modeling/command.py @@ -140,6 +140,7 @@ class CommandPipelineForwards: hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) elif shard_config.sequence_parallelism_mode == "all_to_all": hidden_states = split_forward_gather_backward( @@ -147,6 +148,7 @@ class CommandPipelineForwards: dim=1, process_group=shard_config.sequence_parallel_process_group, grad_scale=1 / shard_config.sequence_parallel_size, + fp8_communication=shard_config.fp8_communication, ) # decoder layers @@ -211,6 +213,7 @@ class CommandPipelineForwards: hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) elif shard_config.sequence_parallelism_mode == "all_to_all": hidden_states = gather_forward_split_backward( @@ -218,6 +221,7 @@ class CommandPipelineForwards: dim=1, process_group=shard_config.sequence_parallel_process_group, grad_scale=shard_config.sequence_parallel_size, + fp8_communication=shard_config.fp8_communication, ) # add hidden states from the last decoder layer @@ -382,9 +386,9 @@ def get_command_flash_attention_forward(shard_config, sp_mode=None, sp_size=None # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": - query_states = all_to_all_comm(query_states, sp_group) - key_states = all_to_all_comm(key_states, sp_group) - value_states = all_to_all_comm(value_states, sp_group) + query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication) + key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication) + value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication) bsz, q_len, _ = query_states.size() query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) @@ -446,7 +450,9 @@ def get_command_flash_attention_forward(shard_config, sp_mode=None, sp_size=None # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) - attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) + attn_output = all_to_all_comm( + attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication + ) else: attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) @@ -526,9 +532,13 @@ def get_command_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) if sp_mode in ["ring", "split_gather"]: - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) + inputs_embeds = split_forward_gather_backward( + inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication + ) elif sp_mode == "all_to_all": - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) + inputs_embeds = split_forward_gather_backward( + inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication + ) hidden_states = inputs_embeds # decoder layers @@ -573,9 +583,13 @@ def get_command_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz hidden_states = self.norm(hidden_states) if sp_mode == "ring" or sp_mode == "split_gather": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) + hidden_states = gather_forward_split_backward( + hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication + ) elif sp_mode == "all_to_all": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) + hidden_states = gather_forward_split_backward( + hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication + ) # add hidden states from the last decoder layer if output_hidden_states: diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 0dbf0ca5a..97544e110 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -221,6 +221,7 @@ class GPT2PipelineForwards: hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) # Going through held blocks. @@ -276,6 +277,7 @@ class GPT2PipelineForwards: hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) if stage_manager.is_last_stage(): diff --git a/colossalai/shardformer/modeling/gptj.py b/colossalai/shardformer/modeling/gptj.py index facd2fcaf..51b228712 100644 --- a/colossalai/shardformer/modeling/gptj.py +++ b/colossalai/shardformer/modeling/gptj.py @@ -185,6 +185,7 @@ class GPTJPipelineForwards: hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) # Going through held blocks. @@ -236,6 +237,7 @@ class GPTJPipelineForwards: hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) if stage_manager.is_last_stage(): @@ -915,6 +917,7 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig): hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): @@ -978,6 +981,7 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig): hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) hidden_states = self.ln_f(hidden_states) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 693f6584f..99b31745b 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -143,9 +143,13 @@ class LlamaPipelineForwards: # Support SP + PP if stage_manager.is_first_stage(): if sp_mode in ["ring", "split_gather"]: - hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group) + hidden_states = split_forward_gather_backward( + hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication + ) elif sp_mode == "all_to_all": - hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size) + hidden_states = split_forward_gather_backward( + hidden_states, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication + ) if self.gradient_checkpointing and self.training and use_cache: if use_cache: @@ -210,9 +214,13 @@ class LlamaPipelineForwards: if stage_manager.is_last_stage(): hidden_states = self.norm(hidden_states) if sp_mode == "ring" or sp_mode == "split_gather": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) + hidden_states = gather_forward_split_backward( + hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication + ) elif sp_mode == "all_to_all": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) + hidden_states = gather_forward_split_backward( + hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication + ) # add hidden states from the last decoder layer if output_hidden_states: diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index 538e96c32..a61360d17 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -175,6 +175,7 @@ class Qwen2PipelineForwards: hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) elif shard_config.sequence_parallelism_mode == "all_to_all": hidden_states = split_forward_gather_backward( @@ -182,6 +183,7 @@ class Qwen2PipelineForwards: dim=1, process_group=shard_config.sequence_parallel_process_group, grad_scale=1 / shard_config.sequence_parallel_size, + fp8_communication=shard_config.fp8_communication, ) # decoder layers @@ -246,6 +248,7 @@ class Qwen2PipelineForwards: hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) elif shard_config.sequence_parallelism_mode == "all_to_all": hidden_states = gather_forward_split_backward( @@ -253,6 +256,7 @@ class Qwen2PipelineForwards: dim=1, process_group=shard_config.sequence_parallel_process_group, grad_scale=shard_config.sequence_parallel_size, + fp8_communication=shard_config.fp8_communication, ) # add hidden states from the last decoder layer if output_hidden_states: @@ -516,9 +520,9 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s value_states = self.v_proj(hidden_states) # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": - query_states = all_to_all_comm(query_states, sp_group) - key_states = all_to_all_comm(key_states, sp_group) - value_states = all_to_all_comm(value_states, sp_group) + query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication) + key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication) + value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication) bsz, q_len, _ = query_states.size() query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) @@ -604,7 +608,9 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s attn_output = attn_output.transpose(1, 2).contiguous() if sp_mode == "all_to_all": attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) - attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) + attn_output = all_to_all_comm( + attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication + ) else: attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) @@ -702,9 +708,13 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No next_decoder_cache = None if sp_mode in ["ring", "split_gather"]: - hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group) + hidden_states = split_forward_gather_backward( + hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication + ) elif sp_mode == "all_to_all": - hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size) + hidden_states = split_forward_gather_backward( + hidden_states, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication + ) for decoder_layer in self.layers: if output_hidden_states: @@ -741,9 +751,13 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No hidden_states = self.norm(hidden_states) if sp_mode == "ring" or sp_mode == "split_gather": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) + hidden_states = gather_forward_split_backward( + hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication + ) elif sp_mode == "all_to_all": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) + hidden_states = gather_forward_split_backward( + hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication + ) # add hidden states from the last decoder layer if output_hidden_states: diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index b84a372a5..4c33e14bc 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -98,6 +98,7 @@ class BertPolicy(Policy): kwargs={ "seq_parallel_mode": sp_mode, "overlap": overlap, + "fp8_communication": self.shard_config.fp8_communication, }, ), SubModuleReplacementDescription( @@ -106,6 +107,7 @@ class BertPolicy(Policy): kwargs={ "seq_parallel_mode": sp_mode, "overlap": overlap, + "fp8_communication": self.shard_config.fp8_communication, }, ), SubModuleReplacementDescription( @@ -114,6 +116,7 @@ class BertPolicy(Policy): kwargs={ "seq_parallel_mode": sp_mode, "overlap": overlap, + "fp8_communication": self.shard_config.fp8_communication, }, ), SubModuleReplacementDescription( @@ -123,7 +126,10 @@ class BertPolicy(Policy): SubModuleReplacementDescription( suffix="attention.output.dense", target_module=col_nn.Linear1D_Row, - kwargs={"seq_parallel_mode": sp_mode}, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="attention.output.dropout", @@ -136,12 +142,16 @@ class BertPolicy(Policy): "seq_parallel_mode": sp_mode, "overlap": overlap, "skip_bias_add": self.enable_bias_gelu_fused, + "fp8_communication": self.shard_config.fp8_communication, }, ), SubModuleReplacementDescription( suffix="output.dense", target_module=col_nn.Linear1D_Row, - kwargs={"seq_parallel_mode": sp_mode}, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="output.dropout", @@ -180,6 +190,13 @@ class BertPolicy(Policy): SubModuleReplacementDescription( suffix="word_embeddings", target_module=embedding_cls, + kwargs=( + { + "fp8_communication": self.shard_config.fp8_communication, + } + if self.shard_config.enable_tensor_parallelism + else {} + ), ) ], policy=policy, @@ -249,6 +266,7 @@ class BertPolicy(Policy): kwargs={ "gather_output": True, "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, }, ), policy=base_policy, diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py index 32d4edadb..da798f6a0 100644 --- a/colossalai/shardformer/policies/blip2.py +++ b/colossalai/shardformer/policies/blip2.py @@ -72,20 +72,30 @@ class BlipPolicy(Policy): target_module=col_nn.FusedLinear1D_Col, kwargs={ "n_fused": 3, + "fp8_communication": self.shard_config.fp8_communication, }, ), SubModuleReplacementDescription( suffix="self_attn.projection", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="mlp.fc1", target_module=col_nn.Linear1D_Col, - kwargs={"skip_bias_add": self.enable_bias_gelu_fused}, + kwargs={ + "skip_bias_add": self.enable_bias_gelu_fused, + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="mlp.fc2", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), ], ) @@ -114,14 +124,23 @@ class BlipPolicy(Policy): SubModuleReplacementDescription( suffix="attention.attention.query", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="attention.attention.key", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="attention.attention.value", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="attention.attention.dropout", @@ -130,6 +149,9 @@ class BlipPolicy(Policy): SubModuleReplacementDescription( suffix="attention.output.dense", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="attention.output.dropout", @@ -138,14 +160,23 @@ class BlipPolicy(Policy): SubModuleReplacementDescription( suffix="crossattention.attention.query", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="crossattention.attention.key", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="crossattention.attention.value", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="crossattention.attention.dropout", @@ -154,6 +185,9 @@ class BlipPolicy(Policy): SubModuleReplacementDescription( suffix="crossattention.output.dense", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="crossattention.output.dropout", @@ -162,10 +196,16 @@ class BlipPolicy(Policy): SubModuleReplacementDescription( suffix="intermediate_query.dense", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="output_query.dense", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="output_query.dropout", @@ -185,26 +225,44 @@ class BlipPolicy(Policy): SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="self_attn.out_proj", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="fc1", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="fc2", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), ], ) @@ -225,7 +283,14 @@ class BlipPolicy(Policy): SubModuleReplacementDescription( suffix="model.decoder.embed_tokens", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + kwargs=( + { + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + } + if self.shard_config.enable_tensor_parallelism + else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by} + ), ), ], policy=policy, @@ -241,6 +306,7 @@ class BlipPolicy(Policy): kwargs={ "gather_output": True, "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, }, ), ], diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index d80adb84a..a43ac02d0 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -76,12 +76,19 @@ class BloomPolicy(Policy): SubModuleReplacementDescription( suffix="self_attention.query_key_value", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel_mode": sp_mode, "overlap": overlap}, + kwargs={ + "seq_parallel_mode": sp_mode, + "overlap": overlap, + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="self_attention.dense", target_module=col_nn.Linear1D_Row, - kwargs={"seq_parallel_mode": sp_mode}, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="self_attention.attention_dropout", @@ -90,12 +97,19 @@ class BloomPolicy(Policy): SubModuleReplacementDescription( suffix="mlp.dense_h_to_4h", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel_mode": sp_mode, "overlap": overlap}, + kwargs={ + "seq_parallel_mode": sp_mode, + "overlap": overlap, + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="mlp.dense_4h_to_h", target_module=col_nn.Linear1D_Row, - kwargs={"seq_parallel_mode": sp_mode}, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + }, ), ], ) @@ -115,7 +129,14 @@ class BloomPolicy(Policy): SubModuleReplacementDescription( suffix="word_embeddings", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + kwargs=( + { + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + } + if self.shard_config.enable_tensor_parallelism + else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by} + ), ), ], policy=policy, @@ -279,6 +300,7 @@ class BloomForCausalLMPolicy(BloomPolicy): kwargs=dict( gather_output=not self.shard_config.parallel_output, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by, + fp8_communication=self.shard_config.fp8_communication, ), ), policy=policy, @@ -337,7 +359,9 @@ class BloomForSequenceClassificationPolicy(BloomPolicy): if self.shard_config.enable_tensor_parallelism: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( - suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True) + suffix="score", + target_module=col_nn.Linear1D_Col, + kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication), ), policy=policy, target_key=BloomForSequenceClassification, @@ -374,7 +398,9 @@ class BloomForTokenClassificationPolicy(BloomPolicy): self.append_or_create_submodule_replacement( description=[ SubModuleReplacementDescription( - suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True) + suffix="classifier", + target_module=col_nn.Linear1D_Col, + kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="dropout", diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index 3877bdac3..16c13085a 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -128,12 +128,17 @@ class ChatGLMPolicy(Policy): "seq_parallel_mode": sp_mode, "seq_parallel_dim": 0, "overlap": overlap, + "fp8_communication": self.shard_config.fp8_communication, }, ), SubModuleReplacementDescription( suffix="self_attention.dense", target_module=col_nn.Linear1D_Row, - kwargs={"seq_parallel_mode": sp_mode, "seq_parallel_dim": 0}, + kwargs={ + "seq_parallel_mode": sp_mode, + "seq_parallel_dim": 0, + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="self_attention.core_attention.attention_dropout", @@ -148,7 +153,14 @@ class ChatGLMPolicy(Policy): SubModuleReplacementDescription( suffix="embedding.word_embeddings", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + kwargs=( + { + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + } + if self.shard_config.enable_tensor_parallelism + else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by} + ), ), ], policy=policy, diff --git a/colossalai/shardformer/policies/command.py b/colossalai/shardformer/policies/command.py index a9b915d10..d0903b11c 100644 --- a/colossalai/shardformer/policies/command.py +++ b/colossalai/shardformer/policies/command.py @@ -126,37 +126,37 @@ class CommandPolicy(Policy): SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="self_attn.o_proj", target_module=Linear1D_Row, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="mlp.gate_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="mlp.up_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="mlp.down_proj", target_module=Linear1D_Row, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), ], ) @@ -166,7 +166,14 @@ class CommandPolicy(Policy): description=SubModuleReplacementDescription( suffix="embed_tokens", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + kwargs=( + { + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + } + if self.shard_config.enable_tensor_parallelism + else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by} + ), ), policy=policy, target_key=CohereModel, @@ -303,6 +310,7 @@ class CommandForCausalLMPolicy(CommandPolicy): kwargs={ "gather_output": not self.shard_config.parallel_output, "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, }, ) ], diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py index e5c167337..e20fb1568 100644 --- a/colossalai/shardformer/policies/falcon.py +++ b/colossalai/shardformer/policies/falcon.py @@ -105,7 +105,14 @@ class FalconPolicy(Policy): SubModuleReplacementDescription( suffix="word_embeddings", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + kwargs=( + { + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + } + if self.shard_config.enable_tensor_parallelism + else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by} + ), ), ], policy=policy, diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index bb6269737..0a1949d85 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -162,7 +162,14 @@ class GPT2Policy(Policy): description=SubModuleReplacementDescription( suffix="wte", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + kwargs=( + { + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + } + if self.shard_config.enable_tensor_parallelism + else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by} + ), ), policy=policy, target_key=GPT2Model, @@ -332,6 +339,7 @@ class GPT2LMHeadModelPolicy(GPT2Policy): kwargs={ "gather_output": False, "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, }, ) ], @@ -402,6 +410,7 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy): kwargs={ "gather_output": True, "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, }, ) ] diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py index c394d911e..6f0c8803c 100644 --- a/colossalai/shardformer/policies/gptj.py +++ b/colossalai/shardformer/policies/gptj.py @@ -77,6 +77,7 @@ class GPTJPolicy(Policy): target_module=col_nn.Linear1D_Col, kwargs={ "overlap": overlap, + "fp8_communication": self.shard_config.fp8_communication, }, ), SubModuleReplacementDescription( @@ -84,6 +85,7 @@ class GPTJPolicy(Policy): target_module=col_nn.Linear1D_Col, kwargs={ "overlap": overlap, + "fp8_communication": self.shard_config.fp8_communication, }, ), SubModuleReplacementDescription( @@ -91,19 +93,29 @@ class GPTJPolicy(Policy): target_module=col_nn.Linear1D_Col, kwargs={ "overlap": overlap, + "fp8_communication": self.shard_config.fp8_communication, }, ), SubModuleReplacementDescription( suffix="attn.out_proj", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="mlp.fc_in", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="mlp.fc_out", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="attn.attn_dropout", @@ -125,7 +137,14 @@ class GPTJPolicy(Policy): description=SubModuleReplacementDescription( suffix="wte", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + kwargs=( + { + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + } + if self.shard_config.enable_tensor_parallelism + else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by} + ), ), policy=policy, target_key=GPTJModel, @@ -264,6 +283,7 @@ class GPTJForCausalLMPolicy(GPTJPolicy): kwargs={ "gather_output": True, "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, }, ) ] diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 6f8404219..23d2d3913 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -166,7 +166,14 @@ class LlamaPolicy(Policy): description=SubModuleReplacementDescription( suffix="embed_tokens", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + kwargs=( + { + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + } + if self.shard_config.enable_tensor_parallelism + else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by} + ), ), policy=policy, target_key=LlamaModel, @@ -308,6 +315,7 @@ class LlamaForCausalLMPolicy(LlamaPolicy): kwargs={ "gather_output": not self.shard_config.parallel_output, "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, }, ) ], diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index c5a0277a5..72a5158e5 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -88,30 +88,51 @@ class MistralPolicy(Policy): SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="self_attn.o_proj", target_module=Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="mlp.gate_proj", target_module=Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="mlp.up_proj", target_module=Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="mlp.down_proj", target_module=Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), ], ) @@ -121,7 +142,14 @@ class MistralPolicy(Policy): description=SubModuleReplacementDescription( suffix="embed_tokens", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + kwargs=( + { + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + } + if self.shard_config.enable_tensor_parallelism + else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by} + ), ), policy=policy, target_key=MistralModel, @@ -281,6 +309,7 @@ class MistralForCausalLMPolicy(MistralPolicy): kwargs={ "gather_output": not self.shard_config.parallel_output, "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, }, ) ] @@ -297,7 +326,9 @@ class MistralForCausalLMPolicy(MistralPolicy): SubModuleReplacementDescription( suffix="lm_head", target_module=PaddingLMHead, - kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by), + kwargs=dict( + make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by, + ), ) ] ) @@ -350,7 +381,9 @@ class MistralForSequenceClassificationPolicy(MistralPolicy): MistralForSequenceClassification: ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( - suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + suffix="score", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication), ) ] ) diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 524d2b8cd..dd64ce652 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -102,18 +102,30 @@ class OPTPolicy(Policy): SubModuleReplacementDescription( suffix="q_proj", target_module=Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="k_proj", target_module=Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="v_proj", target_module=Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="out_proj", target_module=Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), ], ) @@ -123,7 +135,14 @@ class OPTPolicy(Policy): description=SubModuleReplacementDescription( suffix="embed_tokens", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + kwargs=( + { + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + } + if self.shard_config.enable_tensor_parallelism + else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by} + ), ), policy=policy, target_key=OPTDecoder, @@ -272,6 +291,7 @@ class OPTForCausalLMPolicy(OPTPolicy): kwargs=dict( gather_output=not self.shard_config.parallel_output, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by, + fp8_communication=self.shard_config.fp8_communication, ), ), policy=policy, diff --git a/colossalai/shardformer/policies/qwen2.py b/colossalai/shardformer/policies/qwen2.py index 362c14060..1b066200d 100644 --- a/colossalai/shardformer/policies/qwen2.py +++ b/colossalai/shardformer/policies/qwen2.py @@ -119,37 +119,37 @@ class Qwen2Policy(Policy): SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="self_attn.o_proj", target_module=Linear1D_Row, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="mlp.gate_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="mlp.up_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="mlp.down_proj", target_module=Linear1D_Row, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), ], ) @@ -159,7 +159,14 @@ class Qwen2Policy(Policy): description=SubModuleReplacementDescription( suffix="embed_tokens", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + kwargs=( + { + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + } + if self.shard_config.enable_tensor_parallelism + else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by} + ), ), policy=policy, target_key=Qwen2Model, @@ -317,7 +324,11 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy): new_item = { Qwen2ForCausalLM: ModulePolicyDescription( sub_module_replacement=[ - SubModuleReplacementDescription(suffix="lm_head", target_module=Linear1D_Col) + SubModuleReplacementDescription( + suffix="lm_head", + target_module=Linear1D_Col, + kwargs=dict(fp8_communication=self.shard_config.fp8_communication), + ) ], method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}, ) @@ -366,7 +377,9 @@ class Qwen2ForSequenceClassificationPolicy(Qwen2Policy): Qwen2ForSequenceClassification: ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( - suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + suffix="score", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication), ) ] ) diff --git a/colossalai/shardformer/policies/sam.py b/colossalai/shardformer/policies/sam.py index 53faf8997..674fe5e58 100644 --- a/colossalai/shardformer/policies/sam.py +++ b/colossalai/shardformer/policies/sam.py @@ -43,19 +43,29 @@ class SamPolicy(Policy): target_module=col_nn.FusedLinear1D_Col, kwargs={ "n_fused": 3, + "fp8_communication": self.shard_config.fp8_communication, }, ), SubModuleReplacementDescription( suffix="attn.proj", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="mlp.lin1", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="mlp.lin2", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), ], ) @@ -68,58 +78,100 @@ class SamPolicy(Policy): SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="self_attn.out_proj", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="cross_attn_token_to_image.q_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="cross_attn_token_to_image.k_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="cross_attn_token_to_image.v_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="cross_attn_token_to_image.out_proj", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="mlp.lin1", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="mlp.lin2", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="cross_attn_image_to_token.q_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="cross_attn_image_to_token.k_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="cross_attn_image_to_token.v_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="cross_attn_image_to_token.out_proj", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), ], ) @@ -132,18 +184,30 @@ class SamPolicy(Policy): SubModuleReplacementDescription( suffix="final_attn_token_to_image.q_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="final_attn_token_to_image.k_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="final_attn_token_to_image.v_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="final_attn_token_to_image.out_proj", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), ], ) diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 0b594678c..84b5d9594 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -117,23 +117,38 @@ class T5BasePolicy(Policy): SubModuleReplacementDescription( suffix="q", target_module=Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="k", target_module=Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="v", target_module=Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="o", target_module=Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="relative_attention_bias", target_module=Embedding1D, - kwargs=dict(gather_output=False), + kwargs=dict( + gather_output=False, + fp8_communication=self.shard_config.fp8_communication, + ), ignore_if_not_exist=True, ), ], @@ -151,13 +166,24 @@ class T5BasePolicy(Policy): SubModuleReplacementDescription( suffix="wi_0 ", target_module=Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="wi_1", target_module=Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( - suffix="wo", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + suffix="wo", + target_module=Linear1D_Col, + kwargs=dict( + gather_output=True, + fp8_communication=self.shard_config.fp8_communication, + ), ), SubModuleReplacementDescription( suffix="dropout", @@ -170,10 +196,16 @@ class T5BasePolicy(Policy): SubModuleReplacementDescription( suffix="wi", target_module=Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="wo", target_module=Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="dropout", @@ -187,7 +219,14 @@ class T5BasePolicy(Policy): description=SubModuleReplacementDescription( suffix="embed_tokens", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + kwargs=( + { + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + } + if self.shard_config.enable_tensor_parallelism + else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by} + ), ), policy=policy, target_key=T5Stack, @@ -407,7 +446,14 @@ class T5ModelPolicy(T5BasePolicy): description=SubModuleReplacementDescription( suffix="shared", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + kwargs=( + { + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + } + if self.shard_config.enable_tensor_parallelism + else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by} + ), ), policy=policy, target_key=T5Model, @@ -451,7 +497,14 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy): description=SubModuleReplacementDescription( suffix="shared", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + kwargs=( + { + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + } + if self.shard_config.enable_tensor_parallelism + else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by} + ), ), policy=policy, target_key=T5ForConditionalGeneration, @@ -465,6 +518,7 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy): kwargs={ "gather_output": True, "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, }, ), policy=policy, @@ -539,7 +593,14 @@ class T5EncoderPolicy(T5BasePolicy): description=SubModuleReplacementDescription( suffix="shared", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + kwargs=( + { + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + } + if self.shard_config.enable_tensor_parallelism + else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by} + ), ), policy=policy, target_key=T5EncoderModel, diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 069ad0c26..07202094f 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -70,14 +70,23 @@ class ViTPolicy(Policy): SubModuleReplacementDescription( suffix="attention.attention.query", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="attention.attention.key", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="attention.attention.value", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="attention.attention.dropout", @@ -86,6 +95,9 @@ class ViTPolicy(Policy): SubModuleReplacementDescription( suffix="attention.output.dense", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="attention.output.dropout", @@ -96,11 +108,15 @@ class ViTPolicy(Policy): target_module=col_nn.Linear1D_Col, kwargs={ "skip_bias_add": self.enable_bias_gelu_fused, + "fp8_communication": self.shard_config.fp8_communication, }, ), SubModuleReplacementDescription( suffix="output.dense", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="output.dropout", @@ -215,7 +231,9 @@ class ViTForImageClassificationPolicy(ViTPolicy): ViTForImageClassification: ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( - suffix="classifier", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + suffix="classifier", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication), ) ] ) diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index 441e512bb..7a1f146d5 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -91,26 +91,44 @@ class WhisperPolicy(Policy): SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="self_attn.out_proj", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="fc1", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="fc2", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), ], ) @@ -128,42 +146,72 @@ class WhisperPolicy(Policy): SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="self_attn.out_proj", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="encoder_attn.q_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="encoder_attn.k_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="encoder_attn.v_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="encoder_attn.out_proj", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="fc1", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="fc2", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), ], ) @@ -174,7 +222,14 @@ class WhisperPolicy(Policy): SubModuleReplacementDescription( suffix="embed_tokens", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + kwargs=( + { + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + } + if self.shard_config.enable_tensor_parallelism + else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by} + ), ), ], policy=policy, @@ -303,6 +358,7 @@ class WhisperPolicy(Policy): kwargs={ "gather_output": True, "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, }, ), policy=base_policy,