mirror of https://github.com/hpcaitech/ColossalAI
[Shardformer] Support the Qwen2 model (#5699)
* feat: support qwen2 model * fix: modify model config and add Qwen2RMSNorm * fix qwen2 model conflicts * test: add qwen2 shard test * to: add qwen2 auto policy * support qwen model * fix the conflicts * add try catch * add transformers version for qwen2 * add the ColoAttention for the qwen2 model * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add the unit test version check * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the test input bug * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the version check * fix the version check --------- Co-authored-by: Wenhao Chen <cwher@outlook.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/5703/head
parent
d4c5ef441e
commit
a3cc68ca93
|
@ -0,0 +1,758 @@
|
|||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
SequenceClassifierOutputWithPast,
|
||||
)
|
||||
|
||||
try:
|
||||
from transformers.models.qwen2.modeling_qwen2 import (
|
||||
Qwen2ForCausalLM,
|
||||
Qwen2ForSequenceClassification,
|
||||
Qwen2Model,
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
except ImportError:
|
||||
Qwen2Model = "Qwen2Model"
|
||||
Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification"
|
||||
Qwen2ForCausalLM = "Qwen2ForCausalLM"
|
||||
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import ColoAttention, cross_entropy_1d
|
||||
|
||||
|
||||
class Qwen2PipelineForwards:
|
||||
"""
|
||||
This class serves as a micro library for forward function substitution of Qwen2 models
|
||||
under pipeline setting.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def qwen2_model_forward(
|
||||
self: Qwen2Model,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if stage_manager.is_first_stage():
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
device = hidden_states.device
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if output_attentions:
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||
output_hidden_states = False
|
||||
if use_cache:
|
||||
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
||||
use_cache = False
|
||||
|
||||
# assert past_key_values is None, "past_key_values is not supported for Qwen2 models at the moment."
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
|
||||
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
||||
if is_padding_right:
|
||||
raise ValueError(
|
||||
"You are attempting to perform batched generation with padding_side='right'"
|
||||
" this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to "
|
||||
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
||||
)
|
||||
# embed positions, for the first stage, hidden_states is the input embeddings,
|
||||
# for the other stages, hidden_states is the output of the previous stage
|
||||
if shard_config.enable_flash_attention:
|
||||
# in this case, attention_mask is a dict rather than a tensor
|
||||
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
|
||||
attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||
mask_shape,
|
||||
hidden_states.dtype,
|
||||
hidden_states.device,
|
||||
q_padding_mask=attention_mask,
|
||||
is_causal=True,
|
||||
)
|
||||
else:
|
||||
if self._attn_implementation == "flash_attention_2":
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
elif self._attn_implementation == "sdpa" and not output_attentions:
|
||||
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
|
||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
)
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
hidden_states,
|
||||
past_key_values_length,
|
||||
sliding_window=self.config.sliding_window,
|
||||
)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
# always return dict for imediate stage
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
@staticmethod
|
||||
def qwen2_for_causal_lm_forward(
|
||||
self: Qwen2ForCausalLM,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, Qwen2ForCausalLM
|
||||
|
||||
>>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||
|
||||
>>> prompt = "Hey, are you consciours? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
|
||||
```"""
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if output_attentions:
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||
output_hidden_states = False
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = Qwen2PipelineForwards.qwen2_model_forward(
|
||||
self.model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config,
|
||||
)
|
||||
past_key_values = None
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
if shard_config.enable_tensor_parallelism:
|
||||
new_vocab_size = logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
|
||||
)
|
||||
else:
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
else:
|
||||
hidden_states = outputs.get("hidden_states")
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
@staticmethod
|
||||
def qwen2_for_sequence_classification_forward(
|
||||
self: Qwen2ForSequenceClassification,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
):
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if output_attentions:
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||
output_hidden_states = False
|
||||
|
||||
transformer_outputs = Qwen2PipelineForwards.qwen2_model_forward(
|
||||
self.model,
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config,
|
||||
)
|
||||
|
||||
if input_ids is not None:
|
||||
batch_size = input_ids.shape[0]
|
||||
elif inputs_embeds is not None:
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
else:
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = transformer_outputs[0]
|
||||
logits = self.score(hidden_states)
|
||||
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
print(self.config.pad_token_id)
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
else:
|
||||
if input_ids is not None:
|
||||
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
labels = labels.to(logits.device)
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
if not return_dict:
|
||||
output = (pooled_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutputWithPast(
|
||||
loss=loss,
|
||||
logits=pooled_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = transformer_outputs.get("hidden_states")
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
|
||||
def get_qwen2_flash_attention_forward(shard_config: ShardConfig):
|
||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
from colossalai.shardformer.layer import ColoAttention
|
||||
|
||||
def forward(
|
||||
self: Qwen2Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
||||
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
|
||||
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
||||
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
|
||||
if (
|
||||
getattr(self.config, "sliding_window", None) is not None
|
||||
and kv_seq_len > self.config.sliding_window
|
||||
and cache_has_contents
|
||||
):
|
||||
slicing_tokens = 1 - self.config.sliding_window
|
||||
|
||||
past_key = past_key_value[self.layer_idx][0]
|
||||
past_value = past_key_value[self.layer_idx][1]
|
||||
|
||||
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
|
||||
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
|
||||
|
||||
if past_key.shape[-2] != self.config.sliding_window - 1:
|
||||
raise ValueError(
|
||||
f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
|
||||
f" {past_key.shape}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, slicing_tokens:]
|
||||
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
|
||||
|
||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
|
||||
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig):
|
||||
logger = logging.get_logger(__name__)
|
||||
assert shard_config.enable_flash_attention, "Flash Attention is not enabled."
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# in this case, attention_mask is a dict rather than a tensor
|
||||
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
|
||||
attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||
mask_shape,
|
||||
hidden_states.dtype,
|
||||
hidden_states.device,
|
||||
q_padding_mask=attention_mask,
|
||||
is_causal=True,
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||
def forward(
|
||||
self: Qwen2ForCausalLM,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, Qwen2ForCausalLM
|
||||
|
||||
>>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
if shard_config.enable_tensor_parallelism:
|
||||
new_vocab_size = logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
|
||||
)
|
||||
else:
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
return forward
|
|
@ -182,6 +182,16 @@ _POLICY_LIST = {
|
|||
"transformers.models.mistral.modeling_mistral.MistralForSequenceClassification": PolicyLocation(
|
||||
file_name="mistral", class_name="MistralForSequenceClassificationPolicy"
|
||||
),
|
||||
# Qwen2
|
||||
"transformers.models.qwen2.modeling_qwen2.Qwen2Model": PolicyLocation(
|
||||
file_name="qwen2", class_name="Qwen2ModelPolicy"
|
||||
),
|
||||
"transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM": PolicyLocation(
|
||||
file_name="qwen2", class_name="Qwen2ForCausalLMPolicy"
|
||||
),
|
||||
"transformers.models.qwen2.modeling_qwen2.Qwen2ForSequenceClassification": PolicyLocation(
|
||||
file_name="qwen2", class_name="Qwen2ForSequenceClassificationPolicy"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,369 @@
|
|||
import warnings
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
|
||||
from colossalai.shardformer.layer import (
|
||||
FusedRMSNorm,
|
||||
Linear1D_Col,
|
||||
Linear1D_Row,
|
||||
PaddingEmbedding,
|
||||
RMSNorm,
|
||||
VocabParallelEmbedding1D,
|
||||
)
|
||||
|
||||
from ..modeling.qwen2 import (
|
||||
Qwen2PipelineForwards,
|
||||
get_lm_forward_with_dist_cross_entropy,
|
||||
get_qwen2_flash_attention_forward,
|
||||
get_qwen2_model_forward_for_flash_attn,
|
||||
)
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ["Qwen2Policy", "Qwen2ForCausalLMPolicy", "Qwen2ForSequenceClassificationPolicy"]
|
||||
|
||||
|
||||
class Qwen2Policy(Policy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
import transformers
|
||||
from packaging.version import Version
|
||||
|
||||
assert Version(transformers.__version__) >= Version(
|
||||
"4.39.1"
|
||||
), "The Qwen2 model should run on a transformers version of 4.39.1."
|
||||
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
||||
def preprocess(self):
|
||||
self.tie_weight = self.tie_weight_check()
|
||||
self.origin_attn_implement = self.model.config._attn_implementation
|
||||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
try:
|
||||
from transformers.models.qwen2.modeling_qwen2 import (
|
||||
Qwen2Attention,
|
||||
Qwen2DecoderLayer,
|
||||
Qwen2FlashAttention2,
|
||||
Qwen2Model,
|
||||
Qwen2SdpaAttention,
|
||||
)
|
||||
except ImportError:
|
||||
Qwen2Attention = "Qwen2Attention"
|
||||
Qwen2FlashAttention2 = "Qwen2FlashAttention2"
|
||||
Qwen2SdpaAttention = "Qwen2SdpaAttention"
|
||||
Qwen2DecoderLayer = "Qwen2DecoderLayer"
|
||||
Qwen2Model = "Qwen2Model"
|
||||
|
||||
ATTN_IMPLEMENTATION = {
|
||||
"eager": Qwen2Attention,
|
||||
"flash_attention_2": Qwen2FlashAttention2,
|
||||
"sdpa": Qwen2SdpaAttention,
|
||||
}
|
||||
|
||||
policy = {}
|
||||
|
||||
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
|
||||
embedding_cls = None
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
embedding_cls = VocabParallelEmbedding1D
|
||||
else:
|
||||
if self.tie_weight:
|
||||
embedding_cls = PaddingEmbedding
|
||||
norm_cls = FusedRMSNorm if self.shard_config.enable_fused_normalization else RMSNorm
|
||||
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
self.shard_config.enable_sequence_parallelism = False
|
||||
warnings.warn("Qwen2 doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
decoder_attribute_replacement = {
|
||||
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
}
|
||||
if getattr(self.model.config, "num_key_value_heads", False):
|
||||
decoder_attribute_replacement["self_attn.num_key_value_heads"] = (
|
||||
self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size
|
||||
)
|
||||
|
||||
policy[Qwen2DecoderLayer] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.gate_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.up_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.down_proj",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
if embedding_cls is not None:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
),
|
||||
policy=policy,
|
||||
target_key=Qwen2Model,
|
||||
)
|
||||
|
||||
# optimization configuration
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="input_layernorm",
|
||||
target_module=norm_cls,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="post_attention_layernorm",
|
||||
target_module=norm_cls,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=Qwen2DecoderLayer,
|
||||
)
|
||||
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="norm",
|
||||
target_module=norm_cls,
|
||||
),
|
||||
policy=policy,
|
||||
target_key=Qwen2Model,
|
||||
)
|
||||
|
||||
# use flash attention
|
||||
if self.shard_config.enable_flash_attention:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_qwen2_flash_attention_forward(self.shard_config),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=attn_cls,
|
||||
)
|
||||
if self.pipeline_stage_manager is None:
|
||||
# replace qwen2 model forward method
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_qwen2_model_forward_for_flash_attn(self.shard_config),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=Qwen2Model,
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
||||
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
||||
to customized forward method, and add this changing to policy."""
|
||||
if self.pipeline_stage_manager is None:
|
||||
return
|
||||
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if self.model.__class__.__name__ == "Qwen2Model":
|
||||
module = self.model
|
||||
else:
|
||||
module = self.model.model
|
||||
|
||||
if stage_manager.is_interleave:
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage)
|
||||
method_replacement = {
|
||||
"forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config)
|
||||
}
|
||||
|
||||
else:
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
||||
method_replacement = {
|
||||
"forward": partial(
|
||||
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
|
||||
)
|
||||
}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=model_cls
|
||||
)
|
||||
|
||||
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
assert self.pipeline_stage_manager is not None
|
||||
|
||||
if self.model.__class__.__name__ == "Qwen2Model":
|
||||
module = self.model
|
||||
else:
|
||||
module = self.model.model
|
||||
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
held_layers = []
|
||||
if stage_manager.is_interleave:
|
||||
assert stage_manager.num_model_chunks is not None
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
stage_indices = stage_manager.get_stage_index(layers_per_stage)
|
||||
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||
held_layers.append(module.embed_tokens)
|
||||
for start_idx, end_idx in stage_indices:
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(module.norm)
|
||||
|
||||
else:
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.embed_tokens)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.norm)
|
||||
|
||||
return held_layers
|
||||
|
||||
|
||||
class Qwen2ModelPolicy(Qwen2Policy):
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model
|
||||
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
self.set_pipeline_forward(
|
||||
model_cls=Qwen2Model, new_forward=Qwen2PipelineForwards.qwen2_model_forward, policy=policy
|
||||
)
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
held_layers = super().get_held_layers()
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
"""No shared params in Qwen2 model"""
|
||||
return []
|
||||
|
||||
|
||||
class Qwen2ForCausalLMPolicy(Qwen2Policy):
|
||||
def module_policy(self):
|
||||
from transformers import Qwen2ForCausalLM
|
||||
|
||||
policy = super().module_policy()
|
||||
|
||||
setattr(self.shard_config, "causal_lm", True)
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for casual lm
|
||||
new_item = {
|
||||
Qwen2ForCausalLM: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="lm_head", target_module=Linear1D_Col)
|
||||
],
|
||||
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
|
||||
)
|
||||
}
|
||||
policy.update(new_item)
|
||||
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
self.set_pipeline_forward(
|
||||
model_cls=Qwen2ForCausalLM, new_forward=Qwen2PipelineForwards.qwen2_for_causal_lm_forward, policy=policy
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.lm_head)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
qwen2_model = self.model.model
|
||||
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
||||
if (
|
||||
id(qwen2_model.embed_tokens.weight) == id(self.model.lm_head.weight)
|
||||
and self.pipeline_stage_manager.num_stages > 1
|
||||
):
|
||||
# tie weights
|
||||
return [
|
||||
{
|
||||
0: qwen2_model.embed_tokens.weight,
|
||||
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
||||
}
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
class Qwen2ForSequenceClassificationPolicy(Qwen2Policy):
|
||||
def module_policy(self):
|
||||
from transformers import Qwen2ForSequenceClassification
|
||||
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for sequence classification
|
||||
new_item = {
|
||||
Qwen2ForSequenceClassification: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
|
||||
)
|
||||
]
|
||||
)
|
||||
}
|
||||
policy.update(new_item)
|
||||
# to be confirmed
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
self.set_pipeline_forward(
|
||||
model_cls=Qwen2ForSequenceClassification,
|
||||
new_forward=Qwen2PipelineForwards.qwen2_for_sequence_classification_forward,
|
||||
policy=policy,
|
||||
)
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.score)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
"""No shared params in Qwen2 for sequence classification model"""
|
||||
return []
|
|
@ -17,3 +17,8 @@ try:
|
|||
from .mistral import *
|
||||
except ImportError:
|
||||
print("This version of transformers doesn't support mistral.")
|
||||
|
||||
try:
|
||||
from .qwen2 import *
|
||||
except ImportError:
|
||||
print("This version of transformers doesn't support qwen2.")
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
import torch
|
||||
import transformers
|
||||
|
||||
from ..registry import ModelAttribute, model_zoo
|
||||
|
||||
try:
|
||||
from transformers import Qwen2Config
|
||||
|
||||
HAS_QWEN2 = True
|
||||
except ImportError:
|
||||
HAS_QWEN2 = False
|
||||
|
||||
if HAS_QWEN2:
|
||||
# ===============================
|
||||
# Register Qwen2
|
||||
# ===============================
|
||||
|
||||
def data_gen():
|
||||
# the input ids are corresponding to the sentence
|
||||
# 'Hello, my dog is cute'
|
||||
#
|
||||
# the code is give below:
|
||||
# -----------------------------------
|
||||
# from transformers import Qwen2TokenizerFast
|
||||
# tokenizer = Qwen2TokenizerFast.from_pretrained("Qwen/Qwen1.5-7B-Chat")
|
||||
# input = 'Hello, my dog is cute'
|
||||
# tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
|
||||
# -----------------------------------
|
||||
|
||||
input_ids = torch.Tensor(
|
||||
[[9707, 11, 847, 5562, 374, 13, 123, 18838], [9707, 11, 847, 5562, 374, 17, 89, 18838]]
|
||||
).long()
|
||||
attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]]).long()
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
# label is needed for casual lm
|
||||
def data_gen_for_casual_lm():
|
||||
data = data_gen()
|
||||
labels = data["input_ids"].clone()
|
||||
data["labels"] = labels
|
||||
return data
|
||||
|
||||
# transform the output to a dict
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
# function to get the loss
|
||||
loss_fn = lambda output: output["last_hidden_state"].mean()
|
||||
loss_fn_for_casual_lm = lambda output: output["loss"]
|
||||
loss_fn_for_seq_classification = lambda output: output["logits"].mean()
|
||||
|
||||
config = Qwen2Config(
|
||||
hidden_size=128,
|
||||
intermediate_size=256,
|
||||
max_window_layers=4,
|
||||
num_attention_heads=16,
|
||||
num_hidden_layers=4,
|
||||
num_key_value_heads=16,
|
||||
)
|
||||
|
||||
config.pad_token_id = 0
|
||||
|
||||
# register the following models
|
||||
# transformers.Qwen2Model,
|
||||
# transformers.Qwen2ForCausalLM,
|
||||
# transformers.Qwen2ForSequenceClassification,
|
||||
model_zoo.register(
|
||||
name="transformers_qwen2",
|
||||
model_fn=lambda: transformers.Qwen2Model(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_qwen2_for_casual_lm",
|
||||
model_fn=lambda: transformers.Qwen2ForCausalLM(config),
|
||||
data_gen_fn=data_gen_for_casual_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_casual_lm,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_qwen2_for_sequence_classification",
|
||||
model_fn=lambda: transformers.Qwen2ForSequenceClassification(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_seq_classification,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
|
@ -0,0 +1,235 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer.layer.utils import Randomizer
|
||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import (
|
||||
build_model_from_hybrid_plugin,
|
||||
check_all_grad_tensors,
|
||||
check_loss,
|
||||
check_output_hidden_state,
|
||||
check_weight,
|
||||
get_grad_tensors_for_check,
|
||||
run_forward_backward_with_hybrid_plugin,
|
||||
unwrap_model,
|
||||
)
|
||||
|
||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
||||
|
||||
|
||||
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
|
||||
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
|
||||
model_fn, loss_fn, test_config
|
||||
)
|
||||
|
||||
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
|
||||
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
|
||||
)
|
||||
|
||||
stage_manager = booster.plugin.stage_manager
|
||||
tp_group = booster.plugin.tp_group
|
||||
|
||||
# unwrap model
|
||||
qwen2_model = unwrap_model(org_model, "Qwen2Model", "model")
|
||||
shard_qwen2_model = unwrap_model(sharded_model, "Qwen2Model", "model")
|
||||
|
||||
row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"]
|
||||
col_layer_for_check = ["layers[0].self_attn.o_proj"]
|
||||
|
||||
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
|
||||
grads_to_check = {}
|
||||
if (stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True)) and booster.plugin.zero_stage == 0:
|
||||
if test_config["precision"] == "fp32":
|
||||
atol, rtol = 1e-6, 1e-4
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
row_layer_grads = get_grad_tensors_for_check(
|
||||
qwen2_model, shard_qwen2_model, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False
|
||||
)
|
||||
col_layer_grads = get_grad_tensors_for_check(
|
||||
qwen2_model, shard_qwen2_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
|
||||
)
|
||||
grads_to_check.update(col_layer_grads)
|
||||
grads_to_check.update(row_layer_grads)
|
||||
|
||||
# optimizer executes step
|
||||
org_optimizer.step()
|
||||
sharded_optimizer.step()
|
||||
|
||||
# check last hidden state & loss
|
||||
if stage_manager is None or stage_manager.is_last_stage(ignore_chunk=True):
|
||||
if test_config["precision"] == "fp32":
|
||||
atol, rtol = 1e-5, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
|
||||
if org_model.__class__.__name__ == "Qwen2Model":
|
||||
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
|
||||
|
||||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||
|
||||
# check weights
|
||||
if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
|
||||
if test_config["precision"] == "fp32":
|
||||
atol, rtol = 1e-4, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
check_weight(
|
||||
qwen2_model, shard_qwen2_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
|
||||
)
|
||||
|
||||
# check grads
|
||||
check_all_grad_tensors(grads_to_check)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp32",
|
||||
},
|
||||
{
|
||||
"tp_size": 4,
|
||||
"pp_size": 1,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp32",
|
||||
},
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 4,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": False,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp32",
|
||||
},
|
||||
{"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 2,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 1,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
],
|
||||
)
|
||||
def run_qwen2_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_qwen2")
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
|
||||
clear_layout_converter()
|
||||
Randomizer.reset_index()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": False,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp32",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": False,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp16",
|
||||
"zero_stage": 1,
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"pp_style": "interleaved",
|
||||
"num_model_chunks": 2,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": False,
|
||||
"precision": "fp16",
|
||||
"zero_stage": 1,
|
||||
"initial_scale": 1,
|
||||
},
|
||||
],
|
||||
)
|
||||
def run_qwen2_3d_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_qwen2")
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
|
||||
clear_layout_converter()
|
||||
Randomizer.reset_index()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_qwen2(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_qwen2_test()
|
||||
|
||||
|
||||
def check_qwen2_3d(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_qwen2_3d_test()
|
||||
|
||||
|
||||
@pytest.mark.skipif(transformers.__version__ < "4.39.1", reason="Requires transformers version 4.39.1 or later")
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_qwen2():
|
||||
spawn(check_qwen2, 4)
|
||||
|
||||
|
||||
@pytest.mark.skipif(transformers.__version__ < "4.39.1", reason="Requires transformers version 4.39.1 or later")
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_qwen2_3d():
|
||||
spawn(check_qwen2_3d, 8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_qwen2()
|
||||
test_qwen2_3d()
|
Loading…
Reference in New Issue