[fp8] support hybrid parallel plugin (#5982)

* support fp8 comm for qwen2 model

* support fp8 comm for qwen2 model

* support fp8 comm for qwen2 model

* fp8

* fix

* bert and bloom

* chatglm and command

* gpt2,gptj,bert, falcon,blip2

* mistral,opy,sam,t5,vit,whisper

* fix

* fix

* fix
pull/5998/head
Wang Binluo 2024-08-12 18:17:05 +08:00 committed by GitHub
parent f1a3a326c4
commit b2483c8e31
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 633 additions and 83 deletions

View File

@ -68,6 +68,7 @@ class Embedding1D(ParallelModule):
gather_output: bool = True,
weight: Optional[nn.Parameter] = None,
weight_initializer: Callable = init.normal_(),
fp8_communication: bool = False,
*args,
**kwargs,
):
@ -81,6 +82,7 @@ class Embedding1D(ParallelModule):
self.embed_args = args
self.embed_kwargs = kwargs
self.gather_output = gather_output
self.fp8_communication = fp8_communication
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
@ -155,7 +157,9 @@ class Embedding1D(ParallelModule):
def forward(self, input_: Tensor) -> Tensor:
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
if self.gather_output:
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
output = gather_forward_split_backward(
output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
)
return output
else:
return output_parallel

View File

@ -572,6 +572,7 @@ class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule):
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
make_vocab_size_divisible_by: int = 64,
fp8_communication: bool = False,
**kwargs,
):
# create weight and bias
@ -602,6 +603,7 @@ class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule):
**kwargs,
new_num_embeddings=new_out_features,
old_num_embeddings=out_features,
fp8_communication=fp8_communication,
)
# get the length of valid embeddings
tp_rank = dist.get_rank(process_group)

View File

@ -627,6 +627,7 @@ class FusedLinear1D_Col(ParallelModule):
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
fp8_communication: bool = False,
):
super().__init__()
# Keep input parameters
@ -638,6 +639,7 @@ class FusedLinear1D_Col(ParallelModule):
self.n_fused = n_fused
self.process_group = process_group
self.async_communication = async_communication
self.fp8_communication = fp8_communication
if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None")
@ -767,7 +769,9 @@ class FusedLinear1D_Col(ParallelModule):
if self.gather_output:
# All-gather across the partitions.
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
output = gather_forward_split_backward(
output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
)
else:
output = output_parallel

View File

