apply Black formatting

pull/216/head
Ji Zhang 2023-03-24 21:27:57 -07:00
parent 2c424264b9
commit cd28454693
1 changed files with 22 additions and 12 deletions

View File

@ -5,10 +5,17 @@ from transformers import AutoTokenizer, AutoModel
"""ChatGLM_G is a wrapper around the ChatGLM model to fit LangChain framework. May not be an optimal implementation""" """ChatGLM_G is a wrapper around the ChatGLM model to fit LangChain framework. May not be an optimal implementation"""
class ChatGLM_G(LLM): class ChatGLM_G(LLM):
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(
model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).half().cuda() "THUDM/chatglm-6b-int4", trust_remote_code=True
)
model = (
AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True)
.half()
.cuda()
)
history = [] history = []
@property @property
@ -16,20 +23,23 @@ class ChatGLM_G(LLM):
return "ChatGLM_G" return "ChatGLM_G"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
response, updated_history = self.model.chat(self.tokenizer, prompt, history=self.history) response, updated_history = self.model.chat(
print("ChatGLM: prompt: ", prompt) self.tokenizer, prompt, history=self.history, max_length=10000
print("ChatGLM: response: ", response) )
print("history: ", self.history)
if stop is not None: if stop is not None:
response = enforce_stop_tokens(response, stop) response = enforce_stop_tokens(response, stop)
self.history = updated_history self.history = updated_history
return response return response
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
response, updated_history = self.model.chat(self.tokenizer, prompt, history=self.history) response, updated_history = self.model.chat(
print("ChatGLM: prompt: ", prompt) self.tokenizer, prompt, history=self.history, max_length=10000
print("ChatGLM: response: ", response) )
print("history: ", self.history)
if stop is not None: if stop is not None:
response = enforce_stop_tokens(response, stop) response = enforce_stop_tokens(response, stop)
self.history = updated_history self.history = updated_history
return response return response