@ -9,7 +9,7 @@ from transformers.utils import logging
from colossalai . pipeline . stage_manager import PipelineStageManager
from colossalai . shardformer import ShardConfig
from colossalai . shardformer . layer import AttnMaskType , ColoAttention
from colossalai . shardformer . layer import ColoAttention
from colossalai . shardformer . layer . _operation import (
all_to_all_comm ,
gather_sp_output ,
@ -25,42 +25,7 @@ def get_flash_core_attention_forward():
def forward ( self : CoreAttention , query_layer , key_layer , value_layer , attention_mask ) :
query_layer , key_layer , value_layer = [ k . permute ( 1 , 2 , 0 , 3 ) for k in [ query_layer , key_layer , value_layer ] ]
if attention_mask is None and query_layer . shape [ 2 ] == key_layer . shape [ 2 ] :
attention_mask_type = AttnMaskType . CAUSAL
attn_bias = torch . zeros (
query_layer . shape [ 0 ] ,
1 ,
query_layer . shape [ 2 ] ,
key_layer . shape [ 2 ] ,
dtype = query_layer . dtype ,
device = query_layer . device ,
)
temp_mask = (
torch . ones (
query_layer . shape [ 2 ] ,
key_layer . shape [ 2 ] ,
dtype = torch . bool ,
device = query_layer . device ,
)
. tril ( diagonal = 0 )
. expand ( query_layer . shape [ 0 ] , 1 , - 1 , - 1 )
)
attn_bias . masked_fill_ ( temp_mask . logical_not ( ) , torch . finfo ( query_layer . dtype ) . min )
else :
attention_mask_type = AttnMaskType . CUSTOM
if attention_mask is not None :
attn_bias = torch . zeros_like ( attention_mask , dtype = query_layer . dtype )
attn_bias . masked_fill_ ( attention_mask , torch . finfo ( query_layer . dtype ) . min )
dropout_p = self . attention_dropout . p if self . training else 0.0
context_layer = ColoAttention . attention (
query_layer ,
key_layer ,
value_layer ,
attention_mask = attn_bias ,
attention_mask_type = attention_mask_type ,
dropout_p = dropout_p ,
scale = 1.0 / self . norm_factor ,
)
context_layer = ColoAttention . attention ( query_layer , key_layer , value_layer , * * attention_mask )
context_layer = context_layer . permute ( 2 , 0 , 1 , 3 )
new_context_layer_shape = context_layer . size ( ) [ : - 2 ] + ( self . hidden_size_per_partition , )
context_layer = context_layer . reshape ( * new_context_layer_shape )
@ -180,6 +145,17 @@ class ChatGLMPipelineForwards:
] ,
dim = - 1 ,
)
if shard_config . enable_flash_attention :
mask_shape = ( batch_size , 1 , seq_length , seq_length )
full_attention_mask : dict = ColoAttention . prepare_attn_kwargs (
mask_shape ,
hidden_states . dtype ,
hidden_states . device ,
q_padding_mask = attention_mask ,
is_causal = True ,
)
else :
if full_attention_mask is None :
if ( attention_mask is not None and not attention_mask . all ( ) ) or ( past_key_values and seq_length != 1 ) :
full_attention_mask = self . get_masks ( input_ids , past_key_values , padding_mask = attention_mask )
@ -237,7 +213,7 @@ class ChatGLMPipelineForwards:
layer_ret = torch . utils . checkpoint . checkpoint (
layer ,
hidden_states ,
attention_mask ,
full_ attention_mask,
rotary_pos_emb ,
past_key_values [ idx ] ,
use_cache ,
@ -402,7 +378,16 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
] ,
dim = - 1 ,
)
if shard_config . enable_flash_attention :
mask_shape = ( batch_size , 1 , seq_length , seq_length )
full_attention_mask : dict = ColoAttention . prepare_attn_kwargs (
mask_shape ,
hidden_states . dtype ,
hidden_states . device ,
q_padding_mask = attention_mask ,
is_causal = True ,
)
else :
if full_attention_mask is None :
if ( attention_mask is not None and not attention_mask . all ( ) ) or ( past_key_values and seq_length != 1 ) :
full_attention_mask = self . get_masks ( input_ids , past_key_values , padding_mask = attention_mask )
@ -652,3 +637,79 @@ def get_chatglm_sequence_parallel_attention_forward(shard_config: ShardConfig, s
return output , kv_cache
return forward
def get_flash_attention_forward_for_chat_glm_model ( ) :
from . chatglm2_6b . modeling_chatglm import ChatGLMModel
def forward (
self : ChatGLMModel ,
input_ids ,
position_ids : Optional [ torch . Tensor ] = None ,
attention_mask : Optional [ torch . BoolTensor ] = None ,
full_attention_mask : Optional [ torch . BoolTensor ] = None ,
past_key_values : Optional [ Tuple [ Tuple [ torch . Tensor , torch . Tensor ] , . . . ] ] = None ,
inputs_embeds : Optional [ torch . Tensor ] = None ,
use_cache : Optional [ bool ] = None ,
output_hidden_states : Optional [ bool ] = None ,
return_dict : Optional [ bool ] = None ,
) :
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self . config . output_hidden_states
)
use_cache = use_cache if use_cache is not None else self . config . use_cache
return_dict = return_dict if return_dict is not None else self . config . use_return_dict
batch_size , seq_length = input_ids . shape
if inputs_embeds is None :
inputs_embeds = self . embedding ( input_ids )
if self . pre_seq_len is not None :
if past_key_values is None :
past_key_values = self . get_prompt (
batch_size = batch_size , device = input_ids . device , dtype = inputs_embeds . dtype
)
if attention_mask is not None :
attention_mask = torch . cat (
[ attention_mask . new_ones ( ( batch_size , self . pre_seq_len ) ) , attention_mask ] , dim = - 1
)
mask_shape = ( batch_size , 1 , seq_length , seq_length )
full_attention_mask : dict = ColoAttention . prepare_attn_kwargs (
mask_shape ,
inputs_embeds . dtype ,
inputs_embeds . device ,
q_padding_mask = attention_mask ,
is_causal = True ,
)
# Rotary positional embeddings
rotary_pos_emb = self . rotary_pos_emb ( self . seq_length )
if position_ids is not None :
rotary_pos_emb = rotary_pos_emb [ position_ids ]
else :
rotary_pos_emb = rotary_pos_emb [ None , : seq_length ]
rotary_pos_emb = rotary_pos_emb . transpose ( 0 , 1 ) . contiguous ( )
# Run encoder.
hidden_states , presents , all_hidden_states , all_self_attentions = self . encoder (
inputs_embeds ,
full_attention_mask ,
rotary_pos_emb = rotary_pos_emb ,
kv_caches = past_key_values ,
use_cache = use_cache ,
output_hidden_states = output_hidden_states ,
)
if not return_dict :
return tuple ( v for v in [ hidden_states , presents , all_hidden_states , all_self_attentions ] if v is not None )
return BaseModelOutputWithPast (
last_hidden_state = hidden_states ,
past_key_values = presents ,
hidden_states = all_hidden_states ,
attentions = all_self_attentions ,
)
return forward