mirror of https://github.com/hpcaitech/ColossalAI
[inference] decouple pipeline logci for chatglm (#5098)
* [inference] decouple pipeline logci for chatglm * [inference] fix chatglm modelingrefactor/inference
parent
cb450c2861
commit
f196f40a8f
|
@ -156,7 +156,8 @@ class InferenceEngine:
|
|||
input_list, self.max_input_len, self.max_output_len, self.cache_manager_list[0]
|
||||
)
|
||||
# bind the infer state to the model (not lm model)
|
||||
self.model.model.infer_state = batch_infer_state
|
||||
model_to_bind = self.model.model if hasattr(self.model, "model") else self.model.transformer
|
||||
model_to_bind.infer_state = batch_infer_state
|
||||
if generation_config is not None:
|
||||
generation_config.max_new_tokens = self.max_output_len
|
||||
else:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.inference.kv_cache import BatchInferState
|
||||
|
@ -83,6 +84,7 @@ class ChatGLM2InferenceForwards:
|
|||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
):
|
||||
# This function is only used when pipeline is enabled.
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
if output_attentions:
|
||||
|
@ -136,11 +138,12 @@ class ChatGLM2InferenceForwards:
|
|||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
):
|
||||
infer_state = infer_state or getattr(self, "infer_state", None)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
if stage_manager.is_first_stage():
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
|
@ -229,7 +232,7 @@ class ChatGLM2InferenceForwards:
|
|||
)
|
||||
|
||||
# Run encoder.
|
||||
hidden_states = self.encoder(
|
||||
hidden_states, next_cache = self.encoder(
|
||||
hidden_states,
|
||||
full_attention_mask,
|
||||
kv_caches=past_key_values,
|
||||
|
@ -246,6 +249,11 @@ class ChatGLM2InferenceForwards:
|
|||
infer_state.seq_len += 1
|
||||
infer_state.max_len_in_batch += 1
|
||||
|
||||
if stage_manager is None:
|
||||
if not return_dict:
|
||||
return (hidden_states, next_cache)
|
||||
return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=next_cache)
|
||||
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
@staticmethod
|
||||
|
@ -264,10 +272,15 @@ class ChatGLM2InferenceForwards:
|
|||
hidden_states = hidden_states.transpose(0, 1).contiguous()
|
||||
|
||||
infer_state.decode_layer_id = 0
|
||||
|
||||
if stage_index is None:
|
||||
stage_index = (0, len(self.layers))
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
if kv_caches is None:
|
||||
kv_caches = tuple([None] * (end_idx - start_idx + 1))
|
||||
|
||||
# for HF api compatibility, kv-cache must be returned
|
||||
next_decoder_cache = () if use_cache else None
|
||||
for idx, kv_cache in zip(range(start_idx, end_idx), kv_caches):
|
||||
layer = self.layers[idx]
|
||||
layer_ret = layer(
|
||||
|
@ -279,15 +292,19 @@ class ChatGLM2InferenceForwards:
|
|||
)
|
||||
infer_state.decode_layer_id += 1
|
||||
|
||||
hidden_states, _ = layer_ret
|
||||
hidden_states, next_kv_cache = layer_ret
|
||||
if use_cache:
|
||||
next_decoder_cache += (next_kv_cache,)
|
||||
|
||||
hidden_states = hidden_states.transpose(0, 1).contiguous()
|
||||
|
||||
if self.post_layer_norm and (stage_manager.is_last_stage() or stage_manager.num_stages == 1):
|
||||
if self.post_layer_norm and (stage_manager is None or stage_manager.is_last_stage()):
|
||||
# Final layer norm.
|
||||
hidden_states = self.final_layernorm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
|
||||
return hidden_states, next_cache
|
||||
|
||||
@staticmethod
|
||||
def chatglm_glmblock_forward(
|
||||
|
|
Loading…
Reference in New Issue