mirror of https://github.com/hpcaitech/ColossalAI
Merge pull request #5684 from wangbluo/parallel_output
[Shardformer] Add Parallel output for shardformer modelspull/5525/head
commit
22297789ab
|
@ -16,7 +16,7 @@ from transformers.utils import logging
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.shardformer.shard import ShardConfig
|
from colossalai.shardformer.shard import ShardConfig
|
||||||
|
|
||||||
from ..layer import ColoAttention
|
from ..layer import ColoAttention, cross_entropy_1d
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
@ -270,11 +270,21 @@ class MistralForwards:
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
# Flatten the tokens
|
# Flatten the tokens
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
||||||
shift_labels = shift_labels.view(-1)
|
shift_labels = shift_labels.view(-1)
|
||||||
# Enable model parallelism
|
# Enable model parallelism
|
||||||
shift_labels = shift_labels.to(shift_logits.device)
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
loss = loss_fct(shift_logits, shift_labels)
|
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
|
||||||
|
new_vocab_size = logits.shape[-1]
|
||||||
|
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||||
|
loss = cross_entropy_1d(
|
||||||
|
shift_logits,
|
||||||
|
shift_labels,
|
||||||
|
process_group=shard_config.tensor_parallel_process_group,
|
||||||
|
vocab_size=self.lm_head.out_features,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[1:]
|
||||||
|
@ -609,3 +619,100 @@ def get_mistral_flash_attention_forward(shard_config: ShardConfig):
|
||||||
return attn_output, None, past_key_value
|
return attn_output, None, past_key_value
|
||||||
|
|
||||||
return forward
|
return forward
|
||||||
|
|
||||||
|
|
||||||
|
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||||
|
from transformers import MistralForCausalLM
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self: MistralForCausalLM,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, MistralForCausalLM
|
||||||
|
|
||||||
|
>>> model = MistralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||||
|
|
||||||
|
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||||
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||||
|
```"""
|
||||||
|
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
logits = logits.float()
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
new_vocab_size = logits.shape[-1]
|
||||||
|
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||||
|
loss = cross_entropy_1d(
|
||||||
|
shift_logits,
|
||||||
|
shift_labels,
|
||||||
|
process_group=shard_config.tensor_parallel_process_group,
|
||||||
|
vocab_size=self.lm_head.out_features,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
return forward
|
||||||
|
|
|
@ -22,6 +22,8 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.shardformer.layer import ColoAttention
|
from colossalai.shardformer.layer import ColoAttention
|
||||||
from colossalai.shardformer.shard import ShardConfig
|
from colossalai.shardformer.shard import ShardConfig
|
||||||
|
|
||||||
|
from ..layer import cross_entropy_1d
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -336,8 +338,22 @@ class OPTPipelineForwards:
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
# Flatten the tokens
|
# Flatten the tokens
|
||||||
loss_fct = CrossEntropyLoss()
|
|
||||||
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
|
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
|
||||||
|
new_vocab_size = logits.shape[-1]
|
||||||
|
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
loss = cross_entropy_1d(
|
||||||
|
shift_logits,
|
||||||
|
shift_labels,
|
||||||
|
process_group=shard_config.tensor_parallel_process_group,
|
||||||
|
vocab_size=self.lm_head.out_features,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[1:]
|
||||||
return (loss,) + output if loss is not None else output
|
return (loss,) + output if loss is not None else output
|
||||||
|
@ -844,3 +860,146 @@ def get_jit_fused_opt_decoder_layer_forward():
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
return forward
|
return forward
|
||||||
|
|
||||||
|
|
||||||
|
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||||
|
def forward(
|
||||||
|
self: OPTForCausalLM,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
||||||
|
provide it.
|
||||||
|
|
||||||
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||||
|
[`PreTrainedTokenizer.__call__`] for details.
|
||||||
|
|
||||||
|
[What are input IDs?](../glossary#input-ids)
|
||||||
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
[What are attention masks?](../glossary#attention-mask)
|
||||||
|
head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
|
||||||
|
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||||
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
||||||
|
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
|
||||||
|
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
|
||||||
|
tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
|
||||||
|
|
||||||
|
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
|
||||||
|
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||||||
|
|
||||||
|
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
|
||||||
|
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
|
||||||
|
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
||||||
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||||
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
||||||
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||||
|
than the model's internal embedding lookup matrix.
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
use_cache (`bool`, *optional*):
|
||||||
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||||
|
(see `past_key_values`).
|
||||||
|
output_attentions (`bool`, *optional*):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
|
returned tensors for more detail.
|
||||||
|
output_hidden_states (`bool`, *optional*):
|
||||||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
||||||
|
for more detail.
|
||||||
|
return_dict (`bool`, *optional*):
|
||||||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, OPTForCausalLM
|
||||||
|
|
||||||
|
>>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||||
|
|
||||||
|
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||||
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo."
|
||||||
|
```"""
|
||||||
|
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model.decoder(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
logits = self.lm_head(outputs[0]).contiguous()
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# move labels to correct device to enable model parallelism
|
||||||
|
labels = labels.to(logits.device)
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
new_vocab_size = logits.shape[-1]
|
||||||
|
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||||
|
loss = cross_entropy_1d(
|
||||||
|
shift_logits,
|
||||||
|
shift_labels,
|
||||||
|
process_group=shard_config.tensor_parallel_process_group,
|
||||||
|
vocab_size=self.lm_head.out_features,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
return forward
|
||||||
|
|
|
@ -18,6 +18,7 @@ from colossalai.shardformer.layer import (
|
||||||
|
|
||||||
from ..modeling.mistral import (
|
from ..modeling.mistral import (
|
||||||
MistralForwards,
|
MistralForwards,
|
||||||
|
get_lm_forward_with_dist_cross_entropy,
|
||||||
get_mistral_flash_attention_forward,
|
get_mistral_flash_attention_forward,
|
||||||
get_mistral_model_forward_for_flash_attn,
|
get_mistral_model_forward_for_flash_attn,
|
||||||
)
|
)
|
||||||
|
@ -275,14 +276,18 @@ class MistralForCausalLMPolicy(MistralPolicy):
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="lm_head",
|
suffix="lm_head",
|
||||||
target_module=VocabParallelLMHead1D,
|
target_module=VocabParallelLMHead1D,
|
||||||
kwargs=dict(
|
kwargs={
|
||||||
gather_output=True,
|
"gather_output": not self.shard_config.parallel_output,
|
||||||
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
|
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
if self.shard_config.parallel_output:
|
||||||
|
new_item[MistralForCausalLM].method_replacement = {
|
||||||
|
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
new_item = {
|
new_item = {
|
||||||
MistralForCausalLM: ModulePolicyDescription(
|
MistralForCausalLM: ModulePolicyDescription(
|
||||||
|
|
|
@ -21,6 +21,7 @@ from ..modeling.jit import get_jit_fused_dropout_add_func
|
||||||
from ..modeling.opt import (
|
from ..modeling.opt import (
|
||||||
OPTPipelineForwards,
|
OPTPipelineForwards,
|
||||||
get_jit_fused_opt_decoder_layer_forward,
|
get_jit_fused_opt_decoder_layer_forward,
|
||||||
|
get_lm_forward_with_dist_cross_entropy,
|
||||||
get_opt_decoder_forward_for_flash_attention,
|
get_opt_decoder_forward_for_flash_attention,
|
||||||
get_opt_flash_attention_forward,
|
get_opt_flash_attention_forward,
|
||||||
)
|
)
|
||||||
|
@ -269,12 +270,18 @@ class OPTForCausalLMPolicy(OPTPolicy):
|
||||||
suffix="lm_head",
|
suffix="lm_head",
|
||||||
target_module=VocabParallelLMHead1D,
|
target_module=VocabParallelLMHead1D,
|
||||||
kwargs=dict(
|
kwargs=dict(
|
||||||
gather_output=True, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by
|
gather_output=not self.shard_config.parallel_output,
|
||||||
|
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=OPTForCausalLM,
|
target_key=OPTForCausalLM,
|
||||||
)
|
)
|
||||||
|
if self.shard_config.parallel_output:
|
||||||
|
method_replacement = {"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}
|
||||||
|
self.append_or_create_method_replacement(
|
||||||
|
description=method_replacement, policy=policy, target_key=OPTForCausalLM
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.append_or_create_submodule_replacement(
|
self.append_or_create_submodule_replacement(
|
||||||
description=SubModuleReplacementDescription(
|
description=SubModuleReplacementDescription(
|
||||||
|
|
Loading…
Reference in New Issue