mirror of https://github.com/THUDM/ChatGLM-6B
parent
d443215bea
commit
430224bf13
|
@ -2,12 +2,12 @@ import os
|
|||
from transformers import AutoTokenizer, AutoModel
|
||||
import signal
|
||||
import platform
|
||||
from stream_utils import SPStreamDecoder
|
||||
from stream_utils import ChatGLMStreamDecoder
|
||||
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"THUDM/chatglm-6b", trust_remote_code=True)
|
||||
stream_decoder = SPStreamDecoder(tokenizer.sp_tokenizer.text_tokenizer.sp)
|
||||
stream_decoder = ChatGLMStreamDecoder(tokenizer.sp_tokenizer.text_tokenizer.sp)
|
||||
model = AutoModel.from_pretrained(
|
||||
"THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
|
||||
model = model.eval()
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
import sentencepiece as spm
|
||||
from typing import Tuple
|
||||
import re
|
||||
import unittest
|
||||
|
||||
# python implantation of https://github.com/google/sentencepiece/blob/master/src/sentencepiece_processor.cc
|
||||
|
||||
|
@ -24,6 +26,7 @@ def DecodeSentencePiece(piece: str, id: int, is_bos_ws: bool, sp: spm.SentencePi
|
|||
if is_bos_ws and (add_dummy_prefix or remove_extra_whitespaces):
|
||||
t = piece.removeprefix(SPStreamDecoder.SpaceSymbol)
|
||||
has_bos_ws = t != piece
|
||||
piece = t
|
||||
# if we are removing extra whitespace, we remove all leading whitespace
|
||||
if remove_extra_whitespaces:
|
||||
has_bos_ws = False
|
||||
|
@ -67,17 +70,21 @@ class SPStreamDecoder:
|
|||
self._nothing_decoded = True
|
||||
self._ids = []
|
||||
self._decoded = ""
|
||||
self._ending = False
|
||||
self.remove_extra_whitespaces = remove_extra_whitespaces
|
||||
self.add_dummy_prefix = add_dummy_prefix
|
||||
|
||||
def put(self, ids: list[int]) -> None:
|
||||
self._ending = False
|
||||
self._ids += ids
|
||||
self._decode(eos=False)
|
||||
|
||||
def end(self) -> None:
|
||||
self._decode(eos=True)
|
||||
self._is_bos_ws = True
|
||||
self._bos_ws_seen = False
|
||||
self._nothing_decoded = True
|
||||
self._ending = True
|
||||
self._ids = []
|
||||
|
||||
def _decode(self, eos=False) -> None:
|
||||
|
@ -88,7 +95,7 @@ class SPStreamDecoder:
|
|||
if not self._sp.IsByte(self._ids[i]):
|
||||
self._decoded += ProcessBytePieces(byte_pieces)
|
||||
consumed += len(byte_pieces)
|
||||
if consumed > 0:
|
||||
if len(self._decoded) > 0:
|
||||
self._nothing_decoded = False
|
||||
byte_pieces = []
|
||||
# if we have seen a bos_ws token or any non-empty token
|
||||
|
@ -98,7 +105,7 @@ class SPStreamDecoder:
|
|||
piece, self._ids[i], self._is_bos_ws, self._sp)
|
||||
self._decoded += decoded
|
||||
consumed += 1
|
||||
if consumed > 0:
|
||||
if len(self._decoded) > 0:
|
||||
self._nothing_decoded = False
|
||||
else:
|
||||
byte_pieces.append(piece)
|
||||
|
@ -111,3 +118,88 @@ class SPStreamDecoder:
|
|||
t = self._decoded
|
||||
self._decoded = ""
|
||||
return t
|
||||
|
||||
|
||||
class ChatGLMStreamDecoder(SPStreamDecoder):
|
||||
|
||||
def get(self) -> str:
|
||||
# if prefix of special tokens found, wait till it's impossible or end of decode
|
||||
if "[" in self._decoded and len(self._decoded)-self._decoded.index("[") < 8 and not self._ending:
|
||||
return ""
|
||||
if "<" in self._decoded and len(self._decoded)-self._decoded.index("<") < 12 and not self._ending:
|
||||
return ""
|
||||
self._ending = False
|
||||
t = self._decoded
|
||||
self._decoded = ""
|
||||
t = t.replace("<n>", "\n")
|
||||
t = t.replace("[[训练时间]]", "2023年")
|
||||
punkts = [
|
||||
[",", ","],
|
||||
["!", "!"],
|
||||
[":", ":"],
|
||||
[";", ";"],
|
||||
["\?", "?"],
|
||||
]
|
||||
for item in punkts:
|
||||
t = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], t)
|
||||
t = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], t)
|
||||
# for i in range(max_len, 1, -1):
|
||||
# t = t.replace(f"<|blank_{i}|>", " " * i)
|
||||
for blank_token in re.findall(r"<\|blank_\d+\|>", t):
|
||||
t = t.replace(blank_token, " " *
|
||||
int(re.search(r"\d+", blank_token)[0]))
|
||||
return t
|
||||
|
||||
|
||||
class ChatGLMStreamDecoderTest(unittest.TestCase):
|
||||
def test_ChatGLM_StreamDecoder(self):
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
test_strings = [
|
||||
"你好👋", # multi-byte encoding
|
||||
"Hello this is ChatGLM!", # normal text
|
||||
"你好👋 This is ChatGLM!", # multi-byte encoding with tail
|
||||
"!?.,!?。,", # punctuations
|
||||
"A\nB", # "<n>" -> "\n"
|
||||
"[[训练时间]]", # training time token
|
||||
"[[训练时间]123", # broken training time token
|
||||
"1 1", # blank token. Note: It's hard to match the results of strip(), so add leading and tailing "1"
|
||||
"<|blank_8|123", # broken blank token
|
||||
]
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"THUDM/chatglm-6b", trust_remote_code=True)
|
||||
model = AutoModel.from_pretrained(
|
||||
"THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
|
||||
model = model.eval()
|
||||
encoded_ids = [tokenizer(x)['input_ids'] for x in test_strings]
|
||||
stream_decoder = ChatGLMStreamDecoder(
|
||||
tokenizer.sp_tokenizer.text_tokenizer.sp)
|
||||
# original output
|
||||
expected_outputs = [model.process_response(
|
||||
tokenizer.decode(x)) for x in encoded_ids]
|
||||
# decode token by token
|
||||
decoded_strings_stream_token_by_token = [None for _ in test_strings]
|
||||
for i in range(len(test_strings)):
|
||||
res = []
|
||||
for t in encoded_ids[i]:
|
||||
stream_decoder.put([t])
|
||||
res.append(stream_decoder.get())
|
||||
stream_decoder.end()
|
||||
res.append(stream_decoder.get())
|
||||
res = "".join(res)
|
||||
decoded_strings_stream_token_by_token[i] = res
|
||||
# decode all at once
|
||||
decoded_strings_stream = [None for _ in test_strings]
|
||||
for i in range(len(test_strings)):
|
||||
stream_decoder.put(encoded_ids[i])
|
||||
stream_decoder.end()
|
||||
decoded_strings_stream[i] = stream_decoder.get()
|
||||
for i in range(len(test_strings)):
|
||||
print(
|
||||
f"Stream decoder test{i}: expected: '{expected_outputs[i]}', token_by_token: '{decoded_strings_stream_token_by_token[i]}', all at once: '{decoded_strings_stream[i]}'")
|
||||
self.assertEqual(
|
||||
expected_outputs[i], decoded_strings_stream_token_by_token[i])
|
||||
self.assertEqual(expected_outputs[i], decoded_strings_stream[i])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Loading…
Reference in New Issue