|
|
|
@ -1,13 +1,15 @@
|
|
|
|
|
import math
|
|
|
|
|
import warnings
|
|
|
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
import torch.utils.checkpoint
|
|
|
|
|
from torch import nn
|
|
|
|
|
from torch.nn import CrossEntropyLoss
|
|
|
|
|
from transformers.cache_utils import Cache, DynamicCache
|
|
|
|
|
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
|
|
|
|
from transformers.models.cohere.modeling_cohere import CohereForCausalLM, CohereModel, StaticCache, repeat_kv
|
|
|
|
|
from transformers.models.cohere.modeling_cohere import CohereForCausalLM, CohereModel, StaticCache, apply_rotary_pos_emb, repeat_kv
|
|
|
|
|
from transformers.utils import logging
|
|
|
|
|
|
|
|
|
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
|
|
|
@ -333,121 +335,28 @@ class CommandPipelineForwards:
|
|
|
|
|
return {"hidden_states": hidden_states}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|
|
|
|
from transformers import CohereForCausalLM
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self: CohereForCausalLM,
|
|
|
|
|
input_ids: torch.LongTensor = None,
|
|
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
|
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
|
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
|
|
|
labels: Optional[torch.LongTensor] = None,
|
|
|
|
|
use_cache: Optional[bool] = None,
|
|
|
|
|
output_attentions: Optional[bool] = None,
|
|
|
|
|
output_hidden_states: Optional[bool] = None,
|
|
|
|
|
return_dict: Optional[bool] = None,
|
|
|
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
|
|
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
|
|
|
r"""
|
|
|
|
|
Args:
|
|
|
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
|
|
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
|
|
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
|
|
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
>>> from transformers import AutoTokenizer, CohereForCausalLM
|
|
|
|
|
|
|
|
|
|
>>> model = CohereForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
|
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
|
|
|
|
|
|
|
|
|
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
|
|
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
|
|
|
|
|
|
|
|
>>> # Generate
|
|
|
|
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
|
|
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
|
|
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
|
|
|
```"""
|
|
|
|
|
|
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
|
|
|
output_hidden_states = (
|
|
|
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
|
|
|
)
|
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
|
|
|
|
|
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
|
|
|
outputs = self.model(
|
|
|
|
|
input_ids=input_ids,
|
|
|
|
|
attention_mask=attention_mask,
|
|
|
|
|
position_ids=position_ids,
|
|
|
|
|
past_key_values=past_key_values,
|
|
|
|
|
inputs_embeds=inputs_embeds,
|
|
|
|
|
use_cache=use_cache,
|
|
|
|
|
output_attentions=output_attentions,
|
|
|
|
|
output_hidden_states=output_hidden_states,
|
|
|
|
|
return_dict=return_dict,
|
|
|
|
|
cache_position=cache_position,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
hidden_states = outputs[0]
|
|
|
|
|
|
|
|
|
|
logits = self.lm_head(hidden_states)
|
|
|
|
|
logits = logits * self.logit_scale
|
|
|
|
|
logits = logits.float()
|
|
|
|
|
|
|
|
|
|
loss = None
|
|
|
|
|
if labels is not None:
|
|
|
|
|
# Shift so that tokens < n predict n
|
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous()
|
|
|
|
|
shift_labels = labels[..., 1:].contiguous()
|
|
|
|
|
shift_labels = shift_labels.view(-1)
|
|
|
|
|
# Enable model parallelism
|
|
|
|
|
shift_labels = shift_labels.to(shift_logits.device)
|
|
|
|
|
new_vocab_size = logits.shape[-1]
|
|
|
|
|
shift_logits = shift_logits.view(-1, new_vocab_size)
|
|
|
|
|
loss = cross_entropy_1d(
|
|
|
|
|
shift_logits,
|
|
|
|
|
shift_labels,
|
|
|
|
|
process_group=shard_config.tensor_parallel_process_group,
|
|
|
|
|
vocab_size=self.lm_head.out_features,
|
|
|
|
|
dtype=self.model.dtype,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if not return_dict:
|
|
|
|
|
output = (logits,) + outputs[1:]
|
|
|
|
|
return (loss,) + output if loss is not None else output
|
|
|
|
|
|
|
|
|
|
return CausalLMOutputWithPast(
|
|
|
|
|
loss=loss,
|
|
|
|
|
logits=logits,
|
|
|
|
|
past_key_values=outputs.past_key_values,
|
|
|
|
|
hidden_states=outputs.hidden_states,
|
|
|
|
|
attentions=outputs.attentions,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return forward
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group, use_flash_attention):
|
|
|
|
|
from transformers.models.cohere.modeling_cohere import apply_rotary_pos_emb
|
|
|
|
|
|
|
|
|
|
def get_command_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
|
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
|
|
|
past_key_value: Optional[Cache] = None,
|
|
|
|
|
output_attentions: bool = False,
|
|
|
|
|
use_cache: bool = False,
|
|
|
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
|
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
|
|
|
**kwargs,
|
|
|
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
|
|
|
|
if sp_mode is not None:
|
|
|
|
|
assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode"
|
|
|
|
|
assert (sp_size is not None) and (
|
|
|
|
|
sp_group is not None
|
|
|
|
|
), "Must specify sp_size and sp_group for sequence parallel"
|
|
|
|
|
if "padding_mask" in kwargs:
|
|
|
|
|
warnings.warn(
|
|
|
|
|
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
bsz, q_len, _ = hidden_states.size()
|
|
|
|
|
# sp: modify sp_len when sequence parallel mode is ring
|
|
|
|
|
if sp_mode in ["split_gather", "ring"]:
|
|
|
|
@ -468,29 +377,46 @@ def get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group, use_f
|
|
|
|
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
|
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
past_key_value = getattr(self, "past_key_value", past_key_value)
|
|
|
|
|
kv_seq_len = key_states.shape[-2]
|
|
|
|
|
if past_key_value is not None:
|
|
|
|
|
if self.layer_idx is None:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
|
|
|
|
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
|
|
|
|
"with a layer index."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
|
|
|
|
|
|
|
|
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
|
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
|
|
|
|
|
|
|
|
if past_key_value is not None:
|
|
|
|
|
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
|
|
|
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
|
|
|
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
|
|
|
|
|
|
|
|
# repeat k/v heads if n_kv_heads < n_heads
|
|
|
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
|
|
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
|
|
|
if use_flash_attention:
|
|
|
|
|
|
|
|
|
|
if shard_config.enable_flash_attention:
|
|
|
|
|
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
|
|
|
|
|
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
|
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
|
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
|
|
|
|
else:
|
|
|
|
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
|
|
|
|
|
|
|
|
|
if attention_mask is not None: # no matter the length, we just slice it
|
|
|
|
|
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
|
|
|
|
attn_weights = attn_weights + causal_mask
|
|
|
|
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
|
|
|
|
f" {attn_weights.size()}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if attention_mask is not None:
|
|
|
|
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
|
|
|
|
)
|
|
|
|
|
attn_weights = attn_weights + attention_mask
|
|
|
|
|
|
|
|
|
|
# upcast attention to fp32
|
|
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
|
|
|
@ -502,25 +428,28 @@ def get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group, use_f
|
|
|
|
|
f" {attn_output.size()}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
|
|
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
|
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
|
|
# sp: all-to-all comminucation when introducing sequence parallel
|
|
|
|
|
if sp_mode == "all_to_all":
|
|
|
|
|
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
|
|
|
|
|
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
|
|
|
|
|
else:
|
|
|
|
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
|
|
|
|
|
|
|
|
|
attn_output = self.o_proj(attn_output)
|
|
|
|
|
|
|
|
|
|
if not output_attentions or use_flash_attention:
|
|
|
|
|
if not output_attentions:
|
|
|
|
|
attn_weights = None
|
|
|
|
|
return attn_output, attn_weights, past_key_value
|
|
|
|
|
|
|
|
|
|
return forward
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_command_seq_parallel_model_forward(sp_mode, sp_size, sp_group, use_flash_attention):
|
|
|
|
|
def get_command_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self: CohereModel,
|
|
|
|
|
self,
|
|
|
|
|
input_ids: torch.LongTensor = None,
|
|
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
|
|
@ -537,18 +466,14 @@ def get_command_seq_parallel_model_forward(sp_mode, sp_size, sp_group, use_flash
|
|
|
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
|
|
|
)
|
|
|
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
|
|
|
|
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
|
|
|
|
|
|
# retrieve input_ids and inputs_embeds
|
|
|
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
|
|
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time, and must specify either one"
|
|
|
|
|
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if inputs_embeds is None:
|
|
|
|
|
inputs_embeds = self.embed_tokens(input_ids)
|
|
|
|
|
|
|
|
|
|
if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
|
|
|
|
|
if use_cache:
|
|
|
|
|
logger.warning_once(
|
|
|
|
@ -556,7 +481,11 @@ def get_command_seq_parallel_model_forward(sp_mode, sp_size, sp_group, use_flash
|
|
|
|
|
)
|
|
|
|
|
use_cache = False
|
|
|
|
|
|
|
|
|
|
if inputs_embeds is None:
|
|
|
|
|
inputs_embeds = self.embed_tokens(input_ids)
|
|
|
|
|
|
|
|
|
|
past_seen_tokens = 0
|
|
|
|
|
seq_len = inputs_embeds.shape[1]
|
|
|
|
|
if use_cache: # kept for BC (cache positions)
|
|
|
|
|
if not isinstance(past_key_values, StaticCache):
|
|
|
|
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
|
|
|
@ -564,18 +493,18 @@ def get_command_seq_parallel_model_forward(sp_mode, sp_size, sp_group, use_flash
|
|
|
|
|
if cache_position is None:
|
|
|
|
|
if isinstance(past_key_values, StaticCache):
|
|
|
|
|
raise ValueError("cache_position is a required argument when using StaticCache.")
|
|
|
|
|
cache_position = torch.arange(
|
|
|
|
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
|
|
|
|
)
|
|
|
|
|
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=inputs_embeds.device)
|
|
|
|
|
|
|
|
|
|
if position_ids is None:
|
|
|
|
|
position_ids = cache_position.unsqueeze(0)
|
|
|
|
|
if use_flash_attention:
|
|
|
|
|
hidden_states = inputs_embeds
|
|
|
|
|
mask_shape = (hidden_states.shape[0], 1, past_seen_tokens, past_seen_tokens)
|
|
|
|
|
|
|
|
|
|
# in this case, attention_mask is a dict rather than a tensor
|
|
|
|
|
if shard_config.enable_flash_attention:
|
|
|
|
|
mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len)
|
|
|
|
|
attention_mask = ColoAttention.prepare_attn_kwargs(
|
|
|
|
|
mask_shape,
|
|
|
|
|
hidden_states.dtype,
|
|
|
|
|
hidden_states.device,
|
|
|
|
|
inputs_embeds.dtype,
|
|
|
|
|
inputs_embeds.device,
|
|
|
|
|
q_padding_mask=attention_mask,
|
|
|
|
|
is_causal=True,
|
|
|
|
|
)
|
|
|
|
@ -586,32 +515,26 @@ def get_command_seq_parallel_model_forward(sp_mode, sp_size, sp_group, use_flash
|
|
|
|
|
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
|
|
|
|
|
elif sp_mode == "all_to_all":
|
|
|
|
|
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
|
|
|
|
|
|
|
|
|
|
hidden_states = inputs_embeds
|
|
|
|
|
|
|
|
|
|
# decoder layers
|
|
|
|
|
all_hidden_states = () if output_hidden_states else None
|
|
|
|
|
all_self_attns = () if output_attentions else None
|
|
|
|
|
next_decoder_cache = () if use_cache else None
|
|
|
|
|
next_decoder_cache = None
|
|
|
|
|
|
|
|
|
|
for idx, decoder_layer in enumerate(self.layers):
|
|
|
|
|
for decoder_layer in self.layers:
|
|
|
|
|
if output_hidden_states:
|
|
|
|
|
all_hidden_states += (hidden_states,)
|
|
|
|
|
|
|
|
|
|
if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
|
|
|
|
|
|
|
|
|
|
def create_custom_forward(module):
|
|
|
|
|
def custom_forward(*inputs):
|
|
|
|
|
# None for past_key_value
|
|
|
|
|
return module(*inputs, past_key_value=past_key_values, output_attentions=output_attentions)
|
|
|
|
|
|
|
|
|
|
return custom_forward
|
|
|
|
|
|
|
|
|
|
layer_outputs = torch.utils.checkpoint.checkpoint(
|
|
|
|
|
create_custom_forward(decoder_layer),
|
|
|
|
|
if self.gradient_checkpointing and self.training:
|
|
|
|
|
layer_outputs = self._gradient_checkpointing_func(
|
|
|
|
|
decoder_layer.__call__,
|
|
|
|
|
hidden_states,
|
|
|
|
|
attention_mask,
|
|
|
|
|
position_ids,
|
|
|
|
|
past_key_values,
|
|
|
|
|
output_attentions,
|
|
|
|
|
use_cache,
|
|
|
|
|
cache_position,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
@ -628,11 +551,7 @@ def get_command_seq_parallel_model_forward(sp_mode, sp_size, sp_group, use_flash
|
|
|
|
|
hidden_states = layer_outputs[0]
|
|
|
|
|
|
|
|
|
|
if use_cache:
|
|
|
|
|
next_decoder_cache = (
|
|
|
|
|
next_decoder_cache.to_legacy_cache()
|
|
|
|
|
if isinstance(next_decoder_cache, Cache)
|
|
|
|
|
else next_decoder_cache
|
|
|
|
|
)
|
|
|
|
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
|
|
|
|
|
|
|
|
|
if output_attentions:
|
|
|
|
|
all_self_attns += (layer_outputs[1],)
|
|
|
|
@ -648,7 +567,11 @@ def get_command_seq_parallel_model_forward(sp_mode, sp_size, sp_group, use_flash
|
|
|
|
|
if output_hidden_states:
|
|
|
|
|
all_hidden_states += (hidden_states,)
|
|
|
|
|
|
|
|
|
|
next_cache = next_decoder_cache if use_cache else None
|
|
|
|
|
next_cache = None
|
|
|
|
|
if use_cache:
|
|
|
|
|
next_cache = (
|
|
|
|
|
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
|
|
|
|
|
)
|
|
|
|
|
if not return_dict:
|
|
|
|
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
|
|
|
|
|
|
|
|
@ -660,3 +583,104 @@ def get_command_seq_parallel_model_forward(sp_mode, sp_size, sp_group, use_flash
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return forward
|
|
|
|
|
|
|
|
|
|
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|
|
|
|
from transformers import CohereForCausalLM
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self: CohereForCausalLM,
|
|
|
|
|
input_ids: torch.LongTensor = None,
|
|
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
|
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
|
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
|
|
|
labels: Optional[torch.LongTensor] = None,
|
|
|
|
|
use_cache: Optional[bool] = None,
|
|
|
|
|
output_attentions: Optional[bool] = None,
|
|
|
|
|
output_hidden_states: Optional[bool] = None,
|
|
|
|
|
return_dict: Optional[bool] = None,
|
|
|
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
|
|
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
|
|
|
r"""
|
|
|
|
|
Args:
|
|
|
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
|
|
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
|
|
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
|
|
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
>>> from transformers import AutoTokenizer, CohereForCausalLM
|
|
|
|
|
|
|
|
|
|
>>> model = CohereForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
|
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
|
|
|
|
|
|
|
|
|
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
|
|
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
|
|
|
|
|
|
|
|
>>> # Generate
|
|
|
|
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
|
|
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
|
|
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
|
|
|
```"""
|
|
|
|
|
|
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
|
|
|
output_hidden_states = (
|
|
|
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
|
|
|
)
|
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
|
|
|
|
|
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
|
|
|
outputs = self.model(
|
|
|
|
|
input_ids=input_ids,
|
|
|
|
|
attention_mask=attention_mask,
|
|
|
|
|
position_ids=position_ids,
|
|
|
|
|
past_key_values=past_key_values,
|
|
|
|
|
inputs_embeds=inputs_embeds,
|
|
|
|
|
use_cache=use_cache,
|
|
|
|
|
output_attentions=output_attentions,
|
|
|
|
|
output_hidden_states=output_hidden_states,
|
|
|
|
|
return_dict=return_dict,
|
|
|
|
|
cache_position=cache_position,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
hidden_states = outputs[0]
|
|
|
|
|
|
|
|
|
|
logits = self.lm_head(hidden_states)
|
|
|
|
|
logits = logits * self.logit_scale
|
|
|
|
|
logits = logits.float()
|
|
|
|
|
|
|
|
|
|
loss = None
|
|
|
|
|
if labels is not None:
|
|
|
|
|
# Shift so that tokens < n predict n
|
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous()
|
|
|
|
|
shift_labels = labels[..., 1:].contiguous()
|
|
|
|
|
shift_labels = shift_labels.view(-1)
|
|
|
|
|
# Enable model parallelism
|
|
|
|
|
shift_labels = shift_labels.to(shift_logits.device)
|
|
|
|
|
new_vocab_size = logits.shape[-1]
|
|
|
|
|
shift_logits = shift_logits.view(-1, new_vocab_size)
|
|
|
|
|
loss = cross_entropy_1d(
|
|
|
|
|
shift_logits,
|
|
|
|
|
shift_labels,
|
|
|
|
|
process_group=shard_config.tensor_parallel_process_group,
|
|
|
|
|
vocab_size=self.lm_head.out_features,
|
|
|
|
|
dtype=self.model.dtype,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if not return_dict:
|
|
|
|
|
output = (logits,) + outputs[1:]
|
|
|
|
|
return (loss,) + output if loss is not None else output
|
|
|
|
|
|
|
|
|
|
return CausalLMOutputWithPast(
|
|
|
|
|
loss=loss,
|
|
|
|
|
logits=logits,
|
|
|
|
|
past_key_values=outputs.past_key_values,
|
|
|
|
|
hidden_states=outputs.hidden_states,
|
|
|
|
|
attentions=outputs.attentions,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return forward
|