@ -187,11 +187,17 @@ class BertPipelineForwards:
if shard_config is not None and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = split_forward_gather_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
if encoder_hidden_states is not None:
encoder_hidden_states = split_forward_gather_backward(
encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
encoder_hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx):
@ -242,7 +248,10 @@ class BertPipelineForwards:
if shard_config is not None and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = gather_forward_split_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
if output_hidden_states:
@ -1135,11 +1144,17 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig):
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
embedding_output = split_forward_gather_backward(
embedding_output, dim=1, process_group=shard_config.tensor_parallel_process_group
embedding_output,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
if encoder_hidden_states is not None:
encoder_hidden_states = split_forward_gather_backward(
encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
encoder_hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
encoder_outputs = self.encoder(
@ -1159,7 +1174,10 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig):
# When sequence parallelism done, gather the output tensor in forward and split it in backward
sequence_output = gather_forward_split_backward(
sequence_output, dim=1, process_group=shard_config.tensor_parallel_process_group
sequence_output,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

View File

@ -221,7 +221,10 @@ class BloomPipelineForwards:
if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = split_forward_gather_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
start_idx, end_idx = stage_index[0], stage_index[1]
@ -264,7 +267,10 @@ class BloomPipelineForwards:
if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = gather_forward_split_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
if stage_manager.is_last_stage():
@ -922,7 +928,10 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
hidden_states = split_forward_gather_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
@ -960,7 +969,10 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
# When sequence parallelism done, gather the output tensor in forward and split it in backward
hidden_states = gather_forward_split_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
# Add last hidden state
hidden_states = self.ln_f(hidden_states)

View File

@ -206,6 +206,7 @@ class ChatGLMPipelineForwards:
hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = split_forward_gather_backward(
@ -213,6 +214,7 @@ class ChatGLMPipelineForwards:
dim=0,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=1 / shard_config.sequence_parallel_size,
fp8_communication=shard_config.fp8_communication,
)
for idx in range(start_idx, end_idx):
layer = self.encoder._get_layer(idx)
@ -245,6 +247,7 @@ class ChatGLMPipelineForwards:
hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
@ -252,6 +255,7 @@ class ChatGLMPipelineForwards:
dim=0,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=shard_config.sequence_parallel_size,
fp8_communication=shard_config.fp8_communication,
)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
@ -414,6 +418,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
inputs_embeds,
dim=0,
process_group=sp_group,
fp8_communication=shard_config.fp8_communication,
)
elif sp_mode == "all_to_all":
inputs_embeds = split_forward_gather_backward(
@ -421,6 +426,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
dim=0,
process_group=sp_group,
grad_scale=1 / sp_size,
fp8_communication=shard_config.fp8_communication,
)
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
inputs_embeds,
@ -436,6 +442,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
@ -443,6 +450,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
dim=0,
process_group=sp_group,
grad_scale=sp_size,
fp8_communication=shard_config.fp8_communication,
)
if not return_dict:
@ -532,9 +540,24 @@ def get_chatglm_sequence_parallel_attention_forward(shard_config: ShardConfig, s
key_layer = key_layer.reshape(sq, bs, -1)
value_layer = value_layer.reshape(sq, bs, -1)
query_layer = all_to_all_comm(query_layer, sp_group, gather_dim=0)
key_layer = all_to_all_comm(key_layer, sp_group, gather_dim=0)
value_layer = all_to_all_comm(value_layer, sp_group, gather_dim=0)
query_layer = all_to_all_comm(
query_layer,
sp_group,
gather_dim=0,
fp8_communication=shard_config.fp8_communication,
)
key_layer = all_to_all_comm(
key_layer,
sp_group,
gather_dim=0,
fp8_communication=shard_config.fp8_communication,
)
value_layer = all_to_all_comm(
value_layer,
sp_group,
gather_dim=0,
fp8_communication=shard_config.fp8_communication,
)
query_layer = query_layer.view(
sq * sp_size,
@ -610,7 +633,13 @@ def get_chatglm_sequence_parallel_attention_forward(shard_config: ShardConfig, s
context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
if sp_mode == "all_to_all":
context_layer = all_to_all_comm(context_layer, sp_group, gather_dim=2, scatter_dim=0)
context_layer = all_to_all_comm(
context_layer,
sp_group,
gather_dim=2,
scatter_dim=0,
fp8_communication=shard_config.fp8_communication,
)
# =================
# Output. [sq, b, h]

View File

@ -140,6 +140,7 @@ class CommandPipelineForwards:
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = split_forward_gather_backward(
@ -147,6 +148,7 @@ class CommandPipelineForwards:
dim=1,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=1 / shard_config.sequence_parallel_size,
fp8_communication=shard_config.fp8_communication,
)
# decoder layers
@ -211,6 +213,7 @@ class CommandPipelineForwards:
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
@ -218,6 +221,7 @@ class CommandPipelineForwards:
dim=1,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=shard_config.sequence_parallel_size,
fp8_communication=shard_config.fp8_communication,
)
# add hidden states from the last decoder layer
@ -382,9 +386,9 @@ def get_command_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
query_states = all_to_all_comm(query_states, sp_group)
key_states = all_to_all_comm(key_states, sp_group)
value_states = all_to_all_comm(value_states, sp_group)
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
bsz, q_len, _ = query_states.size()
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
@ -446,7 +450,9 @@ def get_command_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
attn_output = all_to_all_comm(
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
)
else:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
@ -526,9 +532,13 @@ def get_command_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
if sp_mode in ["ring", "split_gather"]:
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
inputs_embeds = split_forward_gather_backward(
inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
inputs_embeds = split_forward_gather_backward(
inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
)
hidden_states = inputs_embeds
# decoder layers
@ -573,9 +583,13 @@ def get_command_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
hidden_states = self.norm(hidden_states)
if sp_mode == "ring" or sp_mode == "split_gather":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
)
# add hidden states from the last decoder layer
if output_hidden_states:

View File

@ -221,6 +221,7 @@ class GPT2PipelineForwards:
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
# Going through held blocks.
@ -276,6 +277,7 @@ class GPT2PipelineForwards:
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
if stage_manager.is_last_stage():

View File

@ -185,6 +185,7 @@ class GPTJPipelineForwards:
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
# Going through held blocks.
@ -236,6 +237,7 @@ class GPTJPipelineForwards:
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
if stage_manager.is_last_stage():
@ -915,6 +917,7 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig):
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
@ -978,6 +981,7 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig):
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
hidden_states = self.ln_f(hidden_states)

