add parallel output for mistral model

pull/5684/head
wangbluo 7 months ago
parent d3f34ee8cc
commit 9efc79ef24

@ -16,7 +16,7 @@ from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention from ..layer import ColoAttention, cross_entropy_1d
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -270,11 +270,22 @@ class MistralForwards:
shift_labels = labels[..., 1:].contiguous() shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens # Flatten the tokens
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size) #shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1) shift_labels = shift_labels.view(-1)
# Enable model parallelism # Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device) shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels) if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
)
else:
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
@ -609,3 +620,105 @@ def get_mistral_flash_attention_forward(shard_config: ShardConfig):
return attn_output, None, past_key_value return attn_output, None, past_key_value
return forward return forward
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
from transformers import MistralForCausalLM
def forward(
self: MistralForCausalLM,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, MistralForCausalLM
>>> model = MistralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
if self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
return forward

@ -20,6 +20,7 @@ from ..modeling.mistral import (
MistralForwards, MistralForwards,
get_mistral_flash_attention_forward, get_mistral_flash_attention_forward,
get_mistral_model_forward_for_flash_attn, get_mistral_model_forward_for_flash_attn,
get_lm_forward_with_dist_cross_entropy,
) )
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@ -275,14 +276,19 @@ class MistralForCausalLMPolicy(MistralPolicy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="lm_head", suffix="lm_head",
target_module=VocabParallelLMHead1D, target_module=VocabParallelLMHead1D,
kwargs=dict( kwargs={
gather_output=True, #gather_output=True,
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by, "gather_output": not self.shard_config.parallel_output,
), "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
},
) )
] ]
) )
} }
if self.shard_config.parallel_output:
new_item[MistralForCausalLM].method_replacement = {
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
}
else: else:
new_item = { new_item = {
MistralForCausalLM: ModulePolicyDescription( MistralForCausalLM: ModulePolicyDescription(

Loading…
Cancel
Save