From 430224bf13ba1329e41505b8ba98b49e2e6088d5 Mon Sep 17 00:00:00 2001 From: lwh9346 Date: Thu, 27 Apr 2023 00:07:42 +0800 Subject: [PATCH] Align output with ChatGLM model Add unit test for stream decoder --- stream_cli_demo.py | 4 +- stream_utils.py | 96 +++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 96 insertions(+), 4 deletions(-) diff --git a/stream_cli_demo.py b/stream_cli_demo.py index 1481e0c..3fae35b 100644 --- a/stream_cli_demo.py +++ b/stream_cli_demo.py @@ -2,12 +2,12 @@ import os from transformers import AutoTokenizer, AutoModel import signal import platform -from stream_utils import SPStreamDecoder +from stream_utils import ChatGLMStreamDecoder tokenizer = AutoTokenizer.from_pretrained( "THUDM/chatglm-6b", trust_remote_code=True) -stream_decoder = SPStreamDecoder(tokenizer.sp_tokenizer.text_tokenizer.sp) +stream_decoder = ChatGLMStreamDecoder(tokenizer.sp_tokenizer.text_tokenizer.sp) model = AutoModel.from_pretrained( "THUDM/chatglm-6b", trust_remote_code=True).half().cuda() model = model.eval() diff --git a/stream_utils.py b/stream_utils.py index 3c6a918..bf7f519 100644 --- a/stream_utils.py +++ b/stream_utils.py @@ -1,5 +1,7 @@ import sentencepiece as spm from typing import Tuple +import re +import unittest # python implantation of https://github.com/google/sentencepiece/blob/master/src/sentencepiece_processor.cc @@ -24,6 +26,7 @@ def DecodeSentencePiece(piece: str, id: int, is_bos_ws: bool, sp: spm.SentencePi if is_bos_ws and (add_dummy_prefix or remove_extra_whitespaces): t = piece.removeprefix(SPStreamDecoder.SpaceSymbol) has_bos_ws = t != piece + piece = t # if we are removing extra whitespace, we remove all leading whitespace if remove_extra_whitespaces: has_bos_ws = False @@ -67,17 +70,21 @@ class SPStreamDecoder: self._nothing_decoded = True self._ids = [] self._decoded = "" + self._ending = False self.remove_extra_whitespaces = remove_extra_whitespaces self.add_dummy_prefix = add_dummy_prefix def put(self, ids: list[int]) -> None: + self._ending = False self._ids += ids self._decode(eos=False) def end(self) -> None: self._decode(eos=True) + self._is_bos_ws = True self._bos_ws_seen = False self._nothing_decoded = True + self._ending = True self._ids = [] def _decode(self, eos=False) -> None: @@ -88,7 +95,7 @@ class SPStreamDecoder: if not self._sp.IsByte(self._ids[i]): self._decoded += ProcessBytePieces(byte_pieces) consumed += len(byte_pieces) - if consumed > 0: + if len(self._decoded) > 0: self._nothing_decoded = False byte_pieces = [] # if we have seen a bos_ws token or any non-empty token @@ -98,7 +105,7 @@ class SPStreamDecoder: piece, self._ids[i], self._is_bos_ws, self._sp) self._decoded += decoded consumed += 1 - if consumed > 0: + if len(self._decoded) > 0: self._nothing_decoded = False else: byte_pieces.append(piece) @@ -111,3 +118,88 @@ class SPStreamDecoder: t = self._decoded self._decoded = "" return t + + +class ChatGLMStreamDecoder(SPStreamDecoder): + + def get(self) -> str: + # if prefix of special tokens found, wait till it's impossible or end of decode + if "[" in self._decoded and len(self._decoded)-self._decoded.index("[") < 8 and not self._ending: + return "" + if "<" in self._decoded and len(self._decoded)-self._decoded.index("<") < 12 and not self._ending: + return "" + self._ending = False + t = self._decoded + self._decoded = "" + t = t.replace("", "\n") + t = t.replace("[[训练时间]]", "2023年") + punkts = [ + [",", ","], + ["!", "!"], + [":", ":"], + [";", ";"], + ["\?", "?"], + ] + for item in punkts: + t = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], t) + t = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], t) + # for i in range(max_len, 1, -1): + # t = t.replace(f"<|blank_{i}|>", " " * i) + for blank_token in re.findall(r"<\|blank_\d+\|>", t): + t = t.replace(blank_token, " " * + int(re.search(r"\d+", blank_token)[0])) + return t + + +class ChatGLMStreamDecoderTest(unittest.TestCase): + def test_ChatGLM_StreamDecoder(self): + from transformers import AutoTokenizer, AutoModel + test_strings = [ + "你好👋", # multi-byte encoding + "Hello this is ChatGLM!", # normal text + "你好👋 This is ChatGLM!", # multi-byte encoding with tail + "!?.,!?。,", # punctuations + "A\nB", # "" -> "\n" + "[[训练时间]]", # training time token + "[[训练时间]123", # broken training time token + "1 1", # blank token. Note: It's hard to match the results of strip(), so add leading and tailing "1" + "<|blank_8|123", # broken blank token + ] + tokenizer = AutoTokenizer.from_pretrained( + "THUDM/chatglm-6b", trust_remote_code=True) + model = AutoModel.from_pretrained( + "THUDM/chatglm-6b", trust_remote_code=True).half().cuda() + model = model.eval() + encoded_ids = [tokenizer(x)['input_ids'] for x in test_strings] + stream_decoder = ChatGLMStreamDecoder( + tokenizer.sp_tokenizer.text_tokenizer.sp) + # original output + expected_outputs = [model.process_response( + tokenizer.decode(x)) for x in encoded_ids] + # decode token by token + decoded_strings_stream_token_by_token = [None for _ in test_strings] + for i in range(len(test_strings)): + res = [] + for t in encoded_ids[i]: + stream_decoder.put([t]) + res.append(stream_decoder.get()) + stream_decoder.end() + res.append(stream_decoder.get()) + res = "".join(res) + decoded_strings_stream_token_by_token[i] = res + # decode all at once + decoded_strings_stream = [None for _ in test_strings] + for i in range(len(test_strings)): + stream_decoder.put(encoded_ids[i]) + stream_decoder.end() + decoded_strings_stream[i] = stream_decoder.get() + for i in range(len(test_strings)): + print( + f"Stream decoder test{i}: expected: '{expected_outputs[i]}', token_by_token: '{decoded_strings_stream_token_by_token[i]}', all at once: '{decoded_strings_stream[i]}'") + self.assertEqual( + expected_outputs[i], decoded_strings_stream_token_by_token[i]) + self.assertEqual(expected_outputs[i], decoded_strings_stream[i]) + + +if __name__ == "__main__": + unittest.main()