View File

@ -143,9 +143,13 @@ class LlamaPipelineForwards:
# Support SP + PP
if stage_manager.is_first_stage():
if sp_mode in ["ring", "split_gather"]:
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group)
hidden_states = split_forward_gather_backward(
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size)
hidden_states = split_forward_gather_backward(
hidden_states, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
)
if self.gradient_checkpointing and self.training and use_cache:
if use_cache:
@ -210,9 +214,13 @@ class LlamaPipelineForwards:
if stage_manager.is_last_stage():
hidden_states = self.norm(hidden_states)
if sp_mode == "ring" or sp_mode == "split_gather":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
)
# add hidden states from the last decoder layer
if output_hidden_states:

View File

@ -175,6 +175,7 @@ class Qwen2PipelineForwards:
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = split_forward_gather_backward(
@ -182,6 +183,7 @@ class Qwen2PipelineForwards:
dim=1,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=1 / shard_config.sequence_parallel_size,
fp8_communication=shard_config.fp8_communication,
)
# decoder layers
@ -246,6 +248,7 @@ class Qwen2PipelineForwards:
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
@ -253,6 +256,7 @@ class Qwen2PipelineForwards:
dim=1,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=shard_config.sequence_parallel_size,
fp8_communication=shard_config.fp8_communication,
)
# add hidden states from the last decoder layer
if output_hidden_states:
@ -516,9 +520,9 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
value_states = self.v_proj(hidden_states)
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
query_states = all_to_all_comm(query_states, sp_group)
key_states = all_to_all_comm(key_states, sp_group)
value_states = all_to_all_comm(value_states, sp_group)
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
bsz, q_len, _ = query_states.size()
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
@ -604,7 +608,9 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
attn_output = attn_output.transpose(1, 2).contiguous()
if sp_mode == "all_to_all":
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
attn_output = all_to_all_comm(
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
)
else:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
@ -702,9 +708,13 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
next_decoder_cache = None
if sp_mode in ["ring", "split_gather"]:
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group)
hidden_states = split_forward_gather_backward(
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size)
hidden_states = split_forward_gather_backward(
hidden_states, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
)
for decoder_layer in self.layers:
if output_hidden_states:
@ -741,9 +751,13 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
hidden_states = self.norm(hidden_states)
if sp_mode == "ring" or sp_mode == "split_gather":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
)
# add hidden states from the last decoder layer
if output_hidden_states:

View File

