|
|
@ -9,6 +9,8 @@ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutpu |
|
|
|
from transformers.utils import logging |
|
|
|
from transformers.utils import logging |
|
|
|
|
|
|
|
|
|
|
|
from colossalai.pipeline.stage_manager import PipelineStageManager |
|
|
|
from colossalai.pipeline.stage_manager import PipelineStageManager |
|
|
|
|
|
|
|
from colossalai.shardformer import ShardConfig |
|
|
|
|
|
|
|
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward |
|
|
|
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig |
|
|
|
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig |
|
|
|
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( |
|
|
|
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( |
|
|
|
ChatGLMForConditionalGeneration, |
|
|
|
ChatGLMForConditionalGeneration, |
|
|
@ -146,6 +148,7 @@ class ChatGLMPipelineForwards: |
|
|
|
stage_manager: Optional[PipelineStageManager] = None, |
|
|
|
stage_manager: Optional[PipelineStageManager] = None, |
|
|
|
hidden_states: Optional[torch.FloatTensor] = None, |
|
|
|
hidden_states: Optional[torch.FloatTensor] = None, |
|
|
|
stage_index: Optional[List[int]] = None, |
|
|
|
stage_index: Optional[List[int]] = None, |
|
|
|
|
|
|
|
shard_config: ShardConfig = None, |
|
|
|
): |
|
|
|
): |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
output_hidden_states = (output_hidden_states |
|
|
|
output_hidden_states = (output_hidden_states |
|
|
@ -198,6 +201,11 @@ class ChatGLMPipelineForwards: |
|
|
|
all_self_attentions = None |
|
|
|
all_self_attentions = None |
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
|
|
start_idx, end_idx = stage_index[0], stage_index[1] |
|
|
|
start_idx, end_idx = stage_index[0], stage_index[1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if shard_config.enable_sequence_parallelism: |
|
|
|
|
|
|
|
hidden_states = split_forward_gather_backward(hidden_states, |
|
|
|
|
|
|
|
dim=0, |
|
|
|
|
|
|
|
process_group=shard_config.tensor_parallel_process_group) |
|
|
|
for idx in range(start_idx, end_idx): |
|
|
|
for idx in range(start_idx, end_idx): |
|
|
|
layer = self.encoder._get_layer(idx) |
|
|
|
layer = self.encoder._get_layer(idx) |
|
|
|
if output_hidden_states: |
|
|
|
if output_hidden_states: |
|
|
@ -214,6 +222,11 @@ class ChatGLMPipelineForwards: |
|
|
|
hidden_states, kv_cache = layer_ret |
|
|
|
hidden_states, kv_cache = layer_ret |
|
|
|
if use_cache: |
|
|
|
if use_cache: |
|
|
|
presents = presents + (kv_cache,) |
|
|
|
presents = presents + (kv_cache,) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if shard_config.enable_sequence_parallelism: |
|
|
|
|
|
|
|
hidden_states = gather_forward_split_backward(hidden_states, |
|
|
|
|
|
|
|
dim=0, |
|
|
|
|
|
|
|
process_group=shard_config.tensor_parallel_process_group) |
|
|
|
if output_hidden_states: |
|
|
|
if output_hidden_states: |
|
|
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
if stage_manager.is_last_stage(): |
|
|
|
if stage_manager.is_last_stage(): |
|
|
@ -233,23 +246,22 @@ class ChatGLMPipelineForwards: |
|
|
|
return {'hidden_states': hidden_states} |
|
|
|
return {'hidden_states': hidden_states} |
|
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
@staticmethod |
|
|
|
def chatglm_for_conditional_generation_forward( |
|
|
|
def chatglm_for_conditional_generation_forward(self: ChatGLMForConditionalGeneration, |
|
|
|
self: ChatGLMForConditionalGeneration, |
|
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
|
past_key_values: Optional[Tuple[torch.FloatTensor]] = None, |
|
|
|
past_key_values: Optional[Tuple[torch.FloatTensor]] = None, |
|
|
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
|
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
|
|
labels: Optional[torch.Tensor] = None, |
|
|
|
labels: Optional[torch.Tensor] = None, |
|
|
|
use_cache: Optional[bool] = None, |
|
|
|
use_cache: Optional[bool] = None, |
|
|
|
output_attentions: Optional[bool] = None, |
|
|
|
output_attentions: Optional[bool] = None, |
|
|
|
output_hidden_states: Optional[bool] = None, |
|
|
|
output_hidden_states: Optional[bool] = None, |
|
|
|
return_dict: Optional[bool] = None, |
|
|
|
return_dict: Optional[bool] = None, |
|
|
|
return_last_logit: Optional[bool] = False, |
|
|
|
return_last_logit: Optional[bool] = False, |
|
|
|
stage_manager: Optional[PipelineStageManager] = None, |
|
|
|
stage_manager: Optional[PipelineStageManager] = None, |
|
|
|
hidden_states: Optional[torch.FloatTensor] = None, |
|
|
|
hidden_states: Optional[torch.FloatTensor] = None, |
|
|
|
stage_index: Optional[List[int]] = None, |
|
|
|
stage_index: Optional[List[int]] = None, |
|
|
|
shard_config: ShardConfig = None): |
|
|
|
): |
|
|
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
|
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) |
|
|
|
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) |
|
|
@ -266,6 +278,7 @@ class ChatGLMPipelineForwards: |
|
|
|
stage_manager=stage_manager, |
|
|
|
stage_manager=stage_manager, |
|
|
|
hidden_states=hidden_states, |
|
|
|
hidden_states=hidden_states, |
|
|
|
stage_index=stage_index, |
|
|
|
stage_index=stage_index, |
|
|
|
|
|
|
|
shard_config=shard_config, |
|
|
|
) |
|
|
|
) |
|
|
|
if stage_manager.is_last_stage(): |
|
|
|
if stage_manager.is_last_stage(): |
|
|
|
hidden_states = transformer_outputs[0] |
|
|
|
hidden_states = transformer_outputs[0] |
|
|
@ -296,3 +309,91 @@ class ChatGLMPipelineForwards: |
|
|
|
) |
|
|
|
) |
|
|
|
else: |
|
|
|
else: |
|
|
|
return transformer_outputs |
|
|
|
return transformer_outputs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward( |
|
|
|
|
|
|
|
self, |
|
|
|
|
|
|
|
input_ids, |
|
|
|
|
|
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
|
|
|
|
|
attention_mask: Optional[torch.BoolTensor] = None, |
|
|
|
|
|
|
|
full_attention_mask: Optional[torch.BoolTensor] = None, |
|
|
|
|
|
|
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, |
|
|
|
|
|
|
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
|
|
|
|
|
|
use_cache: Optional[bool] = None, |
|
|
|
|
|
|
|
output_hidden_states: Optional[bool] = None, |
|
|
|
|
|
|
|
return_dict: Optional[bool] = None, |
|
|
|
|
|
|
|
): |
|
|
|
|
|
|
|
output_hidden_states = (output_hidden_states |
|
|
|
|
|
|
|
if output_hidden_states is not None else self.config.output_hidden_states) |
|
|
|
|
|
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
|
|
|
|
|
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
batch_size, seq_length = input_ids.shape |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if inputs_embeds is None: |
|
|
|
|
|
|
|
inputs_embeds = self.embedding(input_ids) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.pre_seq_len is not None: |
|
|
|
|
|
|
|
if past_key_values is None: |
|
|
|
|
|
|
|
past_key_values = self.get_prompt( |
|
|
|
|
|
|
|
batch_size=batch_size, |
|
|
|
|
|
|
|
device=input_ids.device, |
|
|
|
|
|
|
|
dtype=inputs_embeds.dtype, |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
if attention_mask is not None: |
|
|
|
|
|
|
|
attention_mask = torch.cat( |
|
|
|
|
|
|
|
[ |
|
|
|
|
|
|
|
attention_mask.new_ones((batch_size, self.pre_seq_len)), |
|
|
|
|
|
|
|
attention_mask, |
|
|
|
|
|
|
|
], |
|
|
|
|
|
|
|
dim=-1, |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if full_attention_mask is None: |
|
|
|
|
|
|
|
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): |
|
|
|
|
|
|
|
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Rotary positional embeddings |
|
|
|
|
|
|
|
rotary_pos_emb = self.rotary_pos_emb(self.seq_length) |
|
|
|
|
|
|
|
if position_ids is not None: |
|
|
|
|
|
|
|
rotary_pos_emb = rotary_pos_emb[position_ids] |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
rotary_pos_emb = rotary_pos_emb[None, :seq_length] |
|
|
|
|
|
|
|
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Run encoder. |
|
|
|
|
|
|
|
# [seq_len, batch_size, hidden_size] -> [seq_len/TP_size, batch_size, hidden_size] |
|
|
|
|
|
|
|
inputs_embeds = split_forward_gather_backward(inputs_embeds, |
|
|
|
|
|
|
|
dim=0, |
|
|
|
|
|
|
|
process_group=shard_config.tensor_parallel_process_group) |
|
|
|
|
|
|
|
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( |
|
|
|
|
|
|
|
inputs_embeds, |
|
|
|
|
|
|
|
full_attention_mask, |
|
|
|
|
|
|
|
rotary_pos_emb=rotary_pos_emb, |
|
|
|
|
|
|
|
kv_caches=past_key_values, |
|
|
|
|
|
|
|
use_cache=use_cache, |
|
|
|
|
|
|
|
output_hidden_states=output_hidden_states, |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hidden_states = gather_forward_split_backward(hidden_states, |
|
|
|
|
|
|
|
dim=0, |
|
|
|
|
|
|
|
process_group=shard_config.tensor_parallel_process_group) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not return_dict: |
|
|
|
|
|
|
|
return tuple(v for v in [ |
|
|
|
|
|
|
|
hidden_states, |
|
|
|
|
|
|
|
presents, |
|
|
|
|
|
|
|
all_hidden_states, |
|
|
|
|
|
|
|
all_self_attentions, |
|
|
|
|
|
|
|
] if v is not None) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return BaseModelOutputWithPast( |
|
|
|
|
|
|
|
last_hidden_state=hidden_states, |
|
|
|
|
|
|
|
past_key_values=presents, |
|
|
|
|
|
|
|
hidden_states=all_hidden_states, |
|
|
|
|
|
|
|
attentions=all_self_attentions, |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return forward |
|
|
|