[Update] 添加会话加入校验API

pull/3857/head
Bai 2020-04-07 12:26:47 +08:00
parent 29b099efc0
commit e9827c8b25
4 changed files with 54 additions and 8 deletions

View File

@ -1,22 +1,27 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from django.utils.translation import ugettext as _
from django.shortcuts import get_object_or_404, reverse from django.shortcuts import get_object_or_404, reverse
from django.core.files.storage import default_storage from django.core.files.storage import default_storage
from rest_framework import viewsets from rest_framework import viewsets, views
from rest_framework.response import Response from rest_framework.response import Response
from common.utils import is_uuid, get_logger from common.utils import is_uuid, get_logger, get_object_or_none
from common.mixins.api import AsyncApiMixin from common.mixins.api import AsyncApiMixin
from common.permissions import IsOrgAdminOrAppUser, IsOrgAuditor from common.permissions import IsOrgAdminOrAppUser, IsOrgAuditor, IsAppUser
from common.drf.filters import DatetimeRangeFilter from common.drf.filters import DatetimeRangeFilter
from orgs.mixins.api import OrgBulkModelViewSet from orgs.mixins.api import OrgBulkModelViewSet
from orgs.utils import tmp_to_root_org, tmp_to_org
from users.models import User
from ..utils import find_session_replay_local, download_session_replay from ..utils import find_session_replay_local, download_session_replay
from ..hands import SystemUser from ..hands import SystemUser
from ..models import Session from ..models import Session
from .. import serializers from .. import serializers
__all__ = ['SessionViewSet', 'SessionReplayViewSet',] __all__ = [
'SessionViewSet', 'SessionReplayViewSet', 'SessionJoinValidateAPI'
]
logger = get_logger(__name__) logger = get_logger(__name__)
@ -117,3 +122,36 @@ class SessionReplayViewSet(AsyncApiMixin, viewsets.ViewSet):
return Response({"error": url}) return Response({"error": url})
data = self.get_replay_data(session, url) data = self.get_replay_data(session, url)
return Response(data) return Response(data)
class SessionJoinValidateAPI(views.APIView):
permission_classes = (IsAppUser, )
serializer_class = serializers.SessionJoinValidateSerializer
def post(self, request, *args, **kwargs):
serializer = self.serializer_class(data=request.data)
if not serializer.is_valid():
msg = str(serializer.errors)
return Response({'ok': False, 'msg': msg}, status=401)
user_id = serializer.validated_data['user_id']
session_id = serializer.validated_data['session_id']
with tmp_to_root_org():
session = get_object_or_none(Session, pk=session_id)
if not session:
msg = _('Session does not exist: {}'.format(session_id))
return Response({'ok': False, 'msg': msg}, status=401)
if not session.can_join():
msg = _('Session is finished or the protocol not supported')
return Response({'ok': False, 'msg': msg}, status=401)
user = get_object_or_none(User, pk=user_id)
if not user:
msg = _('User does not exist: {}'.format(user_id))
return Response({'ok': False, 'msg': msg}, status=401)
with tmp_to_org(session.org):
if not user.admin_or_audit_orgs:
msg = _('User does not have permission')
return Response({'ok': False, 'msg': msg}, status=401)
return Response({'ok': True, 'msg': ''}, status=200)

View File

@ -244,9 +244,11 @@ class Session(OrgModelMixin):
return False return False
def can_join(self): def can_join(self):
if self.protocol in ['ssh', 'telnet', 'mysql']: if self.is_finished:
return True return False
return False if self.protocol not in ['ssh', 'telnet', 'mysql']:
return False
return True
def save_to_storage(self, f): def save_to_storage(self, f):
local_path = self.get_local_path() local_path = self.get_local_path()

View File

@ -6,7 +6,7 @@ from ..models import Session
__all__ = [ __all__ = [
'SessionSerializer', 'SessionDisplaySerializer', 'SessionSerializer', 'SessionDisplaySerializer',
'ReplaySerializer', 'ReplaySerializer', 'SessionJoinValidateSerializer',
] ]
@ -35,3 +35,8 @@ class SessionDisplaySerializer(SessionSerializer):
class ReplaySerializer(serializers.Serializer): class ReplaySerializer(serializers.Serializer):
file = serializers.FileField(allow_empty_file=True) file = serializers.FileField(allow_empty_file=True)
class SessionJoinValidateSerializer(serializers.Serializer):
user_id = serializers.UUIDField()
session_id = serializers.UUIDField()

View File

@ -22,6 +22,7 @@ router.register(r'replay-storages', api.ReplayStorageViewSet, 'replay-storage')
router.register(r'command-storages', api.CommandStorageViewSet, 'command-storage') router.register(r'command-storages', api.CommandStorageViewSet, 'command-storage')
urlpatterns = [ urlpatterns = [
path('sessions/join/validate/', api.SessionJoinValidateAPI.as_view(), name='join-session'),
path('sessions/<uuid:pk>/replay/', path('sessions/<uuid:pk>/replay/',
api.SessionReplayViewSet.as_view({'get': 'retrieve', 'post': 'create'}), api.SessionReplayViewSet.as_view({'get': 'retrieve', 'post': 'create'}),
name='session-replay'), name='session-replay'),