@ -98,6 +98,7 @@ class BertPolicy(Policy):
kwargs={
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
@ -106,6 +107,7 @@ class BertPolicy(Policy):
kwargs={
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
@ -114,6 +116,7 @@ class BertPolicy(Policy):
kwargs={
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
@ -123,7 +126,10 @@ class BertPolicy(Policy):
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel_mode": sp_mode},
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
@ -136,12 +142,16 @@ class BertPolicy(Policy):
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"skip_bias_add": self.enable_bias_gelu_fused,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="output.dense",
target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel_mode": sp_mode},
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="output.dropout",
@ -180,6 +190,13 @@ class BertPolicy(Policy):
SubModuleReplacementDescription(
suffix="word_embeddings",
target_module=embedding_cls,
kwargs=(
{
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {}
),
)
],
policy=policy,
@ -249,6 +266,7 @@ class BertPolicy(Policy):
kwargs={
"gather_output": True,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
},
),
policy=base_policy,

View File

@ -72,20 +72,30 @@ class BlipPolicy(Policy):
target_module=col_nn.FusedLinear1D_Col,
kwargs={
"n_fused": 3,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attn.projection",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="mlp.fc1",
target_module=col_nn.Linear1D_Col,
kwargs={"skip_bias_add": self.enable_bias_gelu_fused},
kwargs={
"skip_bias_add": self.enable_bias_gelu_fused,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="mlp.fc2",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
],
)
@ -114,14 +124,23 @@ class BlipPolicy(Policy):
SubModuleReplacementDescription(
suffix="attention.attention.query",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="attention.attention.key",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="attention.attention.value",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="attention.attention.dropout",
@ -130,6 +149,9 @@ class BlipPolicy(Policy):
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
@ -138,14 +160,23 @@ class BlipPolicy(Policy):
SubModuleReplacementDescription(
suffix="crossattention.attention.query",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="crossattention.attention.key",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="crossattention.attention.value",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="crossattention.attention.dropout",
@ -154,6 +185,9 @@ class BlipPolicy(Policy):
SubModuleReplacementDescription(
suffix="crossattention.output.dense",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="crossattention.output.dropout",
@ -162,10 +196,16 @@ class BlipPolicy(Policy):
SubModuleReplacementDescription(
suffix="intermediate_query.dense",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="output_query.dense",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="output_query.dropout",
@ -185,26 +225,44 @@ class BlipPolicy(Policy):
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attn.out_proj",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="fc1",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="fc2",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
],
)
@ -225,7 +283,14 @@ class BlipPolicy(Policy):
SubModuleReplacementDescription(
suffix="model.decoder.embed_tokens",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
],
policy=policy,
@ -241,6 +306,7 @@ class BlipPolicy(Policy):
kwargs={
"gather_output": True,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
},
),
],

View File

@ -76,12 +76,19 @@ class BloomPolicy(Policy):
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel_mode": sp_mode, "overlap": overlap},
kwargs={
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attention.dense",
target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel_mode": sp_mode},
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attention.attention_dropout",
@ -90,12 +97,19 @@ class BloomPolicy(Policy):
SubModuleReplacementDescription(
suffix="mlp.dense_h_to_4h",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel_mode": sp_mode, "overlap": overlap},
kwargs={
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="mlp.dense_4h_to_h",
target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel_mode": sp_mode},
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
},
),
],
)
@ -115,7 +129,14 @@ class BloomPolicy(Policy):
SubModuleReplacementDescription(
suffix="word_embeddings",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
],
policy=policy,
@ -279,6 +300,7 @@ class BloomForCausalLMPolicy(BloomPolicy):
kwargs=dict(
gather_output=not self.shard_config.parallel_output,
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
fp8_communication=self.shard_config.fp8_communication,
),
),
policy=policy,
@ -337,7 +359,9 @@ class BloomForSequenceClassificationPolicy(BloomPolicy):
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)
suffix="score",
target_module=col_nn.Linear1D_Col,
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
),
policy=policy,
target_key=BloomForSequenceClassification,
@ -374,7 +398,9 @@ class BloomForTokenClassificationPolicy(BloomPolicy):
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)
suffix="classifier",
target_module=col_nn.Linear1D_Col,
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
),
SubModuleReplacementDescription(
suffix="dropout",

View File

@ -128,12 +128,17 @@ class ChatGLMPolicy(Policy):
"seq_parallel_mode": sp_mode,
"seq_parallel_dim": 0,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attention.dense",
target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel_mode": sp_mode, "seq_parallel_dim": 0},
kwargs={
"seq_parallel_mode": sp_mode,
"seq_parallel_dim": 0,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attention.core_attention.attention_dropout",
@ -148,7 +153,14 @@ class ChatGLMPolicy(Policy):
SubModuleReplacementDescription(
suffix="embedding.word_embeddings",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
],
policy=policy,

View File

@ -126,37 +126,37 @@ class CommandPolicy(Policy):
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode),
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode),
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode),
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
kwargs=dict(seq_parallel_mode=sp_mode),
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode),
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode),
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=Linear1D_Row,
kwargs=dict(seq_parallel_mode=sp_mode),
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
),
],
)
@ -166,7 +166,14 @@ class CommandPolicy(Policy):
description=SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
policy=policy,
target_key=CohereModel,
@ -303,6 +310,7 @@ class CommandForCausalLMPolicy(CommandPolicy):
kwargs={
"gather_output": not self.shard_config.parallel_output,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
},
)
],

