diff --git a/tools/transformers/modeling_internlm.py b/tools/transformers/modeling_internlm.py index df1e19f..da7aaa0 100644 --- a/tools/transformers/modeling_internlm.py +++ b/tools/transformers/modeling_internlm.py @@ -20,6 +20,7 @@ """ PyTorch InternLM model.""" import math from typing import List, Optional, Tuple, Union +import threading, queue import torch import torch.utils.checkpoint @@ -810,35 +811,70 @@ class InternLMForCausalLM(InternLMPreTrainedModel): temperature: float = 0.8, top_p: float = 0.8, **kwargs): + """ + Return a generator in format: (response, history) + Eg. + ('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')]) + ('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')]) + """ + + response_queue = queue.Queue(maxsize=20) + class ChatStreamer(BaseStreamer): def __init__(self, tokenizer) -> None: super().__init__() self.tokenizer = tokenizer - + self.queue = response_queue + self.query = query + self.history = history + self.response = "" + self.received_inputs = False + self.queue.put((self.response, history + [(self.query, self.response)])) + def put(self, value): if len(value.shape) > 1 and value.shape[0] > 1: raise ValueError("ChatStreamer only supports batch size 1") elif len(value.shape) > 1: value = value[0] + + if not self.received_inputs: + # The first received value is input_ids, ignore here + self.received_inputs = True + return + token = self.tokenizer.decode([value[-1]], skip_special_tokens=True) if token.strip() != "": - print(token, end="") - + self.response = self.response + token + history = self.history + [(self.query, self.response)] + self.queue.put((self.response, history)) + def end(self): - print("") - - return self.chat( - tokenizer=tokenizer, - query=query, - streamer=ChatStreamer(tokenizer=tokenizer), - history=history, - max_new_tokens=max_new_tokens, - do_sample=do_sample, - temperature=temperature, - top_p=top_p, - **kwargs - ) - + self.queue.put(None) + + def stream_producer(): + return self.chat( + tokenizer=tokenizer, + query=query, + streamer=ChatStreamer(tokenizer=tokenizer), + history=history, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + temperature=temperature, + top_p=top_p, + **kwargs + ) + + def consumer(): + producer = threading.Thread(target=stream_producer) + producer.start() + while True: + res = response_queue.get() + if res is not None: + return + yield res + + return consumer() + @add_start_docstrings( """