@ -14,6 +14,7 @@ from transformers.modeling_attn_mask_utils import (
from transformers . modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions ,
CausalLMOutputWithCrossAttentions ,
CausalLMOutputWithPast ,
QuestionAnsweringModelOutput ,
SequenceClassifierOutputWithPast ,
TokenClassifierOutput ,
@ -31,6 +32,8 @@ from transformers.utils import logging
from colossalai . pipeline . stage_manager import PipelineStageManager
from colossalai . shardformer . shard import ShardConfig
from . . layer import cross_entropy_1d
def build_falcon_alibi_tensor_fn ( process_group : ProcessGroup ) - > torch . Tensor :
def build_falcon_alibi_tensor (
@ -437,14 +440,28 @@ class FalconPipelineForwards:
loss = None
if labels is not None :
# Shift so that tokens < n predict n
labels = labels . to ( lm_logits . device )
shift_logits = lm_logits [ . . . , : - 1 , : ] . contiguous ( )
shift_labels = labels [ . . . , 1 : ] . contiguous ( )
batch_size , seq_length , vocab_size = shift_logits . shape
# Flatten the tokens
loss_fct = CrossEntropyLoss ( )
loss = loss_fct (
shift_logits . view ( batch_size * seq_length , vocab_size ) , shift_labels . view ( batch_size * seq_length )
)
if shard_config . enable_tensor_parallelism and shard_config . parallel_output :
new_vocab_size = shift_logits . shape [ - 1 ]
shift_logits = shift_logits . view ( - 1 , new_vocab_size )
shift_labels = shift_labels . view ( - 1 )
loss = cross_entropy_1d (
shift_logits ,
shift_labels ,
process_group = shard_config . tensor_parallel_process_group ,
vocab_size = self . lm_head . out_features ,
dtype = self . transformer . dtype ,
)
else :
loss = loss_fct (
shift_logits . view ( batch_size * seq_length , vocab_size ) ,
shift_labels . view ( batch_size * seq_length ) ,
)
if not return_dict :
output = ( lm_logits , ) + transformer_outputs [ 1 : ]
@ -747,3 +764,79 @@ class FalconPipelineForwards:
else :
hidden_states = outputs . get ( " hidden_states " )
return { " hidden_states " : hidden_states }
def get_lm_forward_with_dist_cross_entropy ( shard_config : ShardConfig ) :
from transformers import FalconForCausalLM
def forward (
self : FalconForCausalLM ,
input_ids : Optional [ torch . LongTensor ] = None ,
past_key_values : Optional [ Tuple [ Tuple [ torch . Tensor , torch . Tensor ] , . . . ] ] = None ,
attention_mask : Optional [ torch . Tensor ] = None ,
head_mask : Optional [ torch . Tensor ] = None ,
inputs_embeds : Optional [ torch . Tensor ] = None ,
labels : Optional [ torch . Tensor ] = None ,
use_cache : Optional [ bool ] = None ,
output_attentions : Optional [ bool ] = None ,
output_hidden_states : Optional [ bool ] = None ,
return_dict : Optional [ bool ] = None ,
) - > Union [ Tuple , CausalLMOutputWithPast ] :
r """
labels ( ` torch . LongTensor ` of shape ` ( batch_size , sequence_length ) ` , * optional * ) :
Labels for language modeling . Note that the labels * * are shifted * * inside the model , i . e . you can set
` labels = input_ids ` Indices are selected in ` [ - 100 , 0 , . . . , config . vocab_size ] ` All labels set to ` - 100 `
are ignored ( masked ) , the loss is only computed for labels in ` [ 0 , . . . , config . vocab_size ] `
"""
output_attentions = output_attentions if output_attentions is not None else self . config . output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self . config . output_hidden_states
)
return_dict = return_dict if return_dict is not None else self . config . use_return_dict
transformer_outputs = self . transformer (
input_ids ,
past_key_values = past_key_values ,
attention_mask = attention_mask ,
head_mask = head_mask ,
inputs_embeds = inputs_embeds ,
use_cache = use_cache ,
output_attentions = output_attentions ,
output_hidden_states = output_hidden_states ,
return_dict = return_dict ,
)
past_key_values = None
hidden_states = transformer_outputs [ 0 ]
lm_logits = self . lm_head ( hidden_states )
loss = None
if labels is not None :
# Shift so that tokens < n predict n
labels = labels . to ( lm_logits . device )
shift_logits = lm_logits [ . . . , : - 1 , : ] . contiguous ( )
shift_labels = labels [ . . . , 1 : ] . contiguous ( )
batch_size , seq_length , vocab_size = shift_logits . shape
# Flatten the tokens
new_vocab_size = shift_logits . shape [ - 1 ]
shift_logits = shift_logits . view ( - 1 , new_vocab_size )
shift_labels = shift_labels . view ( - 1 )
loss = cross_entropy_1d (
shift_logits ,
shift_labels ,
process_group = shard_config . tensor_parallel_process_group ,
vocab_size = self . lm_head . out_features ,
dtype = self . transformer . dtype ,
)
if not return_dict :
output = ( lm_logits , ) + transformer_outputs [ 1 : ]
return ( ( loss , ) + output ) if loss is not None else output
return CausalLMOutputWithPast (
loss = loss ,
logits = lm_logits ,
past_key_values = transformer_outputs . past_key_values ,
hidden_states = transformer_outputs . hidden_states ,
attentions = transformer_outputs . attentions ,
)
return forward