View File

@ -105,7 +105,14 @@ class FalconPolicy(Policy):
SubModuleReplacementDescription(
suffix="word_embeddings",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
],
policy=policy,

View File

@ -162,7 +162,14 @@ class GPT2Policy(Policy):
description=SubModuleReplacementDescription(
suffix="wte",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
policy=policy,
target_key=GPT2Model,
@ -332,6 +339,7 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
kwargs={
"gather_output": False,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
},
)
],
@ -402,6 +410,7 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
kwargs={
"gather_output": True,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
},
)
]

View File

@ -77,6 +77,7 @@ class GPTJPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
@ -84,6 +85,7 @@ class GPTJPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
@ -91,19 +93,29 @@ class GPTJPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="attn.out_proj",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="mlp.fc_in",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="mlp.fc_out",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="attn.attn_dropout",
@ -125,7 +137,14 @@ class GPTJPolicy(Policy):
description=SubModuleReplacementDescription(
suffix="wte",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
policy=policy,
target_key=GPTJModel,
@ -264,6 +283,7 @@ class GPTJForCausalLMPolicy(GPTJPolicy):
kwargs={
"gather_output": True,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
},
)
]

View File

@ -166,7 +166,14 @@ class LlamaPolicy(Policy):
description=SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
policy=policy,
target_key=LlamaModel,
@ -308,6 +315,7 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
kwargs={
"gather_output": not self.shard_config.parallel_output,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
},
)
],

View File

@ -88,30 +88,51 @@ class MistralPolicy(Policy):
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
],
)
@ -121,7 +142,14 @@ class MistralPolicy(Policy):
description=SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
policy=policy,
target_key=MistralModel,
@ -281,6 +309,7 @@ class MistralForCausalLMPolicy(MistralPolicy):
kwargs={
"gather_output": not self.shard_config.parallel_output,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
},
)
]
@ -297,7 +326,9 @@ class MistralForCausalLMPolicy(MistralPolicy):
SubModuleReplacementDescription(
suffix="lm_head",
target_module=PaddingLMHead,
kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by),
kwargs=dict(
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
),
)
]
)
@ -350,7 +381,9 @@ class MistralForSequenceClassificationPolicy(MistralPolicy):
MistralForSequenceClassification: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
suffix="score",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
)
]
)

View File

@ -102,18 +102,30 @@ class OPTPolicy(Policy):
SubModuleReplacementDescription(
suffix="q_proj",
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="k_proj",
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="v_proj",
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="out_proj",
target_module=Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
],
)
@ -123,7 +135,14 @@ class OPTPolicy(Policy):
description=SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
policy=policy,
target_key=OPTDecoder,
@ -272,6 +291,7 @@ class OPTForCausalLMPolicy(OPTPolicy):
kwargs=dict(
gather_output=not self.shard_config.parallel_output,
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
fp8_communication=self.shard_config.fp8_communication,
),
),
policy=policy,

View File

@ -119,37 +119,37 @@ class Qwen2Policy(Policy):
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode),
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode),
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode),
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
kwargs=dict(seq_parallel_mode=sp_mode),
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode),
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode),
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=Linear1D_Row,
kwargs=dict(seq_parallel_mode=sp_mode),
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
),
],
)
@ -159,7 +159,14 @@ class Qwen2Policy(Policy):
description=SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
policy=policy,
target_key=Qwen2Model,
@ -317,7 +324,11 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy):
new_item = {
Qwen2ForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(suffix="lm_head", target_module=Linear1D_Col)
SubModuleReplacementDescription(
suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(fp8_communication=self.shard_config.fp8_communication),
)
],
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
)
@ -366,7 +377,9 @@ class Qwen2ForSequenceClassificationPolicy(Qwen2Policy):
Qwen2ForSequenceClassification: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
suffix="score",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
)
]
)

View File

