fix(chat): fix stream_chat to return generator (#123)

pull/302/head
zhjunqin 2023-09-10 23:46:45 +08:00 committed by GitHub
parent 2ec20707d0
commit 8420115b5e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 53 additions and 17 deletions

View File

@ -20,6 +20,7 @@
""" PyTorch InternLM model.""" """ PyTorch InternLM model."""
import math import math
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import threading, queue
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
@ -810,35 +811,70 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
temperature: float = 0.8, temperature: float = 0.8,
top_p: float = 0.8, top_p: float = 0.8,
**kwargs): **kwargs):
"""
Return a generator in format: (response, history)
Eg.
('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')])
('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')])
"""
response_queue = queue.Queue(maxsize=20)
class ChatStreamer(BaseStreamer): class ChatStreamer(BaseStreamer):
def __init__(self, tokenizer) -> None: def __init__(self, tokenizer) -> None:
super().__init__() super().__init__()
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.queue = response_queue
self.query = query
self.history = history
self.response = ""
self.received_inputs = False
self.queue.put((self.response, history + [(self.query, self.response)]))
def put(self, value): def put(self, value):
if len(value.shape) > 1 and value.shape[0] > 1: if len(value.shape) > 1 and value.shape[0] > 1:
raise ValueError("ChatStreamer only supports batch size 1") raise ValueError("ChatStreamer only supports batch size 1")
elif len(value.shape) > 1: elif len(value.shape) > 1:
value = value[0] value = value[0]
if not self.received_inputs:
# The first received value is input_ids, ignore here
self.received_inputs = True
return
token = self.tokenizer.decode([value[-1]], skip_special_tokens=True) token = self.tokenizer.decode([value[-1]], skip_special_tokens=True)
if token.strip() != "<eoa>": if token.strip() != "<eoa>":
print(token, end="") self.response = self.response + token
history = self.history + [(self.query, self.response)]
self.queue.put((self.response, history))
def end(self): def end(self):
print("") self.queue.put(None)
return self.chat( def stream_producer():
tokenizer=tokenizer, return self.chat(
query=query, tokenizer=tokenizer,
streamer=ChatStreamer(tokenizer=tokenizer), query=query,
history=history, streamer=ChatStreamer(tokenizer=tokenizer),
max_new_tokens=max_new_tokens, history=history,
do_sample=do_sample, max_new_tokens=max_new_tokens,
temperature=temperature, do_sample=do_sample,
top_p=top_p, temperature=temperature,
**kwargs top_p=top_p,
) **kwargs
)
def consumer():
producer = threading.Thread(target=stream_producer)
producer.start()
while True:
res = response_queue.get()
if res is not None:
return
yield res
return consumer()
@add_start_docstrings( @add_start_docstrings(
""" """