mirror of https://github.com/InternLM/InternLM
176 lines
6.4 KiB
Python
176 lines
6.4 KiB
Python
|
import argparse
|
||
|
import logging
|
||
|
from dataclasses import dataclass
|
||
|
|
||
|
import streamlit as st
|
||
|
from magic_doc.docconv import DocConverter
|
||
|
from openai import OpenAI
|
||
|
|
||
|
# Set up logging
|
||
|
logging.basicConfig(level=logging.INFO,
|
||
|
format='%(asctime)s - %(levelname)s - %(message)s')
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class GenerationConfig:
|
||
|
# this config is used for chat to provide more diversity
|
||
|
max_tokens: int = 1024
|
||
|
top_p: float = 1.0
|
||
|
temperature: float = 0.1
|
||
|
repetition_penalty: float = 1.005
|
||
|
|
||
|
|
||
|
def generate(
|
||
|
client,
|
||
|
messages,
|
||
|
generation_config,
|
||
|
):
|
||
|
stream = client.chat.completions.create(
|
||
|
model=st.session_state['model_name'],
|
||
|
messages=messages,
|
||
|
stream=True,
|
||
|
temperature=generation_config.temperature,
|
||
|
top_p=generation_config.top_p,
|
||
|
max_tokens=generation_config.max_tokens,
|
||
|
frequency_penalty=generation_config.repetition_penalty,
|
||
|
)
|
||
|
return stream
|
||
|
|
||
|
|
||
|
def prepare_generation_config():
|
||
|
with st.sidebar:
|
||
|
max_tokens = st.number_input('Max Tokens',
|
||
|
min_value=100,
|
||
|
max_value=4096,
|
||
|
value=1024)
|
||
|
top_p = st.number_input('Top P', 0.0, 1.0, 1.0, step=0.01)
|
||
|
temperature = st.number_input('Temperature', 0.0, 1.0, 0.05, step=0.01)
|
||
|
repetition_penalty = st.number_input('Repetition Penalty',
|
||
|
0.8,
|
||
|
1.2,
|
||
|
1.02,
|
||
|
step=0.001,
|
||
|
format='%0.3f')
|
||
|
st.button('Clear Chat History', on_click=on_btn_click)
|
||
|
|
||
|
generation_config = GenerationConfig(max_tokens=max_tokens,
|
||
|
top_p=top_p,
|
||
|
temperature=temperature,
|
||
|
repetition_penalty=repetition_penalty)
|
||
|
|
||
|
return generation_config
|
||
|
|
||
|
|
||
|
def on_btn_click():
|
||
|
del st.session_state.messages
|
||
|
st.session_state.file_content_found = False
|
||
|
st.session_state.file_content_used = False
|
||
|
|
||
|
|
||
|
user_avator = 'assets/user.png'
|
||
|
robot_avator = 'assets/robot.png'
|
||
|
|
||
|
st.title('InternLM2.5 File Chat 📝')
|
||
|
|
||
|
|
||
|
def main(base_url):
|
||
|
# Initialize the client for the model
|
||
|
client = OpenAI(base_url=base_url, timeout=12000)
|
||
|
|
||
|
# Get the model ID
|
||
|
model_name = client.models.list().data[0].id
|
||
|
st.session_state['model_name'] = model_name
|
||
|
|
||
|
# Get the generation config
|
||
|
generation_config = prepare_generation_config()
|
||
|
|
||
|
# Initialize session state
|
||
|
if 'messages' not in st.session_state:
|
||
|
st.session_state.messages = []
|
||
|
|
||
|
if 'file_content_found' not in st.session_state:
|
||
|
st.session_state.file_content_found = False
|
||
|
st.session_state.file_content_used = False
|
||
|
st.session_state.file_name = ''
|
||
|
|
||
|
# Handle file upload
|
||
|
if not st.session_state.file_content_found:
|
||
|
uploaded_file = st.file_uploader('Upload an article',
|
||
|
type=('txt', 'md', 'pdf'))
|
||
|
file_content = ''
|
||
|
if uploaded_file is not None:
|
||
|
if uploaded_file.type == 'application/pdf':
|
||
|
with open('uploaded_file.pdf', 'wb') as f:
|
||
|
f.write(uploaded_file.getbuffer())
|
||
|
converter = DocConverter(s3_config=None)
|
||
|
file_content, time_cost = converter.convert(
|
||
|
'uploaded_file.pdf', conv_timeout=300)
|
||
|
# Reset flag when a new file is uploaded
|
||
|
st.session_state.file_content_found = True
|
||
|
# Store the file content in session state
|
||
|
st.session_state.file_content = file_content
|
||
|
# Store the file name in session state
|
||
|
st.session_state.file_name = uploaded_file.name
|
||
|
else:
|
||
|
file_content = uploaded_file.read().decode('utf-8')
|
||
|
# Reset flag when a new file is uploaded
|
||
|
st.session_state.file_content_found = True
|
||
|
# Store the file content in session state
|
||
|
st.session_state.file_content = file_content
|
||
|
# Store the file name in session state
|
||
|
st.session_state.file_name = uploaded_file.name
|
||
|
|
||
|
if st.session_state.file_content_found:
|
||
|
st.success(f"File '{st.session_state.file_name}' "
|
||
|
'has been successfully uploaded!')
|
||
|
|
||
|
# Display chat messages
|
||
|
for message in st.session_state.messages:
|
||
|
with st.chat_message(message['role'], avatar=message.get('avatar')):
|
||
|
st.markdown(message['content'])
|
||
|
|
||
|
# Handle user input and response generation
|
||
|
if prompt := st.chat_input("What's up?"):
|
||
|
turn = {'role': 'user', 'content': prompt, 'avatar': user_avator}
|
||
|
if (st.session_state.file_content_found
|
||
|
and not st.session_state.file_content_used):
|
||
|
assert st.session_state.file_content is not None
|
||
|
merged_prompt = f'{st.session_state.file_content}\n\n{prompt}'
|
||
|
# Set flag to indicate file content has been used
|
||
|
st.session_state.file_content_used = True
|
||
|
turn['merged_content'] = merged_prompt
|
||
|
|
||
|
st.session_state.messages.append(turn)
|
||
|
with st.chat_message('user', avatar=user_avator):
|
||
|
st.markdown(prompt)
|
||
|
|
||
|
with st.chat_message('assistant', avatar=robot_avator):
|
||
|
messages = [{
|
||
|
'role':
|
||
|
m['role'],
|
||
|
'content':
|
||
|
m['merged_content'] if 'merged_content' in m else m['content'],
|
||
|
} for m in st.session_state.messages]
|
||
|
# Log messages to the terminal
|
||
|
for m in messages:
|
||
|
logging.info(
|
||
|
f"\n\n*** [{m['role']}] ***\n\n\t{m['content']}\n\n")
|
||
|
stream = generate(client, messages, generation_config)
|
||
|
response = st.write_stream(stream)
|
||
|
st.session_state.messages.append({
|
||
|
'role': 'assistant',
|
||
|
'content': response,
|
||
|
'avatar': robot_avator
|
||
|
})
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
parser = argparse.ArgumentParser(
|
||
|
description='Run Streamlit app with OpenAI client.')
|
||
|
parser.add_argument('--base_url',
|
||
|
type=str,
|
||
|
required=True,
|
||
|
help='Base URL for the OpenAI client')
|
||
|
args = parser.parse_args()
|
||
|
main(args.base_url)
|