[inference] decouple pipeline logci for chatglm (#5098)

* [inference] decouple pipeline logci for chatglm

* [inference] fix chatglm modeling
refactor/inference
Hongxin Liu 2023-11-22 18:26:39 +08:00 committed by GitHub
parent cb450c2861
commit f196f40a8f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 24 additions and 6 deletions

View File

@ -156,7 +156,8 @@ class InferenceEngine:
input_list, self.max_input_len, self.max_output_len, self.cache_manager_list[0]
)
# bind the infer state to the model (not lm model)
self.model.model.infer_state = batch_infer_state
model_to_bind = self.model.model if hasattr(self.model, "model") else self.model.transformer
model_to_bind.infer_state = batch_infer_state
if generation_config is not None:
generation_config.max_new_tokens = self.max_output_len
else:

View File

@ -1,6 +1,7 @@
from typing import List, Optional, Tuple
import torch
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.utils import logging
from colossalai.inference.kv_cache import BatchInferState
@ -83,6 +84,7 @@ class ChatGLM2InferenceForwards:
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
):
# This function is only used when pipeline is enabled.
logger = logging.get_logger(__name__)
if output_attentions:
@ -136,11 +138,12 @@ class ChatGLM2InferenceForwards:
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
):
infer_state = infer_state or getattr(self, "infer_state", None)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if stage_manager.is_first_stage():
if stage_manager is None or stage_manager.is_first_stage():
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
@ -229,7 +232,7 @@ class ChatGLM2InferenceForwards:
)
# Run encoder.
hidden_states = self.encoder(
hidden_states, next_cache = self.encoder(
hidden_states,
full_attention_mask,
kv_caches=past_key_values,
@ -246,6 +249,11 @@ class ChatGLM2InferenceForwards:
infer_state.seq_len += 1
infer_state.max_len_in_batch += 1
if stage_manager is None:
if not return_dict:
return (hidden_states, next_cache)
return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=next_cache)
return {"hidden_states": hidden_states}
@staticmethod
@ -264,10 +272,15 @@ class ChatGLM2InferenceForwards:
hidden_states = hidden_states.transpose(0, 1).contiguous()
infer_state.decode_layer_id = 0
if stage_index is None:
stage_index = (0, len(self.layers))
start_idx, end_idx = stage_index[0], stage_index[1]
if kv_caches is None:
kv_caches = tuple([None] * (end_idx - start_idx + 1))
# for HF api compatibility, kv-cache must be returned
next_decoder_cache = () if use_cache else None
for idx, kv_cache in zip(range(start_idx, end_idx), kv_caches):
layer = self.layers[idx]
layer_ret = layer(
@ -279,15 +292,19 @@ class ChatGLM2InferenceForwards:
)
infer_state.decode_layer_id += 1
hidden_states, _ = layer_ret
hidden_states, next_kv_cache = layer_ret
if use_cache:
next_decoder_cache += (next_kv_cache,)
hidden_states = hidden_states.transpose(0, 1).contiguous()
if self.post_layer_norm and (stage_manager.is_last_stage() or stage_manager.num_stages == 1):
if self.post_layer_norm and (stage_manager is None or stage_manager.is_last_stage()):
# Final layer norm.
hidden_states = self.final_layernorm(hidden_states)
return hidden_states
next_cache = next_decoder_cache if use_cache else None
return hidden_states, next_cache
@staticmethod
def chatglm_glmblock_forward(