You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

119 lines
3.9 KiB

import httpx
import openai
from django.conf import settings
from django.utils.translation import gettext_lazy as _
from rest_framework import status
from rest_framework.generics import GenericAPIView
from rest_framework.response import Response
from common.api import JMSModelViewSet
from common.permissions import IsValidUser, OnlySuperUser
from .. import serializers
from ..models import ChatPrompt
from ..prompt import DefaultChatPrompt
class ChatAITestingAPI(GenericAPIView):
serializer_class = serializers.ChatAISettingSerializer
rbac_perms = {
'POST': 'settings.change_chatai'
}
def get_config(self, request):
serializer = self.serializer_class(data=request.data)
serializer.is_valid(raise_exception=True)
data = self.serializer_class().data
data.update(serializer.validated_data)
for k, v in data.items():
if v:
continue
# 页面没有传递值, 从 settings 中获取
data[k] = getattr(settings, k, None)
return data
def post(self, request):
config = self.get_config(request)
chat_ai_enabled = config['CHAT_AI_ENABLED']
if not chat_ai_enabled:
return Response(
status=status.HTTP_400_BAD_REQUEST,
data={'msg': _('Chat AI is not enabled')}
)
proxy = config['GPT_PROXY']
model = config['GPT_MODEL']
kwargs = {
'base_url': config['GPT_BASE_URL'] or None,
'api_key': config['GPT_API_KEY'],
}
try:
if proxy:
kwargs['http_client'] = httpx.Client(
proxies=proxy,
transport=httpx.HTTPTransport(local_address='0.0.0.0')
)
client = openai.OpenAI(**kwargs)
ok = False
error = ''
client.chat.completions.create(
messages=[
{
"role": "user",
"content": "Say this is a test",
}
],
model=model,
)
ok = True
except openai.APIConnectionError as e:
error = str(e.__cause__) # an underlying Exception, likely raised within httpx.
except openai.APIStatusError as e:
error = str(e.message)
except Exception as e:
ok, error = False, str(e)
if ok:
_status, msg = status.HTTP_200_OK, _('Test success')
else:
_status, msg = status.HTTP_400_BAD_REQUEST, error
return Response(status=_status, data={'msg': msg})
class ChatPromptViewSet(JMSModelViewSet):
serializer_classes = {
'default': serializers.ChatPromptSerializer,
}
permission_classes = [IsValidUser]
queryset = ChatPrompt.objects.all()
http_method_names = ['get', 'options']
filterset_fields = ['name']
search_fields = filterset_fields
def get_permissions(self):
if self.action in ['create', 'update', 'partial_update', 'destroy']:
self.permission_classes = [OnlySuperUser]
return super().get_permissions()
def filter_default_prompts(self):
lang = self.request.LANGUAGE_CODE
default_prompts = DefaultChatPrompt.get_prompts(lang)
search_query = self.request.query_params.get('search')
search_query = search_query or self.request.query_params.get('name')
if not search_query:
return default_prompts
search_query = search_query.lower()
filtered_prompts = [
prompt for prompt in default_prompts
if search_query in prompt['name'].lower()
]
return filtered_prompts
def filter_queryset(self, queryset):
queryset = super().filter_queryset(queryset)
default_prompts = self.filter_default_prompts()
return list(queryset) + default_prompts