@ -43,19 +43,29 @@ class SamPolicy(Policy):
target_module=col_nn.FusedLinear1D_Col,
kwargs={
"n_fused": 3,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="attn.proj",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="mlp.lin1",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="mlp.lin2",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
],
)
@ -68,58 +78,100 @@ class SamPolicy(Policy):
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attn.out_proj",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="cross_attn_token_to_image.q_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="cross_attn_token_to_image.k_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="cross_attn_token_to_image.v_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="cross_attn_token_to_image.out_proj",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="mlp.lin1",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="mlp.lin2",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="cross_attn_image_to_token.q_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="cross_attn_image_to_token.k_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="cross_attn_image_to_token.v_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="cross_attn_image_to_token.out_proj",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
],
)
@ -132,18 +184,30 @@ class SamPolicy(Policy):
SubModuleReplacementDescription(
suffix="final_attn_token_to_image.q_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="final_attn_token_to_image.k_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="final_attn_token_to_image.v_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="final_attn_token_to_image.out_proj",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
],
)

View File

@ -117,23 +117,38 @@ class T5BasePolicy(Policy):
SubModuleReplacementDescription(
suffix="q",
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="k",
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="v",
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="o",
target_module=Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="relative_attention_bias",
target_module=Embedding1D,
kwargs=dict(gather_output=False),
kwargs=dict(
gather_output=False,
fp8_communication=self.shard_config.fp8_communication,
),
ignore_if_not_exist=True,
),
],
@ -151,13 +166,24 @@ class T5BasePolicy(Policy):
SubModuleReplacementDescription(
suffix="wi_0 ",
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="wi_1",
target_module=Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="wo", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
suffix="wo",
target_module=Linear1D_Col,
kwargs=dict(
gather_output=True,
fp8_communication=self.shard_config.fp8_communication,
),
),
SubModuleReplacementDescription(
suffix="dropout",
@ -170,10 +196,16 @@ class T5BasePolicy(Policy):
SubModuleReplacementDescription(
suffix="wi",
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="wo",
target_module=Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="dropout",
@ -187,7 +219,14 @@ class T5BasePolicy(Policy):
description=SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
policy=policy,
target_key=T5Stack,
@ -407,7 +446,14 @@ class T5ModelPolicy(T5BasePolicy):
description=SubModuleReplacementDescription(
suffix="shared",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
policy=policy,
target_key=T5Model,
@ -451,7 +497,14 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
description=SubModuleReplacementDescription(
suffix="shared",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
policy=policy,
target_key=T5ForConditionalGeneration,
@ -465,6 +518,7 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
kwargs={
"gather_output": True,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
},
),
policy=policy,
@ -539,7 +593,14 @@ class T5EncoderPolicy(T5BasePolicy):
description=SubModuleReplacementDescription(
suffix="shared",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
policy=policy,
target_key=T5EncoderModel,

View File

@ -70,14 +70,23 @@ class ViTPolicy(Policy):
SubModuleReplacementDescription(
suffix="attention.attention.query",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="attention.attention.key",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="attention.attention.value",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="attention.attention.dropout",
@ -86,6 +95,9 @@ class ViTPolicy(Policy):
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
@ -96,11 +108,15 @@ class ViTPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"skip_bias_add": self.enable_bias_gelu_fused,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="output.dense",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="output.dropout",
@ -215,7 +231,9 @@ class ViTForImageClassificationPolicy(ViTPolicy):
ViTForImageClassification: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="classifier", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
suffix="classifier",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
)
]
)

View File

@ -91,26 +91,44 @@ class WhisperPolicy(Policy):
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attn.out_proj",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="fc1",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="fc2",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
],
)
@ -128,42 +146,72 @@ class WhisperPolicy(Policy):
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attn.out_proj",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="encoder_attn.q_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="encoder_attn.k_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="encoder_attn.v_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="encoder_attn.out_proj",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="fc1",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="fc2",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
],
)
@ -174,7 +222,14 @@ class WhisperPolicy(Policy):
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
],
policy=policy,
@ -303,6 +358,7 @@ class WhisperPolicy(Policy):
kwargs={
"gather_output": True,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
},
),
policy=base_policy,