mirror of https://github.com/THUDM/ChatGLM-6B
Merge 430224bf13
into 401bf3a8a7
commit
ae03f3609a
|
@ -0,0 +1,65 @@
|
||||||
|
import os
|
||||||
|
from transformers import AutoTokenizer, AutoModel
|
||||||
|
import signal
|
||||||
|
import platform
|
||||||
|
from stream_utils import ChatGLMStreamDecoder
|
||||||
|
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
"THUDM/chatglm-6b", trust_remote_code=True)
|
||||||
|
stream_decoder = ChatGLMStreamDecoder(tokenizer.sp_tokenizer.text_tokenizer.sp)
|
||||||
|
model = AutoModel.from_pretrained(
|
||||||
|
"THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
|
||||||
|
model = model.eval()
|
||||||
|
|
||||||
|
os_name = platform.system()
|
||||||
|
clear_command = 'cls' if os_name == 'Windows' else 'clear'
|
||||||
|
stop_stream = False
|
||||||
|
|
||||||
|
|
||||||
|
def signal_handler(signal, frame):
|
||||||
|
global stop_stream
|
||||||
|
stop_stream = True
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
history = []
|
||||||
|
global stop_stream
|
||||||
|
print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
|
||||||
|
while True:
|
||||||
|
query = input("\n用户:")
|
||||||
|
if query.strip() == "stop":
|
||||||
|
break
|
||||||
|
if query.strip() == "clear":
|
||||||
|
history = []
|
||||||
|
stream_decoder.end()
|
||||||
|
stream_decoder.get()
|
||||||
|
os.system(clear_command)
|
||||||
|
print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
|
||||||
|
continue
|
||||||
|
gen_kwargs = {"max_length": 2048, "do_sample": True, "top_p": 0.7,
|
||||||
|
"temperature": 0.95, "logits_processor": None}
|
||||||
|
if not history:
|
||||||
|
prompt = query
|
||||||
|
else:
|
||||||
|
prompt = "".join([f"[Round {i}]\n问:{q}\n答:{r}\n" for i, (q, r) in enumerate(
|
||||||
|
history)] + [f"[Round {len(history)}]\n问:{query}\n答:"])
|
||||||
|
inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
|
||||||
|
print("\nChatGLM-6B:", end="")
|
||||||
|
response = []
|
||||||
|
for outputs in model.stream_generate(**inputs, **gen_kwargs):
|
||||||
|
stream_decoder.put([int(outputs[0][-1])])
|
||||||
|
new_resp = stream_decoder.get().replace("<n>", "\n")
|
||||||
|
response.append(new_resp)
|
||||||
|
print(new_resp, end="")
|
||||||
|
# end of line
|
||||||
|
stream_decoder.end()
|
||||||
|
new_resp = stream_decoder.get().replace("<n>", "\n")
|
||||||
|
response.append(new_resp)
|
||||||
|
print(new_resp)
|
||||||
|
response = "".join(response)
|
||||||
|
history.append((query, response))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,205 @@
|
||||||
|
import sentencepiece as spm
|
||||||
|
from typing import Tuple
|
||||||
|
import re
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
# python implantation of https://github.com/google/sentencepiece/blob/master/src/sentencepiece_processor.cc
|
||||||
|
|
||||||
|
|
||||||
|
def DecodeSentencePiece(piece: str, id: int, is_bos_ws: bool, sp: spm.SentencePieceProcessor, add_dummy_prefix=True, remove_extra_whitespaces=False) -> Tuple[str, bool]:
|
||||||
|
'''
|
||||||
|
Returns decoded piece and a boolean indicating if the function has consumed
|
||||||
|
a bos whitespace token (a piece starting with a kSpaceSymbol). This is used
|
||||||
|
to strip only the first whitespace token from the decoded sequence for
|
||||||
|
add_dummy_prefix.
|
||||||
|
'''
|
||||||
|
if sp.IsControl(id): # <s>, </s>
|
||||||
|
return "", False # invisible symbol.
|
||||||
|
elif sp.IsUnknown(id):
|
||||||
|
if sp.IdToPiece(id) == piece: # <unk>
|
||||||
|
return SPStreamDecoder.DefaultUnknownSymbol, False
|
||||||
|
else: # return piece when piece is not <unk>.
|
||||||
|
return piece, False
|
||||||
|
has_bos_ws = False # whether the token starts with a kSpaceSymbol
|
||||||
|
# Consume if the current position is bos and
|
||||||
|
# piece starts with kSpaceSymbol.
|
||||||
|
if is_bos_ws and (add_dummy_prefix or remove_extra_whitespaces):
|
||||||
|
t = piece.removeprefix(SPStreamDecoder.SpaceSymbol)
|
||||||
|
has_bos_ws = t != piece
|
||||||
|
piece = t
|
||||||
|
# if we are removing extra whitespace, we remove all leading whitespace
|
||||||
|
if remove_extra_whitespaces:
|
||||||
|
has_bos_ws = False
|
||||||
|
return piece.replace(SPStreamDecoder.SpaceSymbol, " "), has_bos_ws
|
||||||
|
|
||||||
|
|
||||||
|
def ProcessBytePieces(pieces: list[str]) -> str:
|
||||||
|
'''
|
||||||
|
Modified version of original code
|
||||||
|
'''
|
||||||
|
if len(pieces) == 0:
|
||||||
|
return ""
|
||||||
|
surfaces = ""
|
||||||
|
# Constructs byte sequence.
|
||||||
|
bytes_ = bytes([int(piece[1:-1], base=16) for piece in pieces])
|
||||||
|
# Set surfaces of `bytes` for each Unicode character.
|
||||||
|
while len(bytes_) > 0:
|
||||||
|
try:
|
||||||
|
surfaces += bytes_.decode('utf-8')
|
||||||
|
break
|
||||||
|
except UnicodeDecodeError as e:
|
||||||
|
# The byte piece at `e.start` is structurally invalid. Map it to
|
||||||
|
# REPLACEMENT CHARACTER (U+FFFD).
|
||||||
|
surfaces += bytes_[:e.start].decode('utf-8')
|
||||||
|
surfaces += SPStreamDecoder.ReplacementCharacter
|
||||||
|
bytes_ = bytes_[e.end:]
|
||||||
|
continue
|
||||||
|
return surfaces
|
||||||
|
|
||||||
|
|
||||||
|
class SPStreamDecoder:
|
||||||
|
SpaceSymbol = chr(0x2581)
|
||||||
|
DefaultUnknownSymbol = chr(0x2047)
|
||||||
|
ReplacementCharacter = chr(0xFFFD)
|
||||||
|
|
||||||
|
def __init__(self, sp: spm.SentencePieceProcessor, remove_extra_whitespaces=False, add_dummy_prefix=True) -> None:
|
||||||
|
self._sp = sp
|
||||||
|
self._bos_ws_seen = False
|
||||||
|
# 'is_bos_ws': whether we expect a bos ws token to consume.
|
||||||
|
self._is_bos_ws = True
|
||||||
|
self._nothing_decoded = True
|
||||||
|
self._ids = []
|
||||||
|
self._decoded = ""
|
||||||
|
self._ending = False
|
||||||
|
self.remove_extra_whitespaces = remove_extra_whitespaces
|
||||||
|
self.add_dummy_prefix = add_dummy_prefix
|
||||||
|
|
||||||
|
def put(self, ids: list[int]) -> None:
|
||||||
|
self._ending = False
|
||||||
|
self._ids += ids
|
||||||
|
self._decode(eos=False)
|
||||||
|
|
||||||
|
def end(self) -> None:
|
||||||
|
self._decode(eos=True)
|
||||||
|
self._is_bos_ws = True
|
||||||
|
self._bos_ws_seen = False
|
||||||
|
self._nothing_decoded = True
|
||||||
|
self._ending = True
|
||||||
|
self._ids = []
|
||||||
|
|
||||||
|
def _decode(self, eos=False) -> None:
|
||||||
|
pieces = [self._sp.IdToPiece(i) for i in self._ids]
|
||||||
|
consumed = 0
|
||||||
|
byte_pieces = []
|
||||||
|
for i, piece in enumerate(pieces):
|
||||||
|
if not self._sp.IsByte(self._ids[i]):
|
||||||
|
self._decoded += ProcessBytePieces(byte_pieces)
|
||||||
|
consumed += len(byte_pieces)
|
||||||
|
if len(self._decoded) > 0:
|
||||||
|
self._nothing_decoded = False
|
||||||
|
byte_pieces = []
|
||||||
|
# if we have seen a bos_ws token or any non-empty token
|
||||||
|
if self._bos_ws_seen or (not self._nothing_decoded):
|
||||||
|
self._is_bos_ws = False
|
||||||
|
decoded, self._bos_ws_seen = DecodeSentencePiece(
|
||||||
|
piece, self._ids[i], self._is_bos_ws, self._sp)
|
||||||
|
self._decoded += decoded
|
||||||
|
consumed += 1
|
||||||
|
if len(self._decoded) > 0:
|
||||||
|
self._nothing_decoded = False
|
||||||
|
else:
|
||||||
|
byte_pieces.append(piece)
|
||||||
|
if eos:
|
||||||
|
self._decoded += ProcessBytePieces(byte_pieces)
|
||||||
|
else:
|
||||||
|
self._ids = self._ids[consumed:]
|
||||||
|
|
||||||
|
def get(self) -> str:
|
||||||
|
t = self._decoded
|
||||||
|
self._decoded = ""
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
class ChatGLMStreamDecoder(SPStreamDecoder):
|
||||||
|
|
||||||
|
def get(self) -> str:
|
||||||
|
# if prefix of special tokens found, wait till it's impossible or end of decode
|
||||||
|
if "[" in self._decoded and len(self._decoded)-self._decoded.index("[") < 8 and not self._ending:
|
||||||
|
return ""
|
||||||
|
if "<" in self._decoded and len(self._decoded)-self._decoded.index("<") < 12 and not self._ending:
|
||||||
|
return ""
|
||||||
|
self._ending = False
|
||||||
|
t = self._decoded
|
||||||
|
self._decoded = ""
|
||||||
|
t = t.replace("<n>", "\n")
|
||||||
|
t = t.replace("[[训练时间]]", "2023年")
|
||||||
|
punkts = [
|
||||||
|
[",", ","],
|
||||||
|
["!", "!"],
|
||||||
|
[":", ":"],
|
||||||
|
[";", ";"],
|
||||||
|
["\?", "?"],
|
||||||
|
]
|
||||||
|
for item in punkts:
|
||||||
|
t = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], t)
|
||||||
|
t = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], t)
|
||||||
|
# for i in range(max_len, 1, -1):
|
||||||
|
# t = t.replace(f"<|blank_{i}|>", " " * i)
|
||||||
|
for blank_token in re.findall(r"<\|blank_\d+\|>", t):
|
||||||
|
t = t.replace(blank_token, " " *
|
||||||
|
int(re.search(r"\d+", blank_token)[0]))
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
class ChatGLMStreamDecoderTest(unittest.TestCase):
|
||||||
|
def test_ChatGLM_StreamDecoder(self):
|
||||||
|
from transformers import AutoTokenizer, AutoModel
|
||||||
|
test_strings = [
|
||||||
|
"你好👋", # multi-byte encoding
|
||||||
|
"Hello this is ChatGLM!", # normal text
|
||||||
|
"你好👋 This is ChatGLM!", # multi-byte encoding with tail
|
||||||
|
"!?.,!?。,", # punctuations
|
||||||
|
"A\nB", # "<n>" -> "\n"
|
||||||
|
"[[训练时间]]", # training time token
|
||||||
|
"[[训练时间]123", # broken training time token
|
||||||
|
"1 1", # blank token. Note: It's hard to match the results of strip(), so add leading and tailing "1"
|
||||||
|
"<|blank_8|123", # broken blank token
|
||||||
|
]
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
"THUDM/chatglm-6b", trust_remote_code=True)
|
||||||
|
model = AutoModel.from_pretrained(
|
||||||
|
"THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
|
||||||
|
model = model.eval()
|
||||||
|
encoded_ids = [tokenizer(x)['input_ids'] for x in test_strings]
|
||||||
|
stream_decoder = ChatGLMStreamDecoder(
|
||||||
|
tokenizer.sp_tokenizer.text_tokenizer.sp)
|
||||||
|
# original output
|
||||||
|
expected_outputs = [model.process_response(
|
||||||
|
tokenizer.decode(x)) for x in encoded_ids]
|
||||||
|
# decode token by token
|
||||||
|
decoded_strings_stream_token_by_token = [None for _ in test_strings]
|
||||||
|
for i in range(len(test_strings)):
|
||||||
|
res = []
|
||||||
|
for t in encoded_ids[i]:
|
||||||
|
stream_decoder.put([t])
|
||||||
|
res.append(stream_decoder.get())
|
||||||
|
stream_decoder.end()
|
||||||
|
res.append(stream_decoder.get())
|
||||||
|
res = "".join(res)
|
||||||
|
decoded_strings_stream_token_by_token[i] = res
|
||||||
|
# decode all at once
|
||||||
|
decoded_strings_stream = [None for _ in test_strings]
|
||||||
|
for i in range(len(test_strings)):
|
||||||
|
stream_decoder.put(encoded_ids[i])
|
||||||
|
stream_decoder.end()
|
||||||
|
decoded_strings_stream[i] = stream_decoder.get()
|
||||||
|
for i in range(len(test_strings)):
|
||||||
|
print(
|
||||||
|
f"Stream decoder test{i}: expected: '{expected_outputs[i]}', token_by_token: '{decoded_strings_stream_token_by_token[i]}', all at once: '{decoded_strings_stream[i]}'")
|
||||||
|
self.assertEqual(
|
||||||
|
expected_outputs[i], decoded_strings_stream_token_by_token[i])
|
||||||
|
self.assertEqual(expected_outputs[i], decoded_strings_stream[i])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue