fix(chat): fix stream_chat to return generator (#123)

pull/302/head
zhjunqin 2023-09-10 23:46:45 +08:00 committed by GitHub
parent 2ec20707d0
commit 8420115b5e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 53 additions and 17 deletions

View File

@ -20,6 +20,7 @@
""" PyTorch InternLM model."""
import math
from typing import List, Optional, Tuple, Union
import threading, queue
import torch
import torch.utils.checkpoint
@ -810,35 +811,70 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
temperature: float = 0.8,
top_p: float = 0.8,
**kwargs):
"""
Return a generator in format: (response, history)
Eg.
('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')])
('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')])
"""
response_queue = queue.Queue(maxsize=20)
class ChatStreamer(BaseStreamer):
def __init__(self, tokenizer) -> None:
super().__init__()
self.tokenizer = tokenizer
self.queue = response_queue
self.query = query
self.history = history
self.response = ""
self.received_inputs = False
self.queue.put((self.response, history + [(self.query, self.response)]))
def put(self, value):
if len(value.shape) > 1 and value.shape[0] > 1:
raise ValueError("ChatStreamer only supports batch size 1")
elif len(value.shape) > 1:
value = value[0]
if not self.received_inputs:
# The first received value is input_ids, ignore here
self.received_inputs = True
return
token = self.tokenizer.decode([value[-1]], skip_special_tokens=True)
if token.strip() != "<eoa>":
print(token, end="")
self.response = self.response + token
history = self.history + [(self.query, self.response)]
self.queue.put((self.response, history))
def end(self):
print("")
return self.chat(
tokenizer=tokenizer,
query=query,
streamer=ChatStreamer(tokenizer=tokenizer),
history=history,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
**kwargs
)
self.queue.put(None)
def stream_producer():
return self.chat(
tokenizer=tokenizer,
query=query,
streamer=ChatStreamer(tokenizer=tokenizer),
history=history,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
**kwargs
)
def consumer():
producer = threading.Thread(target=stream_producer)
producer.start()
while True:
res = response_queue.get()
if res is not None:
return
yield res
return consumer()
@add_start_docstrings(
"""