|
|
|
@ -16,7 +16,7 @@ from transformers.utils import logging
|
|
|
|
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
|
|
|
|
from colossalai.shardformer.shard import ShardConfig
|
|
|
|
|
|
|
|
|
|
from ..layer import ColoAttention
|
|
|
|
|
from ..layer import ColoAttention, cross_entropy_1d
|
|
|
|
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
|
|
|
@ -270,11 +270,22 @@ class MistralForwards:
|
|
|
|
|
shift_labels = labels[..., 1:].contiguous()
|
|
|
|
|
# Flatten the tokens
|
|
|
|
|
loss_fct = CrossEntropyLoss()
|
|
|
|
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
|
|
|
#shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
|
|
|
shift_labels = shift_labels.view(-1)
|
|
|
|
|
# Enable model parallelism
|
|
|
|
|
shift_labels = shift_labels.to(shift_logits.device)
|
|
|
|
|
loss = loss_fct(shift_logits, shift_labels)
|
|
|
|
|
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
|
|
|
|
|
new_vocab_size = logits.shape[-1]
|
|
|
|
|
shift_logits = shift_logits.view(-1, new_vocab_size)
|
|
|
|
|
loss = cross_entropy_1d(
|
|
|
|
|
shift_logits,
|
|
|
|
|
shift_labels,
|
|
|
|
|
process_group=shard_config.tensor_parallel_process_group,
|
|
|
|
|
vocab_size=self.lm_head.out_features,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
|
|
|
loss = loss_fct(shift_logits, shift_labels)
|
|
|
|
|
|
|
|
|
|
if not return_dict:
|
|
|
|
|
output = (logits,) + outputs[1:]
|
|
|
|
@ -609,3 +620,105 @@ def get_mistral_flash_attention_forward(shard_config: ShardConfig):
|
|
|
|
|
return attn_output, None, past_key_value
|
|
|
|
|
|
|
|
|
|
return forward
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|
|
|
|
from transformers import MistralForCausalLM
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self: MistralForCausalLM,
|
|
|
|
|
input_ids: torch.LongTensor = None,
|
|
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
|
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
|
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
|
|
|
labels: Optional[torch.LongTensor] = None,
|
|
|
|
|
use_cache: Optional[bool] = None,
|
|
|
|
|
output_attentions: Optional[bool] = None,
|
|
|
|
|
output_hidden_states: Optional[bool] = None,
|
|
|
|
|
return_dict: Optional[bool] = None,
|
|
|
|
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
|
|
|
r"""
|
|
|
|
|
Args:
|
|
|
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
|
|
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
|
|
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
|
|
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
>>> from transformers import AutoTokenizer, MistralForCausalLM
|
|
|
|
|
|
|
|
|
|
>>> model = MistralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
|
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
|
|
|
|
|
|
|
|
|
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
|
|
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
|
|
|
|
|
|
|
|
>>> # Generate
|
|
|
|
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
|
|
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
|
|
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
|
|
|
```"""
|
|
|
|
|
|
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
|
|
|
output_hidden_states = (
|
|
|
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
|
|
|
)
|
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
|
|
|
|
|
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
|
|
|
outputs = self.model(
|
|
|
|
|
input_ids=input_ids,
|
|
|
|
|
attention_mask=attention_mask,
|
|
|
|
|
position_ids=position_ids,
|
|
|
|
|
past_key_values=past_key_values,
|
|
|
|
|
inputs_embeds=inputs_embeds,
|
|
|
|
|
use_cache=use_cache,
|
|
|
|
|
output_attentions=output_attentions,
|
|
|
|
|
output_hidden_states=output_hidden_states,
|
|
|
|
|
return_dict=return_dict,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
hidden_states = outputs[0]
|
|
|
|
|
if self.config.pretraining_tp > 1:
|
|
|
|
|
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
|
|
|
|
|
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
|
|
|
|
|
logits = torch.cat(logits, dim=-1)
|
|
|
|
|
else:
|
|
|
|
|
logits = self.lm_head(hidden_states)
|
|
|
|
|
logits = logits.float()
|
|
|
|
|
|
|
|
|
|
loss = None
|
|
|
|
|
if labels is not None:
|
|
|
|
|
# Shift so that tokens < n predict n
|
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous()
|
|
|
|
|
shift_labels = labels[..., 1:].contiguous()
|
|
|
|
|
shift_labels = shift_labels.view(-1)
|
|
|
|
|
# Enable model parallelism
|
|
|
|
|
shift_labels = shift_labels.to(shift_logits.device)
|
|
|
|
|
new_vocab_size = logits.shape[-1]
|
|
|
|
|
shift_logits = shift_logits.view(-1, new_vocab_size)
|
|
|
|
|
loss = cross_entropy_1d(
|
|
|
|
|
shift_logits,
|
|
|
|
|
shift_labels,
|
|
|
|
|
process_group=shard_config.tensor_parallel_process_group,
|
|
|
|
|
vocab_size=self.lm_head.out_features,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if not return_dict:
|
|
|
|
|
output = (logits,) + outputs[1:]
|
|
|
|
|
return (loss,) + output if loss is not None else output
|
|
|
|
|
|
|
|
|
|
return CausalLMOutputWithPast(
|
|
|
|
|
loss=loss,
|
|
|
|
|
logits=logits,
|
|
|
|
|
past_key_values=outputs.past_key_values,
|
|
|
|
|
hidden_states=outputs.hidden_states,
|
|
|
|
|
attentions=outputs.attentions,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return forward
|
|
|
|
|