mirror of https://github.com/hpcaitech/ColossalAI
Merge pull request #5684 from wangbluo/parallel_output
[Shardformer] Add Parallel output for shardformer modelspull/5525/head
commit
22297789ab
|
@ -16,7 +16,7 @@ from transformers.utils import logging
|
|||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import ColoAttention
|
||||
from ..layer import ColoAttention, cross_entropy_1d
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
@ -270,11 +270,21 @@ class MistralForwards:
|
|||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
|
||||
new_vocab_size = logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
)
|
||||
else:
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
@ -609,3 +619,100 @@ def get_mistral_flash_attention_forward(shard_config: ShardConfig):
|
|||
return attn_output, None, past_key_value
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||
from transformers import MistralForCausalLM
|
||||
|
||||
def forward(
|
||||
self: MistralForCausalLM,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, MistralForCausalLM
|
||||
|
||||
>>> model = MistralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
new_vocab_size = logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
return forward
|
||||
|
|
|
@ -22,6 +22,8 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
|||
from colossalai.shardformer.layer import ColoAttention
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import cross_entropy_1d
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
|
@ -336,8 +338,22 @@ class OPTPipelineForwards:
|
|||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
|
||||
|
||||
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
|
||||
new_vocab_size = logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
)
|
||||
else:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
@ -844,3 +860,146 @@ def get_jit_fused_opt_decoder_layer_forward():
|
|||
return outputs
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||
def forward(
|
||||
self: OPTForCausalLM,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
||||
provide it.
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
||||
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
|
||||
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
|
||||
tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
|
||||
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||||
|
||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
|
||||
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
|
||||
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
||||
for more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, OPTForCausalLM
|
||||
|
||||
>>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo."
|
||||
```"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model.decoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
logits = self.lm_head(outputs[0]).contiguous()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(logits.device)
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
new_vocab_size = logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
return forward
|
||||
|
|
|
@ -18,6 +18,7 @@ from colossalai.shardformer.layer import (
|
|||
|
||||
from ..modeling.mistral import (
|
||||
MistralForwards,
|
||||
get_lm_forward_with_dist_cross_entropy,
|
||||
get_mistral_flash_attention_forward,
|
||||
get_mistral_model_forward_for_flash_attn,
|
||||
)
|
||||
|
@ -275,14 +276,18 @@ class MistralForCausalLMPolicy(MistralPolicy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=VocabParallelLMHead1D,
|
||||
kwargs=dict(
|
||||
gather_output=True,
|
||||
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
|
||||
),
|
||||
kwargs={
|
||||
"gather_output": not self.shard_config.parallel_output,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
},
|
||||
)
|
||||
]
|
||||
)
|
||||
}
|
||||
if self.shard_config.parallel_output:
|
||||
new_item[MistralForCausalLM].method_replacement = {
|
||||
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
|
||||
}
|
||||
else:
|
||||
new_item = {
|
||||
MistralForCausalLM: ModulePolicyDescription(
|
||||
|
|
|
@ -21,6 +21,7 @@ from ..modeling.jit import get_jit_fused_dropout_add_func
|
|||
from ..modeling.opt import (
|
||||
OPTPipelineForwards,
|
||||
get_jit_fused_opt_decoder_layer_forward,
|
||||
get_lm_forward_with_dist_cross_entropy,
|
||||
get_opt_decoder_forward_for_flash_attention,
|
||||
get_opt_flash_attention_forward,
|
||||
)
|
||||
|
@ -269,12 +270,18 @@ class OPTForCausalLMPolicy(OPTPolicy):
|
|||
suffix="lm_head",
|
||||
target_module=VocabParallelLMHead1D,
|
||||
kwargs=dict(
|
||||
gather_output=True, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by
|
||||
gather_output=not self.shard_config.parallel_output,
|
||||
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=OPTForCausalLM,
|
||||
)
|
||||
if self.shard_config.parallel_output:
|
||||
method_replacement = {"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=OPTForCausalLM
|
||||
)
|
||||
else:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
|
|
Loading…
Reference in New Issue