mirror of https://github.com/hpcaitech/ColossalAI
[fp8] support hybrid parallel plugin (#5982)
* support fp8 comm for qwen2 model * support fp8 comm for qwen2 model * support fp8 comm for qwen2 model * fp8 * fix * bert and bloom * chatglm and command * gpt2,gptj,bert, falcon,blip2 * mistral,opy,sam,t5,vit,whisper * fix * fix * fixpull/5998/head
parent
f1a3a326c4
commit
b2483c8e31
|
@ -68,6 +68,7 @@ class Embedding1D(ParallelModule):
|
|||
gather_output: bool = True,
|
||||
weight: Optional[nn.Parameter] = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
fp8_communication: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -81,6 +82,7 @@ class Embedding1D(ParallelModule):
|
|||
self.embed_args = args
|
||||
self.embed_kwargs = kwargs
|
||||
self.gather_output = gather_output
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
|
@ -155,7 +157,9 @@ class Embedding1D(ParallelModule):
|
|||
def forward(self, input_: Tensor) -> Tensor:
|
||||
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
if self.gather_output:
|
||||
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
|
||||
output = gather_forward_split_backward(
|
||||
output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
return output
|
||||
else:
|
||||
return output_parallel
|
||||
|
|
|
@ -572,6 +572,7 @@ class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule):
|
|||
weight: Optional[Parameter] = None,
|
||||
bias_: Optional[Parameter] = None,
|
||||
make_vocab_size_divisible_by: int = 64,
|
||||
fp8_communication: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
# create weight and bias
|
||||
|
@ -602,6 +603,7 @@ class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule):
|
|||
**kwargs,
|
||||
new_num_embeddings=new_out_features,
|
||||
old_num_embeddings=out_features,
|
||||
fp8_communication=fp8_communication,
|
||||
)
|
||||
# get the length of valid embeddings
|
||||
tp_rank = dist.get_rank(process_group)
|
||||
|
|
|
@ -627,6 +627,7 @@ class FusedLinear1D_Col(ParallelModule):
|
|||
bias_: Optional[Parameter] = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
fp8_communication: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
# Keep input parameters
|
||||
|
@ -638,6 +639,7 @@ class FusedLinear1D_Col(ParallelModule):
|
|||
self.n_fused = n_fused
|
||||
self.process_group = process_group
|
||||
self.async_communication = async_communication
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError("cannot skip bias addition if bias is None")
|
||||
|
@ -767,7 +769,9 @@ class FusedLinear1D_Col(ParallelModule):
|
|||
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
|
||||
output = gather_forward_split_backward(
|
||||
output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
|
|
|
@ -187,11 +187,17 @@ class BertPipelineForwards:
|
|||
if shard_config is not None and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode == "split_gather":
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states = split_forward_gather_backward(
|
||||
encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
encoder_hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx):
|
||||
|
@ -242,7 +248,10 @@ class BertPipelineForwards:
|
|||
if shard_config is not None and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
if output_hidden_states:
|
||||
|
@ -1135,11 +1144,17 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||
# split the input tensor along sequence dimension
|
||||
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
||||
embedding_output = split_forward_gather_backward(
|
||||
embedding_output, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
embedding_output,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states = split_forward_gather_backward(
|
||||
encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
encoder_hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
|
@ -1159,7 +1174,10 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||
|
||||
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
||||
sequence_output = gather_forward_split_backward(
|
||||
sequence_output, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
sequence_output,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
|
|
@ -221,7 +221,10 @@ class BloomPipelineForwards:
|
|||
if shard_config and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode == "split_gather":
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
|
@ -264,7 +267,10 @@ class BloomPipelineForwards:
|
|||
if shard_config and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
|
@ -922,7 +928,10 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||
# split the input tensor along sequence dimension
|
||||
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
|
@ -960,7 +969,10 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||
|
||||
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
# Add last hidden state
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
|
|
@ -206,6 +206,7 @@ class ChatGLMPipelineForwards:
|
|||
hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = split_forward_gather_backward(
|
||||
|
@ -213,6 +214,7 @@ class ChatGLMPipelineForwards:
|
|||
dim=0,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=1 / shard_config.sequence_parallel_size,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
for idx in range(start_idx, end_idx):
|
||||
layer = self.encoder._get_layer(idx)
|
||||
|
@ -245,6 +247,7 @@ class ChatGLMPipelineForwards:
|
|||
hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
|
@ -252,6 +255,7 @@ class ChatGLMPipelineForwards:
|
|||
dim=0,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=shard_config.sequence_parallel_size,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
@ -414,6 +418,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
|
|||
inputs_embeds,
|
||||
dim=0,
|
||||
process_group=sp_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
|
@ -421,6 +426,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
|
|||
dim=0,
|
||||
process_group=sp_group,
|
||||
grad_scale=1 / sp_size,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
|
||||
inputs_embeds,
|
||||
|
@ -436,6 +442,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
|
|||
hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
|
@ -443,6 +450,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
|
|||
dim=0,
|
||||
process_group=sp_group,
|
||||
grad_scale=sp_size,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
|
@ -532,9 +540,24 @@ def get_chatglm_sequence_parallel_attention_forward(shard_config: ShardConfig, s
|
|||
key_layer = key_layer.reshape(sq, bs, -1)
|
||||
value_layer = value_layer.reshape(sq, bs, -1)
|
||||
|
||||
query_layer = all_to_all_comm(query_layer, sp_group, gather_dim=0)
|
||||
key_layer = all_to_all_comm(key_layer, sp_group, gather_dim=0)
|
||||
value_layer = all_to_all_comm(value_layer, sp_group, gather_dim=0)
|
||||
query_layer = all_to_all_comm(
|
||||
query_layer,
|
||||
sp_group,
|
||||
gather_dim=0,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
key_layer = all_to_all_comm(
|
||||
key_layer,
|
||||
sp_group,
|
||||
gather_dim=0,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
value_layer = all_to_all_comm(
|
||||
value_layer,
|
||||
sp_group,
|
||||
gather_dim=0,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
query_layer = query_layer.view(
|
||||
sq * sp_size,
|
||||
|
@ -610,7 +633,13 @@ def get_chatglm_sequence_parallel_attention_forward(shard_config: ShardConfig, s
|
|||
|
||||
context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
|
||||
if sp_mode == "all_to_all":
|
||||
context_layer = all_to_all_comm(context_layer, sp_group, gather_dim=2, scatter_dim=0)
|
||||
context_layer = all_to_all_comm(
|
||||
context_layer,
|
||||
sp_group,
|
||||
gather_dim=2,
|
||||
scatter_dim=0,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
# =================
|
||||
# Output. [sq, b, h]
|
||||
|
|
|
@ -140,6 +140,7 @@ class CommandPipelineForwards:
|
|||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = split_forward_gather_backward(
|
||||
|
@ -147,6 +148,7 @@ class CommandPipelineForwards:
|
|||
dim=1,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=1 / shard_config.sequence_parallel_size,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
# decoder layers
|
||||
|
@ -211,6 +213,7 @@ class CommandPipelineForwards:
|
|||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
|
@ -218,6 +221,7 @@ class CommandPipelineForwards:
|
|||
dim=1,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=shard_config.sequence_parallel_size,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
|
@ -382,9 +386,9 @@ def get_command_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
|
|||
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
query_states = all_to_all_comm(query_states, sp_group)
|
||||
key_states = all_to_all_comm(key_states, sp_group)
|
||||
value_states = all_to_all_comm(value_states, sp_group)
|
||||
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
bsz, q_len, _ = query_states.size()
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
@ -446,7 +450,9 @@ def get_command_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
|
|||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
|
||||
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
|
||||
attn_output = all_to_all_comm(
|
||||
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
else:
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
|
@ -526,9 +532,13 @@ def get_command_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
|
|||
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
||||
|
||||
if sp_mode in ["ring", "split_gather"]:
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# decoder layers
|
||||
|
@ -573,9 +583,13 @@ def get_command_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
|
|||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if sp_mode == "ring" or sp_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
|
|
|
@ -221,6 +221,7 @@ class GPT2PipelineForwards:
|
|||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
# Going through held blocks.
|
||||
|
@ -276,6 +277,7 @@ class GPT2PipelineForwards:
|
|||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
|
|
|
@ -185,6 +185,7 @@ class GPTJPipelineForwards:
|
|||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
# Going through held blocks.
|
||||
|
@ -236,6 +237,7 @@ class GPTJPipelineForwards:
|
|||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
|
@ -915,6 +917,7 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
|
@ -978,6 +981,7 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
|
|
@ -143,9 +143,13 @@ class LlamaPipelineForwards:
|
|||
# Support SP + PP
|
||||
if stage_manager.is_first_stage():
|
||||
if sp_mode in ["ring", "split_gather"]:
|
||||
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group)
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size)
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
if use_cache:
|
||||
|
@ -210,9 +214,13 @@ class LlamaPipelineForwards:
|
|||
if stage_manager.is_last_stage():
|
||||
hidden_states = self.norm(hidden_states)
|
||||
if sp_mode == "ring" or sp_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
|
|
|
@ -175,6 +175,7 @@ class Qwen2PipelineForwards:
|
|||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = split_forward_gather_backward(
|
||||
|
@ -182,6 +183,7 @@ class Qwen2PipelineForwards:
|
|||
dim=1,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=1 / shard_config.sequence_parallel_size,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
# decoder layers
|
||||
|
@ -246,6 +248,7 @@ class Qwen2PipelineForwards:
|
|||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
|
@ -253,6 +256,7 @@ class Qwen2PipelineForwards:
|
|||
dim=1,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=shard_config.sequence_parallel_size,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
|
@ -516,9 +520,9 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
|||
value_states = self.v_proj(hidden_states)
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
query_states = all_to_all_comm(query_states, sp_group)
|
||||
key_states = all_to_all_comm(key_states, sp_group)
|
||||
value_states = all_to_all_comm(value_states, sp_group)
|
||||
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
bsz, q_len, _ = query_states.size()
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
@ -604,7 +608,9 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
|||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
if sp_mode == "all_to_all":
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
|
||||
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
|
||||
attn_output = all_to_all_comm(
|
||||
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
else:
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
|
@ -702,9 +708,13 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
|
|||
next_decoder_cache = None
|
||||
|
||||
if sp_mode in ["ring", "split_gather"]:
|
||||
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group)
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size)
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
|
@ -741,9 +751,13 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
|
|||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if sp_mode == "ring" or sp_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
|
|
|
@ -98,6 +98,7 @@ class BertPolicy(Policy):
|
|||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"overlap": overlap,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -106,6 +107,7 @@ class BertPolicy(Policy):
|
|||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"overlap": overlap,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -114,6 +116,7 @@ class BertPolicy(Policy):
|
|||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"overlap": overlap,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -123,7 +126,10 @@ class BertPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={"seq_parallel_mode": sp_mode},
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dropout",
|
||||
|
@ -136,12 +142,16 @@ class BertPolicy(Policy):
|
|||
"seq_parallel_mode": sp_mode,
|
||||
"overlap": overlap,
|
||||
"skip_bias_add": self.enable_bias_gelu_fused,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={"seq_parallel_mode": sp_mode},
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.dropout",
|
||||
|
@ -180,6 +190,13 @@ class BertPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="word_embeddings",
|
||||
target_module=embedding_cls,
|
||||
kwargs=(
|
||||
{
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {}
|
||||
),
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
|
@ -249,6 +266,7 @@ class BertPolicy(Policy):
|
|||
kwargs={
|
||||
"gather_output": True,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
policy=base_policy,
|
||||
|
|
|
@ -72,20 +72,30 @@ class BlipPolicy(Policy):
|
|||
target_module=col_nn.FusedLinear1D_Col,
|
||||
kwargs={
|
||||
"n_fused": 3,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.projection",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.fc1",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"skip_bias_add": self.enable_bias_gelu_fused},
|
||||
kwargs={
|
||||
"skip_bias_add": self.enable_bias_gelu_fused,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.fc2",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@ -114,14 +124,23 @@ class BlipPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.query",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.key",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.value",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.dropout",
|
||||
|
@ -130,6 +149,9 @@ class BlipPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dropout",
|
||||
|
@ -138,14 +160,23 @@ class BlipPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="crossattention.attention.query",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="crossattention.attention.key",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="crossattention.attention.value",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="crossattention.attention.dropout",
|
||||
|
@ -154,6 +185,9 @@ class BlipPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="crossattention.output.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="crossattention.output.dropout",
|
||||
|
@ -162,10 +196,16 @@ class BlipPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="intermediate_query.dense",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output_query.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output_query.dropout",
|
||||
|
@ -185,26 +225,44 @@ class BlipPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.out_proj",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc1",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc2",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@ -225,7 +283,14 @@ class BlipPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="model.decoder.embed_tokens",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
|
@ -241,6 +306,7 @@ class BlipPolicy(Policy):
|
|||
kwargs={
|
||||
"gather_output": True,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
],
|
||||
|
|
|
@ -76,12 +76,19 @@ class BloomPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.query_key_value",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"seq_parallel_mode": sp_mode, "overlap": overlap},
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"overlap": overlap,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={"seq_parallel_mode": sp_mode},
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.attention_dropout",
|
||||
|
@ -90,12 +97,19 @@ class BloomPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="mlp.dense_h_to_4h",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"seq_parallel_mode": sp_mode, "overlap": overlap},
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"overlap": overlap,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.dense_4h_to_h",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={"seq_parallel_mode": sp_mode},
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@ -115,7 +129,14 @@ class BloomPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="word_embeddings",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
|
@ -279,6 +300,7 @@ class BloomForCausalLMPolicy(BloomPolicy):
|
|||
kwargs=dict(
|
||||
gather_output=not self.shard_config.parallel_output,
|
||||
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
|
@ -337,7 +359,9 @@ class BloomForSequenceClassificationPolicy(BloomPolicy):
|
|||
if self.shard_config.enable_tensor_parallelism:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)
|
||||
suffix="score",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=BloomForSequenceClassification,
|
||||
|
@ -374,7 +398,9 @@ class BloomForTokenClassificationPolicy(BloomPolicy):
|
|||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)
|
||||
suffix="classifier",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
|
|
|
@ -128,12 +128,17 @@ class ChatGLMPolicy(Policy):
|
|||
"seq_parallel_mode": sp_mode,
|
||||
"seq_parallel_dim": 0,
|
||||
"overlap": overlap,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={"seq_parallel_mode": sp_mode, "seq_parallel_dim": 0},
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"seq_parallel_dim": 0,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.core_attention.attention_dropout",
|
||||
|
@ -148,7 +153,14 @@ class ChatGLMPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="embedding.word_embeddings",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
|
|
|
@ -126,37 +126,37 @@ class CommandPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.gate_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.up_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.down_proj",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@ -166,7 +166,14 @@ class CommandPolicy(Policy):
|
|||
description=SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=CohereModel,
|
||||
|
@ -303,6 +310,7 @@ class CommandForCausalLMPolicy(CommandPolicy):
|
|||
kwargs={
|
||||
"gather_output": not self.shard_config.parallel_output,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
)
|
||||
],
|
||||
|
|
|
@ -105,7 +105,14 @@ class FalconPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="word_embeddings",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
|
|
|
@ -162,7 +162,14 @@ class GPT2Policy(Policy):
|
|||
description=SubModuleReplacementDescription(
|
||||
suffix="wte",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=GPT2Model,
|
||||
|
@ -332,6 +339,7 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
|
|||
kwargs={
|
||||
"gather_output": False,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
)
|
||||
],
|
||||
|
@ -402,6 +410,7 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
|
|||
kwargs={
|
||||
"gather_output": True,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
)
|
||||
]
|
||||
|
|
|
@ -77,6 +77,7 @@ class GPTJPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"overlap": overlap,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -84,6 +85,7 @@ class GPTJPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"overlap": overlap,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -91,19 +93,29 @@ class GPTJPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"overlap": overlap,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.out_proj",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.fc_in",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.fc_out",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.attn_dropout",
|
||||
|
@ -125,7 +137,14 @@ class GPTJPolicy(Policy):
|
|||
description=SubModuleReplacementDescription(
|
||||
suffix="wte",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=GPTJModel,
|
||||
|
@ -264,6 +283,7 @@ class GPTJForCausalLMPolicy(GPTJPolicy):
|
|||
kwargs={
|
||||
"gather_output": True,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
)
|
||||
]
|
||||
|
|
|
@ -166,7 +166,14 @@ class LlamaPolicy(Policy):
|
|||
description=SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=LlamaModel,
|
||||
|
@ -308,6 +315,7 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
|||
kwargs={
|
||||
"gather_output": not self.shard_config.parallel_output,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
)
|
||||
],
|
||||
|
|
|
@ -88,30 +88,51 @@ class MistralPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.gate_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.up_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.down_proj",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@ -121,7 +142,14 @@ class MistralPolicy(Policy):
|
|||
description=SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=MistralModel,
|
||||
|
@ -281,6 +309,7 @@ class MistralForCausalLMPolicy(MistralPolicy):
|
|||
kwargs={
|
||||
"gather_output": not self.shard_config.parallel_output,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
)
|
||||
]
|
||||
|
@ -297,7 +326,9 @@ class MistralForCausalLMPolicy(MistralPolicy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=PaddingLMHead,
|
||||
kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by),
|
||||
kwargs=dict(
|
||||
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
@ -350,7 +381,9 @@ class MistralForSequenceClassificationPolicy(MistralPolicy):
|
|||
MistralForSequenceClassification: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
|
||||
suffix="score",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
|
@ -102,18 +102,30 @@ class OPTPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="q_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="k_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="v_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="out_proj",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@ -123,7 +135,14 @@ class OPTPolicy(Policy):
|
|||
description=SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=OPTDecoder,
|
||||
|
@ -272,6 +291,7 @@ class OPTForCausalLMPolicy(OPTPolicy):
|
|||
kwargs=dict(
|
||||
gather_output=not self.shard_config.parallel_output,
|
||||
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
|
|
|
@ -119,37 +119,37 @@ class Qwen2Policy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.gate_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.up_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.down_proj",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@ -159,7 +159,14 @@ class Qwen2Policy(Policy):
|
|||
description=SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=Qwen2Model,
|
||||
|
@ -317,7 +324,11 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy):
|
|||
new_item = {
|
||||
Qwen2ForCausalLM: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="lm_head", target_module=Linear1D_Col)
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(fp8_communication=self.shard_config.fp8_communication),
|
||||
)
|
||||
],
|
||||
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
|
||||
)
|
||||
|
@ -366,7 +377,9 @@ class Qwen2ForSequenceClassificationPolicy(Qwen2Policy):
|
|||
Qwen2ForSequenceClassification: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
|
||||
suffix="score",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
|
@ -43,19 +43,29 @@ class SamPolicy(Policy):
|
|||
target_module=col_nn.FusedLinear1D_Col,
|
||||
kwargs={
|
||||
"n_fused": 3,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.proj",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.lin1",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.lin2",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@ -68,58 +78,100 @@ class SamPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.out_proj",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="cross_attn_token_to_image.q_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="cross_attn_token_to_image.k_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="cross_attn_token_to_image.v_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="cross_attn_token_to_image.out_proj",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.lin1",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.lin2",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="cross_attn_image_to_token.q_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="cross_attn_image_to_token.k_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="cross_attn_image_to_token.v_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="cross_attn_image_to_token.out_proj",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@ -132,18 +184,30 @@ class SamPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="final_attn_token_to_image.q_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="final_attn_token_to_image.k_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="final_attn_token_to_image.v_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="final_attn_token_to_image.out_proj",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
|
|
@ -117,23 +117,38 @@ class T5BasePolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="q",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="k",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="v",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="o",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="relative_attention_bias",
|
||||
target_module=Embedding1D,
|
||||
kwargs=dict(gather_output=False),
|
||||
kwargs=dict(
|
||||
gather_output=False,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
),
|
||||
ignore_if_not_exist=True,
|
||||
),
|
||||
],
|
||||
|
@ -151,13 +166,24 @@ class T5BasePolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="wi_0 ",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wi_1",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wo", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
|
||||
suffix="wo",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(
|
||||
gather_output=True,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
|
@ -170,10 +196,16 @@ class T5BasePolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="wi",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wo",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
|
@ -187,7 +219,14 @@ class T5BasePolicy(Policy):
|
|||
description=SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=T5Stack,
|
||||
|
@ -407,7 +446,14 @@ class T5ModelPolicy(T5BasePolicy):
|
|||
description=SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=T5Model,
|
||||
|
@ -451,7 +497,14 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
|
|||
description=SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=T5ForConditionalGeneration,
|
||||
|
@ -465,6 +518,7 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
|
|||
kwargs={
|
||||
"gather_output": True,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
policy=policy,
|
||||
|
@ -539,7 +593,14 @@ class T5EncoderPolicy(T5BasePolicy):
|
|||
description=SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=T5EncoderModel,
|
||||
|
|
|
@ -70,14 +70,23 @@ class ViTPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.query",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.key",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.value",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.dropout",
|
||||
|
@ -86,6 +95,9 @@ class ViTPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dropout",
|
||||
|
@ -96,11 +108,15 @@ class ViTPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"skip_bias_add": self.enable_bias_gelu_fused,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.dropout",
|
||||
|
@ -215,7 +231,9 @@ class ViTForImageClassificationPolicy(ViTPolicy):
|
|||
ViTForImageClassification: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="classifier", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
|
||||
suffix="classifier",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
|
@ -91,26 +91,44 @@ class WhisperPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.out_proj",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc1",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc2",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@ -128,42 +146,72 @@ class WhisperPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.out_proj",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="encoder_attn.q_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="encoder_attn.k_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="encoder_attn.v_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="encoder_attn.out_proj",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc1",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc2",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@ -174,7 +222,14 @@ class WhisperPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
|
@ -303,6 +358,7 @@ class WhisperPolicy(Policy):
|
|||
kwargs={
|
||||
"gather_output": True,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
policy=base_policy,
|
||||
|
|
Loading…
Reference in New Issue