diff --git a/stream_cli_demo.py b/stream_cli_demo.py new file mode 100644 index 0000000..3fae35b --- /dev/null +++ b/stream_cli_demo.py @@ -0,0 +1,65 @@ +import os +from transformers import AutoTokenizer, AutoModel +import signal +import platform +from stream_utils import ChatGLMStreamDecoder + + +tokenizer = AutoTokenizer.from_pretrained( + "THUDM/chatglm-6b", trust_remote_code=True) +stream_decoder = ChatGLMStreamDecoder(tokenizer.sp_tokenizer.text_tokenizer.sp) +model = AutoModel.from_pretrained( + "THUDM/chatglm-6b", trust_remote_code=True).half().cuda() +model = model.eval() + +os_name = platform.system() +clear_command = 'cls' if os_name == 'Windows' else 'clear' +stop_stream = False + + +def signal_handler(signal, frame): + global stop_stream + stop_stream = True + + +def main(): + history = [] + global stop_stream + print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") + while True: + query = input("\n用户:") + if query.strip() == "stop": + break + if query.strip() == "clear": + history = [] + stream_decoder.end() + stream_decoder.get() + os.system(clear_command) + print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") + continue + gen_kwargs = {"max_length": 2048, "do_sample": True, "top_p": 0.7, + "temperature": 0.95, "logits_processor": None} + if not history: + prompt = query + else: + prompt = "".join([f"[Round {i}]\n问:{q}\n答:{r}\n" for i, (q, r) in enumerate( + history)] + [f"[Round {len(history)}]\n问:{query}\n答:"]) + inputs = tokenizer([prompt], return_tensors="pt").to(model.device) + print("\nChatGLM-6B:", end="") + response = [] + for outputs in model.stream_generate(**inputs, **gen_kwargs): + stream_decoder.put([int(outputs[0][-1])]) + new_resp = stream_decoder.get().replace("", "\n") + response.append(new_resp) + print(new_resp, end="") + # end of line + stream_decoder.end() + new_resp = stream_decoder.get().replace("", "\n") + response.append(new_resp) + print(new_resp) + response = "".join(response) + history.append((query, response)) + + +if __name__ == "__main__": + main() diff --git a/stream_utils.py b/stream_utils.py new file mode 100644 index 0000000..bf7f519 --- /dev/null +++ b/stream_utils.py @@ -0,0 +1,205 @@ +import sentencepiece as spm +from typing import Tuple +import re +import unittest + +# python implantation of https://github.com/google/sentencepiece/blob/master/src/sentencepiece_processor.cc + + +def DecodeSentencePiece(piece: str, id: int, is_bos_ws: bool, sp: spm.SentencePieceProcessor, add_dummy_prefix=True, remove_extra_whitespaces=False) -> Tuple[str, bool]: + ''' + Returns decoded piece and a boolean indicating if the function has consumed + a bos whitespace token (a piece starting with a kSpaceSymbol). This is used + to strip only the first whitespace token from the decoded sequence for + add_dummy_prefix. + ''' + if sp.IsControl(id): # , + return "", False # invisible symbol. + elif sp.IsUnknown(id): + if sp.IdToPiece(id) == piece: # + return SPStreamDecoder.DefaultUnknownSymbol, False + else: # return piece when piece is not . + return piece, False + has_bos_ws = False # whether the token starts with a kSpaceSymbol + # Consume if the current position is bos and + # piece starts with kSpaceSymbol. + if is_bos_ws and (add_dummy_prefix or remove_extra_whitespaces): + t = piece.removeprefix(SPStreamDecoder.SpaceSymbol) + has_bos_ws = t != piece + piece = t + # if we are removing extra whitespace, we remove all leading whitespace + if remove_extra_whitespaces: + has_bos_ws = False + return piece.replace(SPStreamDecoder.SpaceSymbol, " "), has_bos_ws + + +def ProcessBytePieces(pieces: list[str]) -> str: + ''' + Modified version of original code + ''' + if len(pieces) == 0: + return "" + surfaces = "" + # Constructs byte sequence. + bytes_ = bytes([int(piece[1:-1], base=16) for piece in pieces]) + # Set surfaces of `bytes` for each Unicode character. + while len(bytes_) > 0: + try: + surfaces += bytes_.decode('utf-8') + break + except UnicodeDecodeError as e: + # The byte piece at `e.start` is structurally invalid. Map it to + # REPLACEMENT CHARACTER (U+FFFD). + surfaces += bytes_[:e.start].decode('utf-8') + surfaces += SPStreamDecoder.ReplacementCharacter + bytes_ = bytes_[e.end:] + continue + return surfaces + + +class SPStreamDecoder: + SpaceSymbol = chr(0x2581) + DefaultUnknownSymbol = chr(0x2047) + ReplacementCharacter = chr(0xFFFD) + + def __init__(self, sp: spm.SentencePieceProcessor, remove_extra_whitespaces=False, add_dummy_prefix=True) -> None: + self._sp = sp + self._bos_ws_seen = False + # 'is_bos_ws': whether we expect a bos ws token to consume. + self._is_bos_ws = True + self._nothing_decoded = True + self._ids = [] + self._decoded = "" + self._ending = False + self.remove_extra_whitespaces = remove_extra_whitespaces + self.add_dummy_prefix = add_dummy_prefix + + def put(self, ids: list[int]) -> None: + self._ending = False + self._ids += ids + self._decode(eos=False) + + def end(self) -> None: + self._decode(eos=True) + self._is_bos_ws = True + self._bos_ws_seen = False + self._nothing_decoded = True + self._ending = True + self._ids = [] + + def _decode(self, eos=False) -> None: + pieces = [self._sp.IdToPiece(i) for i in self._ids] + consumed = 0 + byte_pieces = [] + for i, piece in enumerate(pieces): + if not self._sp.IsByte(self._ids[i]): + self._decoded += ProcessBytePieces(byte_pieces) + consumed += len(byte_pieces) + if len(self._decoded) > 0: + self._nothing_decoded = False + byte_pieces = [] + # if we have seen a bos_ws token or any non-empty token + if self._bos_ws_seen or (not self._nothing_decoded): + self._is_bos_ws = False + decoded, self._bos_ws_seen = DecodeSentencePiece( + piece, self._ids[i], self._is_bos_ws, self._sp) + self._decoded += decoded + consumed += 1 + if len(self._decoded) > 0: + self._nothing_decoded = False + else: + byte_pieces.append(piece) + if eos: + self._decoded += ProcessBytePieces(byte_pieces) + else: + self._ids = self._ids[consumed:] + + def get(self) -> str: + t = self._decoded + self._decoded = "" + return t + + +class ChatGLMStreamDecoder(SPStreamDecoder): + + def get(self) -> str: + # if prefix of special tokens found, wait till it's impossible or end of decode + if "[" in self._decoded and len(self._decoded)-self._decoded.index("[") < 8 and not self._ending: + return "" + if "<" in self._decoded and len(self._decoded)-self._decoded.index("<") < 12 and not self._ending: + return "" + self._ending = False + t = self._decoded + self._decoded = "" + t = t.replace("", "\n") + t = t.replace("[[训练时间]]", "2023年") + punkts = [ + [",", ","], + ["!", "!"], + [":", ":"], + [";", ";"], + ["\?", "?"], + ] + for item in punkts: + t = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], t) + t = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], t) + # for i in range(max_len, 1, -1): + # t = t.replace(f"<|blank_{i}|>", " " * i) + for blank_token in re.findall(r"<\|blank_\d+\|>", t): + t = t.replace(blank_token, " " * + int(re.search(r"\d+", blank_token)[0])) + return t + + +class ChatGLMStreamDecoderTest(unittest.TestCase): + def test_ChatGLM_StreamDecoder(self): + from transformers import AutoTokenizer, AutoModel + test_strings = [ + "你好👋", # multi-byte encoding + "Hello this is ChatGLM!", # normal text + "你好👋 This is ChatGLM!", # multi-byte encoding with tail + "!?.,!?。,", # punctuations + "A\nB", # "" -> "\n" + "[[训练时间]]", # training time token + "[[训练时间]123", # broken training time token + "1 1", # blank token. Note: It's hard to match the results of strip(), so add leading and tailing "1" + "<|blank_8|123", # broken blank token + ] + tokenizer = AutoTokenizer.from_pretrained( + "THUDM/chatglm-6b", trust_remote_code=True) + model = AutoModel.from_pretrained( + "THUDM/chatglm-6b", trust_remote_code=True).half().cuda() + model = model.eval() + encoded_ids = [tokenizer(x)['input_ids'] for x in test_strings] + stream_decoder = ChatGLMStreamDecoder( + tokenizer.sp_tokenizer.text_tokenizer.sp) + # original output + expected_outputs = [model.process_response( + tokenizer.decode(x)) for x in encoded_ids] + # decode token by token + decoded_strings_stream_token_by_token = [None for _ in test_strings] + for i in range(len(test_strings)): + res = [] + for t in encoded_ids[i]: + stream_decoder.put([t]) + res.append(stream_decoder.get()) + stream_decoder.end() + res.append(stream_decoder.get()) + res = "".join(res) + decoded_strings_stream_token_by_token[i] = res + # decode all at once + decoded_strings_stream = [None for _ in test_strings] + for i in range(len(test_strings)): + stream_decoder.put(encoded_ids[i]) + stream_decoder.end() + decoded_strings_stream[i] = stream_decoder.get() + for i in range(len(test_strings)): + print( + f"Stream decoder test{i}: expected: '{expected_outputs[i]}', token_by_token: '{decoded_strings_stream_token_by_token[i]}', all at once: '{decoded_strings_stream[i]}'") + self.assertEqual( + expected_outputs[i], decoded_strings_stream_token_by_token[i]) + self.assertEqual(expected_outputs[i], decoded_strings_stream[i]) + + +if __name__ == "__main__": + unittest.main()