# Copyright 2023 lm-sys@FastChat # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import dataclasses from enum import Enum, auto from typing import List class SeparatorStyle(Enum): ADD_EOS_TOKEN = auto() @dataclasses.dataclass class Conversation: system: str roles: List[str] messages: List[List[str]] offset: int sep_style: SeparatorStyle = SeparatorStyle.ADD_EOS_TOKEN sep: str = "" skip_next: bool = False def get_prompt(self): if self.sep_style == SeparatorStyle.ADD_EOS_TOKEN: ret = self.system for role, message in self.messages: if message: ret += role + ": " + message + self.sep else: ret += role + ": " return ret else: raise ValueError(f"Invalid style: {self.sep_style}") def append_message(self, role, message): self.messages.append([role, message]) def to_gradio_chatbot(self): ret = [] for i, (role, msg) in enumerate(self.messages[self.offset :]): if i % 2 == 0: ret.append([msg, None]) else: ret[-1][-1] = msg return ret def copy(self): return Conversation( system=self.system, roles=self.roles, messages=[[x, y] for x, y in self.messages], offset=self.offset, sep_style=self.sep_style, sep=self.sep, ) def dict(self): return { "system": self.system, "roles": self.roles, "messages": self.messages, "offset": self.offset, "sep": self.sep, } conv = Conversation( system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", roles=("Human", "Assistant"), messages=(), offset=0, sep_style=SeparatorStyle.ADD_EOS_TOKEN, sep="", ) default_conversation = conv