mirror of https://github.com/hpcaitech/ColossalAI
feat: support qwen2 model
parent
61545fcfee
commit
5c2a47a667
|
@ -0,0 +1,604 @@
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
from transformers.modeling_outputs import (
|
||||||
|
BaseModelOutputWithPast,
|
||||||
|
CausalLMOutputWithPast,
|
||||||
|
SequenceClassifierOutputWithPast,
|
||||||
|
)
|
||||||
|
from transformers.models.qwen2.modeling_qwen2 import (
|
||||||
|
Qwen2ForCausalLM,
|
||||||
|
Qwen2ForSequenceClassification,
|
||||||
|
Qwen2Model,
|
||||||
|
_prepare_4d_causal_attention_mask,
|
||||||
|
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||||
|
)
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
from colossalai.shardformer.shard import ShardConfig
|
||||||
|
|
||||||
|
from ..layer import cross_entropy_1d
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2PipelineForwards:
|
||||||
|
"""
|
||||||
|
This class serves as a micro library for forward function substitution of Qwen2 models
|
||||||
|
under pipeline setting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def qwen2_model_forward(
|
||||||
|
self: Qwen2Model,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
stage_index: Optional[List[int]] = None,
|
||||||
|
shard_config: ShardConfig = None,
|
||||||
|
):
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# retrieve input_ids and inputs_embeds
|
||||||
|
if stage_manager.is_first_stage():
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||||
|
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
else:
|
||||||
|
input_shape = hidden_states.shape[:-1]
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
|
device = hidden_states.device
|
||||||
|
|
||||||
|
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||||
|
if output_attentions:
|
||||||
|
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||||
|
output_attentions = False
|
||||||
|
if output_hidden_states:
|
||||||
|
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||||
|
output_hidden_states = False
|
||||||
|
if use_cache:
|
||||||
|
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
assert past_key_values is None, "past_key_values is not supported for Qwen2 models at the moment."
|
||||||
|
|
||||||
|
past_key_values_length = 0
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = torch.arange(
|
||||||
|
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||||
|
)
|
||||||
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||||
|
else:
|
||||||
|
position_ids = position_ids.view(-1, seq_length).long()
|
||||||
|
|
||||||
|
if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
|
||||||
|
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
||||||
|
if is_padding_right:
|
||||||
|
raise ValueError(
|
||||||
|
"You are attempting to perform batched generation with padding_side='right'"
|
||||||
|
" this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to "
|
||||||
|
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._attn_implementation == "flash_attention_2":
|
||||||
|
# 2d mask is passed through the layers
|
||||||
|
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||||
|
elif self._attn_implementation == "sdpa" and not output_attentions:
|
||||||
|
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
||||||
|
# the manual implementation that requires a 4D causal mask in all cases.
|
||||||
|
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||||
|
attention_mask,
|
||||||
|
(batch_size, seq_length),
|
||||||
|
inputs_embeds,
|
||||||
|
past_key_values_length,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 4d mask is passed through the layers
|
||||||
|
attention_mask = _prepare_4d_causal_attention_mask(
|
||||||
|
attention_mask,
|
||||||
|
(batch_size, seq_length),
|
||||||
|
inputs_embeds,
|
||||||
|
past_key_values_length,
|
||||||
|
sliding_window=self.config.sliding_window,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
if use_cache:
|
||||||
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
next_decoder_cache = None
|
||||||
|
|
||||||
|
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||||
|
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
# None for past_key_value
|
||||||
|
return module(*inputs, output_attentions, None)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(decoder_layer),
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
if stage_manager.is_last_stage():
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
|
|
||||||
|
if stage_manager.is_last_stage():
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
)
|
||||||
|
# always return dict for imediate stage
|
||||||
|
return {"hidden_states": hidden_states}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def qwen2_for_causal_lm_forward(
|
||||||
|
self: Qwen2ForCausalLM,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
stage_index: Optional[List[int]] = None,
|
||||||
|
shard_config: ShardConfig = None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, Qwen2ForCausalLM
|
||||||
|
|
||||||
|
>>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||||
|
|
||||||
|
>>> prompt = "Hey, are you consciours? Can you talk to me?"
|
||||||
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
|
||||||
|
```"""
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||||
|
if output_attentions:
|
||||||
|
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||||
|
output_attentions = False
|
||||||
|
if output_hidden_states:
|
||||||
|
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||||
|
output_hidden_states = False
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = Qwen2PipelineForwards.qwen2_model_forward(
|
||||||
|
self.model,
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
stage_manager=stage_manager,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
stage_index=stage_index,
|
||||||
|
)
|
||||||
|
|
||||||
|
if stage_manager.is_last_stage():
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
if shard_config.enable_tensor_parallelism:
|
||||||
|
new_vocab_size = logits.shape[-1]
|
||||||
|
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||||
|
loss = cross_entropy_1d(
|
||||||
|
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states = outputs.get("hidden_states")
|
||||||
|
return {"hidden_states": hidden_states}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def qwen2_for_sequence_classification_forward(
|
||||||
|
self: Qwen2ForSequenceClassification,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
stage_index: Optional[List[int]] = None,
|
||||||
|
shard_config: ShardConfig = None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||||
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||||
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||||
|
"""
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||||
|
if output_attentions:
|
||||||
|
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||||
|
output_attentions = False
|
||||||
|
if output_hidden_states:
|
||||||
|
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||||
|
output_hidden_states = False
|
||||||
|
|
||||||
|
transformer_outputs = Qwen2PipelineForwards.qwen2_model_forward(
|
||||||
|
self.model,
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
stage_manager=stage_manager,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
stage_index=stage_index,
|
||||||
|
)
|
||||||
|
|
||||||
|
if input_ids is not None:
|
||||||
|
batch_size = input_ids.shape[0]
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
batch_size = inputs_embeds.shape[0]
|
||||||
|
else:
|
||||||
|
batch_size = hidden_states.shape[0]
|
||||||
|
|
||||||
|
if stage_manager.is_last_stage():
|
||||||
|
hidden_states = transformer_outputs[0]
|
||||||
|
logits = self.score(hidden_states)
|
||||||
|
|
||||||
|
if self.config.pad_token_id is None and batch_size != 1:
|
||||||
|
print(self.config.pad_token_id)
|
||||||
|
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||||
|
if self.config.pad_token_id is None:
|
||||||
|
sequence_lengths = -1
|
||||||
|
else:
|
||||||
|
if input_ids is not None:
|
||||||
|
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
|
||||||
|
else:
|
||||||
|
sequence_lengths = -1
|
||||||
|
|
||||||
|
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
labels = labels.to(logits.device)
|
||||||
|
if self.config.problem_type is None:
|
||||||
|
if self.num_labels == 1:
|
||||||
|
self.config.problem_type = "regression"
|
||||||
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||||
|
self.config.problem_type = "single_label_classification"
|
||||||
|
else:
|
||||||
|
self.config.problem_type = "multi_label_classification"
|
||||||
|
|
||||||
|
if self.config.problem_type == "regression":
|
||||||
|
loss_fct = MSELoss()
|
||||||
|
if self.num_labels == 1:
|
||||||
|
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
||||||
|
else:
|
||||||
|
loss = loss_fct(pooled_logits, labels)
|
||||||
|
elif self.config.problem_type == "single_label_classification":
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
|
elif self.config.problem_type == "multi_label_classification":
|
||||||
|
loss_fct = BCEWithLogitsLoss()
|
||||||
|
loss = loss_fct(pooled_logits, labels)
|
||||||
|
if not return_dict:
|
||||||
|
output = (pooled_logits,) + transformer_outputs[1:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return SequenceClassifierOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=pooled_logits,
|
||||||
|
past_key_values=transformer_outputs.past_key_values,
|
||||||
|
hidden_states=transformer_outputs.hidden_states,
|
||||||
|
attentions=transformer_outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
hidden_states = transformer_outputs.get("hidden_states")
|
||||||
|
return {"hidden_states": hidden_states}
|
||||||
|
|
||||||
|
|
||||||
|
def get_qwen2_flash_attention_forward(shard_config: ShardConfig):
|
||||||
|
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, apply_rotary_pos_emb, repeat_kv
|
||||||
|
|
||||||
|
from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self: Qwen2Attention,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||||
|
|
||||||
|
assert past_key_value is None, "past_key_value is not supported for Qwen2 models at the moment."
|
||||||
|
|
||||||
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
me_input_shape = (bsz, q_len, self.num_heads, self.head_dim)
|
||||||
|
query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape)
|
||||||
|
key_states = key_states.transpose(1, 2).contiguous().view(*me_input_shape)
|
||||||
|
value_states = value_states.transpose(1, 2).contiguous().view(*me_input_shape)
|
||||||
|
|
||||||
|
flash_attention_mask = None
|
||||||
|
attn_mask_type = AttnMaskType.causal
|
||||||
|
if not getattr(shard_config, "causal_lm", False) and attention_mask != None:
|
||||||
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||||
|
)
|
||||||
|
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
|
||||||
|
attn_mask_type = AttnMaskType.paddedcausal
|
||||||
|
|
||||||
|
attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
|
||||||
|
attn_output = attention(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attn_mask=flash_attention_mask,
|
||||||
|
attn_mask_type=attn_mask_type,
|
||||||
|
origin_attn_mask=attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output, None, past_key_value
|
||||||
|
|
||||||
|
return forward
|
||||||
|
|
||||||
|
|
||||||
|
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||||
|
from transformers import Qwen2ForCausalLM
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self: Qwen2ForCausalLM,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, Qwen2ForCausalLM
|
||||||
|
|
||||||
|
>>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||||
|
|
||||||
|
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||||
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||||
|
```"""
|
||||||
|
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
logits = logits.float()
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
if shard_config.enable_tensor_parallelism:
|
||||||
|
new_vocab_size = logits.shape[-1]
|
||||||
|
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||||
|
loss = cross_entropy_1d(
|
||||||
|
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
return forward
|
|
@ -0,0 +1,336 @@
|
||||||
|
import warnings
|
||||||
|
from functools import partial
|
||||||
|
from typing import Callable, Dict, List, Union
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.nn import Module
|
||||||
|
|
||||||
|
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, RMSNorm, VocabParallelEmbedding1D
|
||||||
|
|
||||||
|
from ..modeling.qwen2 import (
|
||||||
|
Qwen2PipelineForwards,
|
||||||
|
get_lm_forward_with_dist_cross_entropy,
|
||||||
|
get_qwen2_flash_attention_forward,
|
||||||
|
)
|
||||||
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
|
__all__ = ["Qwen2Policy", "Qwen2ForCausalLMPolicy", "Qwen2ForSequenceClassificationPolicy"]
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2Policy(Policy):
|
||||||
|
def config_sanity_check(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def preprocess(self):
|
||||||
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
# Resize embedding
|
||||||
|
vocab_size = self.model.config.vocab_size
|
||||||
|
world_size = self.shard_config.tensor_parallel_size
|
||||||
|
|
||||||
|
if vocab_size % world_size != 0:
|
||||||
|
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||||
|
self.model.resize_token_embeddings(new_vocab_size)
|
||||||
|
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||||
|
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2DecoderLayer, Qwen2Model
|
||||||
|
|
||||||
|
policy = {}
|
||||||
|
|
||||||
|
norm_cls = FusedRMSNorm if self.shard_config.enable_fused_normalization else RMSNorm
|
||||||
|
|
||||||
|
if self.shard_config.enable_sequence_parallelism:
|
||||||
|
self.shard_config.enable_sequence_parallelism = False
|
||||||
|
warnings.warn("Qwen2 doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||||
|
|
||||||
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
decoder_attribute_replacement = {
|
||||||
|
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||||
|
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||||
|
}
|
||||||
|
if getattr(self.model.config, "num_key_value_heads", False):
|
||||||
|
decoder_attribute_replacement["self_attn.num_key_value_heads"] = (
|
||||||
|
self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size
|
||||||
|
)
|
||||||
|
|
||||||
|
policy[Qwen2DecoderLayer] = ModulePolicyDescription(
|
||||||
|
attribute_replacement=decoder_attribute_replacement,
|
||||||
|
sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.q_proj",
|
||||||
|
target_module=Linear1D_Col,
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.k_proj",
|
||||||
|
target_module=Linear1D_Col,
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.v_proj",
|
||||||
|
target_module=Linear1D_Col,
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn.o_proj",
|
||||||
|
target_module=Linear1D_Row,
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="mlp.gate_proj",
|
||||||
|
target_module=Linear1D_Col,
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="mlp.up_proj",
|
||||||
|
target_module=Linear1D_Col,
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="mlp.down_proj",
|
||||||
|
target_module=Linear1D_Row,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.append_or_create_submodule_replacement(
|
||||||
|
description=SubModuleReplacementDescription(
|
||||||
|
suffix="embed_tokens",
|
||||||
|
target_module=VocabParallelEmbedding1D,
|
||||||
|
),
|
||||||
|
policy=policy,
|
||||||
|
target_key=Qwen2Model,
|
||||||
|
)
|
||||||
|
|
||||||
|
# optimization configuration
|
||||||
|
self.append_or_create_submodule_replacement(
|
||||||
|
description=[
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="input_layernorm",
|
||||||
|
target_module=norm_cls,
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="post_attention_layernorm",
|
||||||
|
target_module=norm_cls,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
policy=policy,
|
||||||
|
target_key=Qwen2DecoderLayer,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.append_or_create_submodule_replacement(
|
||||||
|
description=SubModuleReplacementDescription(
|
||||||
|
suffix="norm",
|
||||||
|
target_module=norm_cls,
|
||||||
|
),
|
||||||
|
policy=policy,
|
||||||
|
target_key=Qwen2Model,
|
||||||
|
)
|
||||||
|
|
||||||
|
# use flash attention
|
||||||
|
if self.shard_config.enable_flash_attention:
|
||||||
|
self.append_or_create_method_replacement(
|
||||||
|
description={
|
||||||
|
"forward": get_qwen2_flash_attention_forward(self.shard_config),
|
||||||
|
},
|
||||||
|
policy=policy,
|
||||||
|
target_key=Qwen2Attention,
|
||||||
|
)
|
||||||
|
|
||||||
|
return policy
|
||||||
|
|
||||||
|
def postprocess(self):
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
||||||
|
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
||||||
|
to customized forward method, and add this changing to policy."""
|
||||||
|
if self.pipeline_stage_manager is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
if self.model.__class__.__name__ == "Qwen2Model":
|
||||||
|
module = self.model
|
||||||
|
else:
|
||||||
|
module = self.model.model
|
||||||
|
|
||||||
|
if stage_manager.is_interleave:
|
||||||
|
layers_per_stage = self.distribute_layers(
|
||||||
|
len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks
|
||||||
|
)
|
||||||
|
stage_manager.stage_indices = Policy.get_stage_index(
|
||||||
|
layers_per_stage,
|
||||||
|
stage_manager.stage,
|
||||||
|
num_model_chunks=stage_manager.num_model_chunks,
|
||||||
|
num_stages=stage_manager.num_stages,
|
||||||
|
)
|
||||||
|
method_replacement = {
|
||||||
|
"forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config)
|
||||||
|
}
|
||||||
|
|
||||||
|
else:
|
||||||
|
layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages)
|
||||||
|
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||||
|
method_replacement = {
|
||||||
|
"forward": partial(
|
||||||
|
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
|
||||||
|
)
|
||||||
|
}
|
||||||
|
self.append_or_create_method_replacement(
|
||||||
|
description=method_replacement, policy=policy, target_key=model_cls
|
||||||
|
)
|
||||||
|
|
||||||
|
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
|
||||||
|
|
||||||
|
def get_held_layers(self) -> List[Module]:
|
||||||
|
"""Get pipeline layers for current stage."""
|
||||||
|
assert self.pipeline_stage_manager is not None
|
||||||
|
|
||||||
|
if self.model.__class__.__name__ == "Qwen2Model":
|
||||||
|
module = self.model
|
||||||
|
else:
|
||||||
|
module = self.model.model
|
||||||
|
|
||||||
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
|
||||||
|
held_layers = []
|
||||||
|
if stage_manager.is_interleave:
|
||||||
|
assert stage_manager.num_model_chunks is not None
|
||||||
|
layers_per_stage = self.distribute_layers(
|
||||||
|
len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks
|
||||||
|
)
|
||||||
|
stage_indices = Policy.get_stage_index(
|
||||||
|
layers_per_stage,
|
||||||
|
stage_manager.stage,
|
||||||
|
num_model_chunks=stage_manager.num_model_chunks,
|
||||||
|
num_stages=stage_manager.num_stages,
|
||||||
|
)
|
||||||
|
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
|
held_layers.append(module.embed_tokens)
|
||||||
|
for start_idx, end_idx in stage_indices:
|
||||||
|
held_layers.extend(module.layers[start_idx:end_idx])
|
||||||
|
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
|
held_layers.append(module.norm)
|
||||||
|
|
||||||
|
else:
|
||||||
|
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
||||||
|
if stage_manager.is_first_stage():
|
||||||
|
held_layers.append(module.embed_tokens)
|
||||||
|
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||||
|
held_layers.extend(module.layers[start_idx:end_idx])
|
||||||
|
if stage_manager.is_last_stage():
|
||||||
|
held_layers.append(module.norm)
|
||||||
|
|
||||||
|
return held_layers
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2ModelPolicy(Qwen2Policy):
|
||||||
|
def module_policy(self):
|
||||||
|
policy = super().module_policy()
|
||||||
|
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model
|
||||||
|
|
||||||
|
if self.pipeline_stage_manager:
|
||||||
|
# set None as default
|
||||||
|
self.set_pipeline_forward(
|
||||||
|
model_cls=Qwen2Model, new_forward=Qwen2PipelineForwards.qwen2_model_forward, policy=policy
|
||||||
|
)
|
||||||
|
return policy
|
||||||
|
|
||||||
|
def get_held_layers(self) -> List[Module]:
|
||||||
|
"""Get pipeline layers for current stage."""
|
||||||
|
held_layers = super().get_held_layers()
|
||||||
|
return held_layers
|
||||||
|
|
||||||
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||||
|
"""No shared params in Qwen2 model"""
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2ForCausalLMPolicy(Qwen2Policy):
|
||||||
|
def module_policy(self):
|
||||||
|
from transformers import Qwen2ForCausalLM
|
||||||
|
|
||||||
|
policy = super().module_policy()
|
||||||
|
|
||||||
|
setattr(self.shard_config, "causal_lm", True)
|
||||||
|
|
||||||
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
# add a new item for casual lm
|
||||||
|
new_item = {
|
||||||
|
Qwen2ForCausalLM: ModulePolicyDescription(
|
||||||
|
sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(suffix="lm_head", target_module=Linear1D_Col)
|
||||||
|
],
|
||||||
|
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
policy.update(new_item)
|
||||||
|
|
||||||
|
if self.pipeline_stage_manager:
|
||||||
|
# set None as default
|
||||||
|
self.set_pipeline_forward(
|
||||||
|
model_cls=Qwen2ForCausalLM, new_forward=Qwen2PipelineForwards.qwen2_for_causal_lm_forward, policy=policy
|
||||||
|
)
|
||||||
|
|
||||||
|
return policy
|
||||||
|
|
||||||
|
def get_held_layers(self) -> List[Module]:
|
||||||
|
"""Get pipeline layers for current stage."""
|
||||||
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
held_layers = super().get_held_layers()
|
||||||
|
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
|
held_layers.append(self.model.lm_head)
|
||||||
|
return held_layers
|
||||||
|
|
||||||
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||||
|
qwen2_model = self.model.model
|
||||||
|
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
||||||
|
if (
|
||||||
|
id(qwen2_model.embed_tokens.weight) == id(self.model.lm_head.weight)
|
||||||
|
and self.pipeline_stage_manager.num_stages > 1
|
||||||
|
):
|
||||||
|
# tie weights
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
0: qwen2_model.embed_tokens.weight,
|
||||||
|
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2ForSequenceClassificationPolicy(Qwen2Policy):
|
||||||
|
def module_policy(self):
|
||||||
|
from transformers import Qwen2ForSequenceClassification
|
||||||
|
|
||||||
|
policy = super().module_policy()
|
||||||
|
|
||||||
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
# add a new item for sequence classification
|
||||||
|
new_item = {
|
||||||
|
Qwen2ForSequenceClassification: ModulePolicyDescription(
|
||||||
|
sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
policy.update(new_item)
|
||||||
|
# to be confirmed
|
||||||
|
if self.pipeline_stage_manager:
|
||||||
|
# set None as default
|
||||||
|
self.set_pipeline_forward(
|
||||||
|
model_cls=Qwen2ForSequenceClassification,
|
||||||
|
new_forward=Qwen2PipelineForwards.qwen2_for_sequence_classification_forward,
|
||||||
|
policy=policy,
|
||||||
|
)
|
||||||
|
return policy
|
||||||
|
|
||||||
|
def get_held_layers(self) -> List[Module]:
|
||||||
|
"""Get pipeline layers for current stage."""
|
||||||
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
held_layers = super().get_held_layers()
|
||||||
|
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
|
held_layers.append(self.model.score)
|
||||||
|
return held_layers
|
||||||
|
|
||||||
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||||
|
"""No shared params in Qwen2 for sequence classification model"""
|
||||||
|
return []
|
Loading…
Reference in New Issue