|
|
@ -1,3 +1,4 @@
|
|
|
|
|
|
|
|
import re
|
|
|
|
from threading import Lock
|
|
|
|
from threading import Lock
|
|
|
|
from typing import Any, Callable, Generator, List, Optional
|
|
|
|
from typing import Any, Callable, Generator, List, Optional
|
|
|
|
|
|
|
|
|
|
|
@ -118,6 +119,9 @@ def _format_dialogue(instruction: str, response: str = ''):
|
|
|
|
return f'\n\n### Instruction:\n{instruction}\n\n### Response:\n{response}'
|
|
|
|
return f'\n\n### Instruction:\n{instruction}\n\n### Response:\n{response}'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
STOP_PAT = re.compile(r'(###|instruction:).*', flags=(re.I | re.S))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChatPromptProcessor:
|
|
|
|
class ChatPromptProcessor:
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, tokenizer, context: str, max_len: int = 2048):
|
|
|
|
def __init__(self, tokenizer, context: str, max_len: int = 2048):
|
|
|
@ -164,6 +168,10 @@ class ChatPromptProcessor:
|
|
|
|
prompt += ''.join(rows) + _format_dialogue(last_dialogue.instruction)
|
|
|
|
prompt += ''.join(rows) + _format_dialogue(last_dialogue.instruction)
|
|
|
|
return prompt
|
|
|
|
return prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def postprocess_output(self, output: str) -> str:
|
|
|
|
|
|
|
|
output = STOP_PAT.sub('', output)
|
|
|
|
|
|
|
|
return output.strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LockedIterator:
|
|
|
|
class LockedIterator:
|
|
|
|
|
|
|
|
|
|
|
|