mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] chatglm support sequence parallel (#4482)
* [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel * fix fix fix fixpull/4498/head
parent
351351a36e
commit
59e252ecdb
|
@ -74,6 +74,7 @@ class Linear1D_Col(ParallelModule):
|
|||
process_group: ProcessGroup = None,
|
||||
gather_output: bool = False,
|
||||
seq_parallel: bool = False,
|
||||
seq_parallel_dim: int = 1,
|
||||
overlap: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
weight: Optional[Parameter] = None,
|
||||
|
@ -87,6 +88,7 @@ class Linear1D_Col(ParallelModule):
|
|||
self.out_features = out_features
|
||||
self.gather_output = gather_output
|
||||
self.seq_parallel = seq_parallel
|
||||
self.seq_parallel_dim = seq_parallel_dim
|
||||
self.overlap = overlap
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.device = device
|
||||
|
@ -190,7 +192,8 @@ class Linear1D_Col(ParallelModule):
|
|||
bias = self.bias if not self.skip_bias_add else None
|
||||
if self.seq_parallel:
|
||||
output_parallel = linear_gather_forward_reducescatter_backward(input_parallel, self.weight, bias,
|
||||
self.process_group, True, 1, self.overlap)
|
||||
self.process_group, True,
|
||||
self.seq_parallel_dim, self.overlap)
|
||||
else:
|
||||
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
|
||||
|
||||
|
@ -236,6 +239,7 @@ class Linear1D_Row(ParallelModule):
|
|||
device: torch.device = None,
|
||||
process_group: ProcessGroup = None,
|
||||
seq_parallel: bool = False,
|
||||
seq_parallel_dim: int = 1,
|
||||
parallel_input: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
weight: Optional[Parameter] = None,
|
||||
|
@ -254,6 +258,7 @@ class Linear1D_Row(ParallelModule):
|
|||
self.skip_bias_add = skip_bias_add
|
||||
self.process_group = process_group
|
||||
self.seq_parallel = seq_parallel
|
||||
self.seq_parallel_dim = seq_parallel_dim
|
||||
self.num_partitions = dist.get_world_size(self.process_group)
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
|
@ -390,7 +395,8 @@ class Linear1D_Row(ParallelModule):
|
|||
else:
|
||||
output_parallel = F.linear(input_, self.weight)
|
||||
if self.seq_parallel:
|
||||
output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1)
|
||||
output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group,
|
||||
self.seq_parallel_dim)
|
||||
else:
|
||||
output = reduce_forward(output_parallel, self.process_group)
|
||||
|
||||
|
|
|
@ -9,6 +9,8 @@ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutpu
|
|||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
|
||||
ChatGLMForConditionalGeneration,
|
||||
|
@ -146,6 +148,7 @@ class ChatGLMPipelineForwards:
|
|||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
):
|
||||
logger = logging.get_logger(__name__)
|
||||
output_hidden_states = (output_hidden_states
|
||||
|
@ -198,6 +201,11 @@ class ChatGLMPipelineForwards:
|
|||
all_self_attentions = None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
|
||||
if shard_config.enable_sequence_parallelism:
|
||||
hidden_states = split_forward_gather_backward(hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.tensor_parallel_process_group)
|
||||
for idx in range(start_idx, end_idx):
|
||||
layer = self.encoder._get_layer(idx)
|
||||
if output_hidden_states:
|
||||
|
@ -214,6 +222,11 @@ class ChatGLMPipelineForwards:
|
|||
hidden_states, kv_cache = layer_ret
|
||||
if use_cache:
|
||||
presents = presents + (kv_cache,)
|
||||
|
||||
if shard_config.enable_sequence_parallelism:
|
||||
hidden_states = gather_forward_split_backward(hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.tensor_parallel_process_group)
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
if stage_manager.is_last_stage():
|
||||
|
@ -233,23 +246,22 @@ class ChatGLMPipelineForwards:
|
|||
return {'hidden_states': hidden_states}
|
||||
|
||||
@staticmethod
|
||||
def chatglm_for_conditional_generation_forward(
|
||||
self: ChatGLMForConditionalGeneration,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
return_last_logit: Optional[bool] = False,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
):
|
||||
def chatglm_for_conditional_generation_forward(self: ChatGLMForConditionalGeneration,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
return_last_logit: Optional[bool] = False,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None):
|
||||
logger = logging.get_logger(__name__)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
|
||||
|
@ -266,6 +278,7 @@ class ChatGLMPipelineForwards:
|
|||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config,
|
||||
)
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
@ -296,3 +309,91 @@ class ChatGLMPipelineForwards:
|
|||
)
|
||||
else:
|
||||
return transformer_outputs
|
||||
|
||||
|
||||
def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
full_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
output_hidden_states = (output_hidden_states
|
||||
if output_hidden_states is not None else self.config.output_hidden_states)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
|
||||
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embedding(input_ids)
|
||||
|
||||
if self.pre_seq_len is not None:
|
||||
if past_key_values is None:
|
||||
past_key_values = self.get_prompt(
|
||||
batch_size=batch_size,
|
||||
device=input_ids.device,
|
||||
dtype=inputs_embeds.dtype,
|
||||
)
|
||||
if attention_mask is not None:
|
||||
attention_mask = torch.cat(
|
||||
[
|
||||
attention_mask.new_ones((batch_size, self.pre_seq_len)),
|
||||
attention_mask,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
if full_attention_mask is None:
|
||||
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
|
||||
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
|
||||
|
||||
# Rotary positional embeddings
|
||||
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
|
||||
if position_ids is not None:
|
||||
rotary_pos_emb = rotary_pos_emb[position_ids]
|
||||
else:
|
||||
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
|
||||
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
|
||||
|
||||
# Run encoder.
|
||||
# [seq_len, batch_size, hidden_size] -> [seq_len/TP_size, batch_size, hidden_size]
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds,
|
||||
dim=0,
|
||||
process_group=shard_config.tensor_parallel_process_group)
|
||||
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
|
||||
inputs_embeds,
|
||||
full_attention_mask,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
kv_caches=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
hidden_states = gather_forward_split_backward(hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.tensor_parallel_process_group)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [
|
||||
hidden_states,
|
||||
presents,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
] if v is not None)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=presents,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
)
|
||||
|
||||
return forward
|
||||
|
|
|
@ -155,20 +155,26 @@ class BertPolicy(Policy):
|
|||
|
||||
# use flash attention
|
||||
if self.shard_config.enable_flash_attention:
|
||||
policy[BertSelfAttention] = ModulePolicyDescription(method_replacement={
|
||||
self.append_or_create_method_replacement(description={
|
||||
'forward': get_bert_flash_attention_forward(),
|
||||
})
|
||||
},
|
||||
policy=policy,
|
||||
target_key=BertSelfAttention)
|
||||
|
||||
# use jit operator
|
||||
if self.shard_config.enable_jit_fused:
|
||||
policy[BertSelfOutput] = ModulePolicyDescription(method_replacement={
|
||||
self.append_or_create_method_replacement(description={
|
||||
'forward': get_jit_fused_bert_self_output_forward(),
|
||||
'dropout_add': get_jit_fused_dropout_add_func(),
|
||||
})
|
||||
policy[BertOutput] = ModulePolicyDescription(method_replacement={
|
||||
},
|
||||
policy=policy,
|
||||
target_key=BertSelfOutput)
|
||||
self.append_or_create_method_replacement(description={
|
||||
'forward': get_jit_fused_bert_output_forward(),
|
||||
'dropout_add': get_jit_fused_dropout_add_func(),
|
||||
})
|
||||
},
|
||||
policy=policy,
|
||||
target_key=BertOutput)
|
||||
|
||||
return policy
|
||||
|
||||
|
|
|
@ -285,21 +285,26 @@ class BlipPolicy(Policy):
|
|||
|
||||
# use flash attention
|
||||
if self.shard_config.enable_flash_attention:
|
||||
policy[Blip2Attention] = ModulePolicyDescription(method_replacement={
|
||||
self.append_or_create_method_replacement(description={
|
||||
'forward': get_blip2_flash_attention_forward(),
|
||||
})
|
||||
},
|
||||
policy=policy,
|
||||
target_key=Blip2Attention)
|
||||
|
||||
# use jit operator
|
||||
if self.shard_config.enable_jit_fused:
|
||||
policy[Blip2QFormerSelfOutput] = ModulePolicyDescription(
|
||||
method_replacement={
|
||||
'forward': get_jit_fused_blip2_QFormer_self_output_forward(),
|
||||
'dropout_add': get_jit_fused_dropout_add_func(),
|
||||
})
|
||||
policy[Blip2QFormerOutput] = ModulePolicyDescription(method_replacement={
|
||||
self.append_or_create_method_replacement(description={
|
||||
'forward': get_jit_fused_blip2_QFormer_self_output_forward(),
|
||||
'dropout_add': get_jit_fused_dropout_add_func(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=Blip2QFormerSelfOutput)
|
||||
self.append_or_create_method_replacement(description={
|
||||
'forward': get_jit_fused_blip2_QFormer_output_forward(),
|
||||
'dropout_add': get_jit_fused_dropout_add_func(),
|
||||
})
|
||||
},
|
||||
policy=policy,
|
||||
target_key=Blip2QFormerOutput)
|
||||
|
||||
return policy
|
||||
|
||||
|
|
|
@ -125,25 +125,33 @@ class BloomPolicy(Policy):
|
|||
target_key=BloomModel)
|
||||
|
||||
if self.shard_config.enable_flash_attention:
|
||||
policy[BloomAttention] = ModulePolicyDescription(method_replacement={
|
||||
self.append_or_create_method_replacement(description={
|
||||
'forward': get_bloom_flash_attention_forward(),
|
||||
'dropout_add': get_dropout_add_func()
|
||||
})
|
||||
'dropout_add': get_dropout_add_func(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=BloomAttention)
|
||||
|
||||
# enable jit fused operator
|
||||
if self.shard_config.enable_jit_fused:
|
||||
policy[BloomAttention] = ModulePolicyDescription(method_replacement={
|
||||
self.append_or_create_method_replacement(description={
|
||||
'forward': get_jit_fused_bloom_attention_forward(),
|
||||
'dropout_add': get_jit_fused_dropout_add_func(),
|
||||
})
|
||||
policy[BloomMLP] = ModulePolicyDescription(method_replacement={
|
||||
},
|
||||
policy=policy,
|
||||
target_key=BloomAttention)
|
||||
self.append_or_create_method_replacement(description={
|
||||
'forward': get_jit_fused_bloom_mlp_forward(),
|
||||
'dropout_add': get_jit_fused_dropout_add_func(),
|
||||
})
|
||||
policy[BloomGelu] = ModulePolicyDescription(method_replacement={
|
||||
},
|
||||
policy=policy,
|
||||
target_key=BloomMLP)
|
||||
self.append_or_create_method_replacement(description={
|
||||
'forward': get_jit_fused_bloom_gelu_forward(),
|
||||
'bloom_gelu_forward': get_jit_fused_gelu_forward_func(),
|
||||
})
|
||||
},
|
||||
policy=policy,
|
||||
target_key=BloomGelu)
|
||||
|
||||
return policy
|
||||
|
||||
|
|
|
@ -15,7 +15,11 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
|
|||
GLMBlock,
|
||||
)
|
||||
|
||||
from ..modeling.chatglm2 import get_flash_core_attention_forward, get_jit_fused_glm_block_forward
|
||||
from ..modeling.chatglm2 import (
|
||||
get_chatglm_sequence_parallel_forward_fn,
|
||||
get_flash_core_attention_forward,
|
||||
get_jit_fused_glm_block_forward,
|
||||
)
|
||||
from ..modeling.jit import get_jit_fused_dropout_add_func
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
|
@ -45,8 +49,8 @@ class ChatGLMPolicy(Policy):
|
|||
|
||||
policy = {}
|
||||
|
||||
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
|
||||
policy[ChatGLMModel] = ModulePolicyDescription(attribute_replacement={},
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -55,36 +59,42 @@ class ChatGLMPolicy(Policy):
|
|||
)
|
||||
])
|
||||
|
||||
policy[GLMBlock] = ModulePolicyDescription(attribute_replacement={
|
||||
"self_attention.num_attention_heads_per_partition":
|
||||
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
"self_attention.projection_size":
|
||||
(self.model.config.kv_channels * self.model.config.num_attention_heads) //
|
||||
self.shard_config.tensor_parallel_size,
|
||||
"self_attention.qkv_hidden_size":
|
||||
(self.model.config.kv_channels * self.model.config.num_attention_heads * 3) //
|
||||
self.shard_config.tensor_parallel_size,
|
||||
"self_attention.core_attention.num_attention_heads_per_partition":
|
||||
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
"self_attention.core_attention.hidden_size_per_partition":
|
||||
self.model.config.kv_channels * self.model.config.num_attention_heads //
|
||||
self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.query_key_value",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.core_attention.attention_dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
])
|
||||
policy[GLMBlock] = ModulePolicyDescription(
|
||||
attribute_replacement={
|
||||
"self_attention.num_attention_heads_per_partition":
|
||||
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
"self_attention.projection_size":
|
||||
(self.model.config.kv_channels * self.model.config.num_attention_heads) //
|
||||
self.shard_config.tensor_parallel_size,
|
||||
"self_attention.qkv_hidden_size":
|
||||
(self.model.config.kv_channels * self.model.config.num_attention_heads * 3) //
|
||||
self.shard_config.tensor_parallel_size,
|
||||
"self_attention.core_attention.num_attention_heads_per_partition":
|
||||
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
"self_attention.core_attention.hidden_size_per_partition":
|
||||
self.model.config.kv_channels * self.model.config.num_attention_heads //
|
||||
self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="self_attention.query_key_value",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
'seq_parallel': use_sequence_parallel,
|
||||
'seq_parallel_dim': 0
|
||||
}),
|
||||
SubModuleReplacementDescription(suffix="self_attention.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
'seq_parallel': use_sequence_parallel,
|
||||
'seq_parallel_dim': 0
|
||||
}),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.core_attention.attention_dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
])
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
if not self.model.config.rmsnorm:
|
||||
|
@ -124,16 +134,27 @@ class ChatGLMPolicy(Policy):
|
|||
|
||||
# use flash attention
|
||||
if self.shard_config.enable_flash_attention:
|
||||
policy[CoreAttention] = ModulePolicyDescription(method_replacement={
|
||||
self.append_or_create_method_replacement(description={
|
||||
'forward': get_flash_core_attention_forward(),
|
||||
})
|
||||
},
|
||||
policy=policy,
|
||||
target_key=CoreAttention)
|
||||
|
||||
# use sequence parallel
|
||||
if use_sequence_parallel:
|
||||
self.append_or_create_method_replacement(
|
||||
description={'forward': get_chatglm_sequence_parallel_forward_fn(self.shard_config)},
|
||||
policy=policy,
|
||||
target_key=ChatGLMModel)
|
||||
|
||||
# use jit fused operator
|
||||
if self.shard_config.enable_jit_fused:
|
||||
policy[GLMBlock] = ModulePolicyDescription(method_replacement={
|
||||
self.append_or_create_method_replacement(description={
|
||||
'forward': get_jit_fused_glm_block_forward(),
|
||||
'dropout_add': get_jit_fused_dropout_add_func(),
|
||||
})
|
||||
},
|
||||
policy=policy,
|
||||
target_key=GLMBlock)
|
||||
|
||||
return policy
|
||||
|
||||
|
@ -178,7 +199,13 @@ class ChatGLMPolicy(Policy):
|
|||
|
||||
layers_per_stage = Policy.distribute_layers(module.num_layers, stage_manager.num_stages)
|
||||
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
|
||||
method_replacement = {
|
||||
'forward':
|
||||
partial(new_forward,
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index,
|
||||
shard_config=self.shard_config)
|
||||
}
|
||||
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
|
||||
|
||||
|
||||
|
|
|
@ -118,9 +118,11 @@ class GPT2Policy(Policy):
|
|||
target_key=GPT2Block)
|
||||
|
||||
if self.shard_config.enable_flash_attention:
|
||||
policy[GPT2Attention] = ModulePolicyDescription(method_replacement={
|
||||
self.append_or_create_method_replacement(description={
|
||||
'forward': get_gpt2_flash_attention_forward(),
|
||||
})
|
||||
},
|
||||
policy=policy,
|
||||
target_key=GPT2Attention)
|
||||
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)}
|
||||
|
|
|
@ -105,9 +105,11 @@ class LlamaPolicy(Policy):
|
|||
target_key=LlamaModel)
|
||||
|
||||
if self.shard_config.enable_flash_attention:
|
||||
policy[LlamaAttention] = ModulePolicyDescription(method_replacement={
|
||||
self.append_or_create_method_replacement(description={
|
||||
'forward': get_llama_flash_attention_forward(),
|
||||
})
|
||||
},
|
||||
policy=policy,
|
||||
target_key=LlamaAttention)
|
||||
|
||||
return policy
|
||||
|
||||
|
|
|
@ -199,12 +199,16 @@ class SamPolicy(Policy):
|
|||
|
||||
# use flash attention
|
||||
if self.shard_config.enable_flash_attention:
|
||||
policy[SamAttention] = ModulePolicyDescription(method_replacement={
|
||||
self.append_or_create_method_replacement(description={
|
||||
'forward': get_sam_flash_attention_forward(),
|
||||
})
|
||||
policy[SamVisionAttention] = ModulePolicyDescription(method_replacement={
|
||||
},
|
||||
policy=policy,
|
||||
target_key=SamAttention)
|
||||
self.append_or_create_method_replacement(description={
|
||||
'forward': get_sam_vision_flash_attention_forward(),
|
||||
})
|
||||
},
|
||||
policy=policy,
|
||||
target_key=SamVisionAttention)
|
||||
|
||||
return policy
|
||||
|
||||
|
|
|
@ -90,16 +90,20 @@ class ViTPolicy(Policy):
|
|||
|
||||
# use flash attention
|
||||
if self.shard_config.enable_flash_attention:
|
||||
policy[ViTSelfAttention] = ModulePolicyDescription(method_replacement={
|
||||
self.append_or_create_method_replacement(description={
|
||||
'forward': get_vit_flash_self_attention_forward(),
|
||||
})
|
||||
},
|
||||
policy=policy,
|
||||
target_key=ViTSelfAttention)
|
||||
|
||||
# use jit fused operator
|
||||
if self.shard_config.enable_jit_fused:
|
||||
policy[ViTOutput] = ModulePolicyDescription(method_replacement={
|
||||
self.append_or_create_method_replacement(description={
|
||||
'forward': get_jit_fused_vit_output_forward(),
|
||||
'dropout_add': get_jit_fused_dropout_add_func(),
|
||||
})
|
||||
},
|
||||
policy=policy,
|
||||
target_key=ViTOutput)
|
||||
return policy
|
||||
|
||||
def new_model_class(self):
|
||||
|
|
|
@ -12,8 +12,8 @@ from ..registry import ModelAttribute, model_zoo
|
|||
|
||||
|
||||
def data_gen():
|
||||
input_ids = torch.tensor([[5941, 15, 2670, 3543, 632, 2075]], dtype=torch.int64)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]])
|
||||
input_ids = torch.tensor([[5941, 15, 2670, 3543, 632, 2075, 632, 2075]], dtype=torch.int64)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]])
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue