import sentencepiece as spm from typing import Tuple import re import unittest # python implantation of https://github.com/google/sentencepiece/blob/master/src/sentencepiece_processor.cc def DecodeSentencePiece(piece: str, id: int, is_bos_ws: bool, sp: spm.SentencePieceProcessor, add_dummy_prefix=True, remove_extra_whitespaces=False) -> Tuple[str, bool]: ''' Returns decoded piece and a boolean indicating if the function has consumed a bos whitespace token (a piece starting with a kSpaceSymbol). This is used to strip only the first whitespace token from the decoded sequence for add_dummy_prefix. ''' if sp.IsControl(id): # , return "", False # invisible symbol. elif sp.IsUnknown(id): if sp.IdToPiece(id) == piece: # return SPStreamDecoder.DefaultUnknownSymbol, False else: # return piece when piece is not . return piece, False has_bos_ws = False # whether the token starts with a kSpaceSymbol # Consume if the current position is bos and # piece starts with kSpaceSymbol. if is_bos_ws and (add_dummy_prefix or remove_extra_whitespaces): t = piece.removeprefix(SPStreamDecoder.SpaceSymbol) has_bos_ws = t != piece piece = t # if we are removing extra whitespace, we remove all leading whitespace if remove_extra_whitespaces: has_bos_ws = False return piece.replace(SPStreamDecoder.SpaceSymbol, " "), has_bos_ws def ProcessBytePieces(pieces: list[str]) -> str: ''' Modified version of original code ''' if len(pieces) == 0: return "" surfaces = "" # Constructs byte sequence. bytes_ = bytes([int(piece[1:-1], base=16) for piece in pieces]) # Set surfaces of `bytes` for each Unicode character. while len(bytes_) > 0: try: surfaces += bytes_.decode('utf-8') break except UnicodeDecodeError as e: # The byte piece at `e.start` is structurally invalid. Map it to # REPLACEMENT CHARACTER (U+FFFD). surfaces += bytes_[:e.start].decode('utf-8') surfaces += SPStreamDecoder.ReplacementCharacter bytes_ = bytes_[e.end:] continue return surfaces class SPStreamDecoder: SpaceSymbol = chr(0x2581) DefaultUnknownSymbol = chr(0x2047) ReplacementCharacter = chr(0xFFFD) def __init__(self, sp: spm.SentencePieceProcessor, remove_extra_whitespaces=False, add_dummy_prefix=True) -> None: self._sp = sp self._bos_ws_seen = False # 'is_bos_ws': whether we expect a bos ws token to consume. self._is_bos_ws = True self._nothing_decoded = True self._ids = [] self._decoded = "" self._ending = False self.remove_extra_whitespaces = remove_extra_whitespaces self.add_dummy_prefix = add_dummy_prefix def put(self, ids: list[int]) -> None: self._ending = False self._ids += ids self._decode(eos=False) def end(self) -> None: self._decode(eos=True) self._is_bos_ws = True self._bos_ws_seen = False self._nothing_decoded = True self._ending = True self._ids = [] def _decode(self, eos=False) -> None: pieces = [self._sp.IdToPiece(i) for i in self._ids] consumed = 0 byte_pieces = [] for i, piece in enumerate(pieces): if not self._sp.IsByte(self._ids[i]): self._decoded += ProcessBytePieces(byte_pieces) consumed += len(byte_pieces) if len(self._decoded) > 0: self._nothing_decoded = False byte_pieces = [] # if we have seen a bos_ws token or any non-empty token if self._bos_ws_seen or (not self._nothing_decoded): self._is_bos_ws = False decoded, self._bos_ws_seen = DecodeSentencePiece( piece, self._ids[i], self._is_bos_ws, self._sp) self._decoded += decoded consumed += 1 if len(self._decoded) > 0: self._nothing_decoded = False else: byte_pieces.append(piece) if eos: self._decoded += ProcessBytePieces(byte_pieces) else: self._ids = self._ids[consumed:] def get(self) -> str: t = self._decoded self._decoded = "" return t class ChatGLMStreamDecoder(SPStreamDecoder): def get(self) -> str: # if prefix of special tokens found, wait till it's impossible or end of decode if "[" in self._decoded and len(self._decoded)-self._decoded.index("[") < 8 and not self._ending: return "" if "<" in self._decoded and len(self._decoded)-self._decoded.index("<") < 12 and not self._ending: return "" self._ending = False t = self._decoded self._decoded = "" t = t.replace("", "\n") t = t.replace("[[训练时间]]", "2023年") punkts = [ [",", ","], ["!", "!"], [":", ":"], [";", ";"], ["\?", "?"], ] for item in punkts: t = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], t) t = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], t) # for i in range(max_len, 1, -1): # t = t.replace(f"<|blank_{i}|>", " " * i) for blank_token in re.findall(r"<\|blank_\d+\|>", t): t = t.replace(blank_token, " " * int(re.search(r"\d+", blank_token)[0])) return t class ChatGLMStreamDecoderTest(unittest.TestCase): def test_ChatGLM_StreamDecoder(self): from transformers import AutoTokenizer, AutoModel test_strings = [ "你好👋", # multi-byte encoding "Hello this is ChatGLM!", # normal text "你好👋 This is ChatGLM!", # multi-byte encoding with tail "!?.,!?。,", # punctuations "A\nB", # "" -> "\n" "[[训练时间]]", # training time token "[[训练时间]123", # broken training time token "1 1", # blank token. Note: It's hard to match the results of strip(), so add leading and tailing "1" "<|blank_8|123", # broken blank token ] tokenizer = AutoTokenizer.from_pretrained( "THUDM/chatglm-6b", trust_remote_code=True) model = AutoModel.from_pretrained( "THUDM/chatglm-6b", trust_remote_code=True).half().cuda() model = model.eval() encoded_ids = [tokenizer(x)['input_ids'] for x in test_strings] stream_decoder = ChatGLMStreamDecoder( tokenizer.sp_tokenizer.text_tokenizer.sp) # original output expected_outputs = [model.process_response( tokenizer.decode(x)) for x in encoded_ids] # decode token by token decoded_strings_stream_token_by_token = [None for _ in test_strings] for i in range(len(test_strings)): res = [] for t in encoded_ids[i]: stream_decoder.put([t]) res.append(stream_decoder.get()) stream_decoder.end() res.append(stream_decoder.get()) res = "".join(res) decoded_strings_stream_token_by_token[i] = res # decode all at once decoded_strings_stream = [None for _ in test_strings] for i in range(len(test_strings)): stream_decoder.put(encoded_ids[i]) stream_decoder.end() decoded_strings_stream[i] = stream_decoder.get() for i in range(len(test_strings)): print( f"Stream decoder test{i}: expected: '{expected_outputs[i]}', token_by_token: '{decoded_strings_stream_token_by_token[i]}', all at once: '{decoded_strings_stream[i]}'") self.assertEqual( expected_outputs[i], decoded_strings_stream_token_by_token[i]) self.assertEqual(expected_outputs[i], decoded_strings_stream[i]) if __name__ == "__main__": unittest.main()