mirror of https://github.com/jumpserver/jumpserver
				
				
				
			perf: connection token 分api权限
							parent
							
								
									2493647e5c
								
							
						
					
					
						commit
						d23953932f
					
				| 
						 | 
				
			
			@ -7,7 +7,6 @@ import os
 | 
			
		|||
import base64
 | 
			
		||||
import ctypes
 | 
			
		||||
 | 
			
		||||
from django.conf import settings
 | 
			
		||||
from django.core.cache import cache
 | 
			
		||||
from django.shortcuts import get_object_or_404
 | 
			
		||||
from django.http import HttpResponse
 | 
			
		||||
| 
						 | 
				
			
			@ -33,11 +32,11 @@ from perms.utils.asset.permission import get_asset_actions
 | 
			
		|||
from common.const.http import PATCH
 | 
			
		||||
from terminal.models import EndpointRule
 | 
			
		||||
from ..serializers import (
 | 
			
		||||
    ConnectionTokenSerializer, ConnectionTokenSecretSerializer,
 | 
			
		||||
    ConnectionTokenSerializer, ConnectionTokenSecretSerializer, SuperConnectionTokenSerializer
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
__all__ = ['UserConnectionTokenViewSet', 'TokenCacheMixin']
 | 
			
		||||
__all__ = ['UserConnectionTokenViewSet', 'UserSuperConnectionTokenViewSet', 'TokenCacheMixin']
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ClientProtocolMixin:
 | 
			
		||||
| 
						 | 
				
			
			@ -70,8 +69,7 @@ class ClientProtocolMixin:
 | 
			
		|||
        system_user = serializer.validated_data['system_user']
 | 
			
		||||
 | 
			
		||||
        user = serializer.validated_data.get('user')
 | 
			
		||||
        if not user or not self.request.user.is_superuser:
 | 
			
		||||
            user = self.request.user
 | 
			
		||||
        user = user if user else self.request.user
 | 
			
		||||
        return asset, application, system_user, user
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
| 
						 | 
				
			
			@ -105,7 +103,7 @@ class ClientProtocolMixin:
 | 
			
		|||
            'bookmarktype:i': '3',
 | 
			
		||||
            'use redirection server name:i': '0',
 | 
			
		||||
            'smart sizing:i': '1',
 | 
			
		||||
            #'drivestoredirect:s': '*',
 | 
			
		||||
            # 'drivestoredirect:s': '*',
 | 
			
		||||
            # 'domain:s': ''
 | 
			
		||||
            # 'alternate shell:s:': '||MySQLWorkbench',
 | 
			
		||||
            # 'remoteapplicationname:s': 'Firefox',
 | 
			
		||||
| 
						 | 
				
			
			@ -206,21 +204,6 @@ class ClientProtocolMixin:
 | 
			
		|||
        rst = rst.decode('ascii')
 | 
			
		||||
        return rst
 | 
			
		||||
 | 
			
		||||
    @action(methods=['POST', 'GET'], detail=False, url_path='rdp/file')
 | 
			
		||||
    def get_rdp_file(self, request, *args, **kwargs):
 | 
			
		||||
        if self.request.method == 'GET':
 | 
			
		||||
            data = self.request.query_params
 | 
			
		||||
        else:
 | 
			
		||||
            data = self.request.data
 | 
			
		||||
        serializer = self.get_serializer(data=data)
 | 
			
		||||
        serializer.is_valid(raise_exception=True)
 | 
			
		||||
        name, data = self.get_rdp_file_content(serializer)
 | 
			
		||||
        response = HttpResponse(data, content_type='application/octet-stream')
 | 
			
		||||
        filename = "{}-{}-jumpserver.rdp".format(self.request.user.username, name)
 | 
			
		||||
        filename = urllib.parse.quote(filename)
 | 
			
		||||
        response['Content-Disposition'] = 'attachment; filename*=UTF-8\'\'%s' % filename
 | 
			
		||||
        return response
 | 
			
		||||
 | 
			
		||||
    def get_valid_serializer(self):
 | 
			
		||||
        if self.request.method == 'GET':
 | 
			
		||||
            data = self.request.query_params
 | 
			
		||||
| 
						 | 
				
			
			@ -252,6 +235,21 @@ class ClientProtocolMixin:
 | 
			
		|||
        }
 | 
			
		||||
        return data
 | 
			
		||||
 | 
			
		||||
    @action(methods=['POST', 'GET'], detail=False, url_path='rdp/file')
 | 
			
		||||
    def get_rdp_file(self, request, *args, **kwargs):
 | 
			
		||||
        if self.request.method == 'GET':
 | 
			
		||||
            data = self.request.query_params
 | 
			
		||||
        else:
 | 
			
		||||
            data = self.request.data
 | 
			
		||||
        serializer = self.get_serializer(data=data)
 | 
			
		||||
        serializer.is_valid(raise_exception=True)
 | 
			
		||||
        name, data = self.get_rdp_file_content(serializer)
 | 
			
		||||
        response = HttpResponse(data, content_type='application/octet-stream')
 | 
			
		||||
        filename = "{}-{}-jumpserver.rdp".format(self.request.user.username, name)
 | 
			
		||||
        filename = urllib.parse.quote(filename)
 | 
			
		||||
        response['Content-Disposition'] = 'attachment; filename*=UTF-8\'\'%s' % filename
 | 
			
		||||
        return response
 | 
			
		||||
 | 
			
		||||
    @action(methods=['POST', 'GET'], detail=False, url_path='client-url')
 | 
			
		||||
    def get_client_protocol_url(self, request, *args, **kwargs):
 | 
			
		||||
        serializer = self.get_valid_serializer()
 | 
			
		||||
| 
						 | 
				
			
			@ -370,7 +368,7 @@ class TokenCacheMixin:
 | 
			
		|||
        key = self.get_token_cache_key(token)
 | 
			
		||||
        return cache.ttl(key)
 | 
			
		||||
 | 
			
		||||
    def set_token_to_cache(self, token, value, ttl=5*60):
 | 
			
		||||
    def set_token_to_cache(self, token, value, ttl=5 * 60):
 | 
			
		||||
        key = self.get_token_cache_key(token)
 | 
			
		||||
        cache.set(key, value, timeout=ttl)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -379,7 +377,7 @@ class TokenCacheMixin:
 | 
			
		|||
        value = cache.get(key, None)
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
    def renewal_token(self, token, ttl=5*60):
 | 
			
		||||
    def renewal_token(self, token, ttl=5 * 60):
 | 
			
		||||
        value = self.get_token_from_cache(token)
 | 
			
		||||
        if value:
 | 
			
		||||
            pre_ttl = self.get_token_ttl(token)
 | 
			
		||||
| 
						 | 
				
			
			@ -397,22 +395,10 @@ class TokenCacheMixin:
 | 
			
		|||
        return data
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class UserConnectionTokenViewSet(
 | 
			
		||||
class BaseUserConnectionTokenViewSet(
 | 
			
		||||
    RootOrgViewMixin, SerializerMixin, ClientProtocolMixin,
 | 
			
		||||
    SecretDetailMixin, TokenCacheMixin, GenericViewSet
 | 
			
		||||
    TokenCacheMixin, GenericViewSet
 | 
			
		||||
):
 | 
			
		||||
    serializer_classes = {
 | 
			
		||||
        'default': ConnectionTokenSerializer,
 | 
			
		||||
        'get_secret_detail': ConnectionTokenSecretSerializer,
 | 
			
		||||
    }
 | 
			
		||||
    rbac_perms = {
 | 
			
		||||
        'GET': 'authentication.view_connectiontoken',
 | 
			
		||||
        'create': 'authentication.add_connectiontoken',
 | 
			
		||||
        'renewal': 'authentication.add_superconnectiontoken',
 | 
			
		||||
        'get_secret_detail': 'authentication.view_connectiontokensecret',
 | 
			
		||||
        'get_rdp_file': 'authentication.add_connectiontoken',
 | 
			
		||||
        'get_client_protocol_url': 'authentication.add_connectiontoken',
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def check_resource_permission(user, asset, application, system_user):
 | 
			
		||||
| 
						 | 
				
			
			@ -429,22 +415,7 @@ class UserConnectionTokenViewSet(
 | 
			
		|||
            raise PermissionDenied(error)
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    @action(methods=[PATCH], detail=False)
 | 
			
		||||
    def renewal(self, request, *args, **kwargs):
 | 
			
		||||
        """ 续期 Token """
 | 
			
		||||
        perm_required = 'authentication.add_superconnectiontoken'
 | 
			
		||||
        if not request.user.has_perm(perm_required):
 | 
			
		||||
            raise PermissionDenied('No permissions for authentication.add_superconnectiontoken')
 | 
			
		||||
        token = request.data.get('token', '')
 | 
			
		||||
        data = self.renewal_token(token)
 | 
			
		||||
        status_code = 200 if data.get('ok') else 404
 | 
			
		||||
        return Response(data=data, status=status_code)
 | 
			
		||||
 | 
			
		||||
    def create_token(self, user, asset, application, system_user, ttl=5*60):
 | 
			
		||||
        # 再次强调一下权限
 | 
			
		||||
        perm_required = 'authentication.add_superconnectiontoken'
 | 
			
		||||
        if user != self.request.user and not self.request.user.has_perm(perm_required):
 | 
			
		||||
            raise PermissionDenied('Only can create user token')
 | 
			
		||||
    def create_token(self, user, asset, application, system_user, ttl=5 * 60):
 | 
			
		||||
        self.check_resource_permission(user, asset, application, system_user)
 | 
			
		||||
        token = random_string(36)
 | 
			
		||||
        secret = random_string(16)
 | 
			
		||||
| 
						 | 
				
			
			@ -489,6 +460,20 @@ class UserConnectionTokenViewSet(
 | 
			
		|||
        }
 | 
			
		||||
        return Response(data, status=201)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class UserConnectionTokenViewSet(BaseUserConnectionTokenViewSet, SecretDetailMixin):
 | 
			
		||||
    serializer_classes = {
 | 
			
		||||
        'default': ConnectionTokenSerializer,
 | 
			
		||||
        'get_secret_detail': ConnectionTokenSecretSerializer,
 | 
			
		||||
    }
 | 
			
		||||
    rbac_perms = {
 | 
			
		||||
        'GET': 'authentication.view_connectiontoken',
 | 
			
		||||
        'create': 'authentication.add_connectiontoken',
 | 
			
		||||
        'get_secret_detail': 'authentication.view_connectiontokensecret',
 | 
			
		||||
        'get_rdp_file': 'authentication.add_connectiontoken',
 | 
			
		||||
        'get_client_protocol_url': 'authentication.add_connectiontoken',
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    def valid_token(self, token):
 | 
			
		||||
        from users.models import User
 | 
			
		||||
        from assets.models import SystemUser, Asset
 | 
			
		||||
| 
						 | 
				
			
			@ -526,3 +511,23 @@ class UserConnectionTokenViewSet(
 | 
			
		|||
        if not value:
 | 
			
		||||
            return Response('', status=404)
 | 
			
		||||
        return Response(value)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class UserSuperConnectionTokenViewSet(
 | 
			
		||||
    BaseUserConnectionTokenViewSet, TokenCacheMixin, GenericViewSet
 | 
			
		||||
):
 | 
			
		||||
    serializer_classes = {
 | 
			
		||||
        'default': SuperConnectionTokenSerializer,
 | 
			
		||||
    }
 | 
			
		||||
    rbac_perms = {
 | 
			
		||||
        'create': 'authentication.add_superconnectiontoken',
 | 
			
		||||
        'renewal': 'authentication.add_superconnectiontoken'
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @action(methods=[PATCH], detail=False)
 | 
			
		||||
    def renewal(self, request, *args, **kwargs):
 | 
			
		||||
        """ 续期 Token """
 | 
			
		||||
        token = request.data.get('token', '')
 | 
			
		||||
        data = self.renewal_token(token)
 | 
			
		||||
        status_code = 200 if data.get('ok') else 404
 | 
			
		||||
        return Response(data=data, status=status_code)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,24 +13,16 @@ __all__ = [
 | 
			
		|||
    'ConnectionTokenUserSerializer', 'ConnectionTokenFilterRuleSerializer',
 | 
			
		||||
    'ConnectionTokenAssetSerializer', 'ConnectionTokenSystemUserSerializer',
 | 
			
		||||
    'ConnectionTokenDomainSerializer', 'ConnectionTokenRemoteAppSerializer',
 | 
			
		||||
    'ConnectionTokenGatewaySerializer', 'ConnectionTokenSecretSerializer'
 | 
			
		||||
    'ConnectionTokenGatewaySerializer', 'ConnectionTokenSecretSerializer',
 | 
			
		||||
    'SuperConnectionTokenSerializer'
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ConnectionTokenSerializer(serializers.Serializer):
 | 
			
		||||
    user = serializers.CharField(max_length=128, required=False, allow_blank=True)
 | 
			
		||||
    system_user = serializers.CharField(max_length=128, required=True)
 | 
			
		||||
    asset = serializers.CharField(max_length=128, required=False)
 | 
			
		||||
    application = serializers.CharField(max_length=128, required=False)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def validate_user(user_id):
 | 
			
		||||
        from users.models import User
 | 
			
		||||
        user = User.objects.filter(id=user_id).first()
 | 
			
		||||
        if user is None:
 | 
			
		||||
            raise serializers.ValidationError('user id not exist')
 | 
			
		||||
        return user
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def validate_system_user(system_user_id):
 | 
			
		||||
        from assets.models import SystemUser
 | 
			
		||||
| 
						 | 
				
			
			@ -65,6 +57,18 @@ class ConnectionTokenSerializer(serializers.Serializer):
 | 
			
		|||
        return super().validate(attrs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SuperConnectionTokenSerializer(ConnectionTokenSerializer):
 | 
			
		||||
    user = serializers.CharField(max_length=128, required=False, allow_blank=True)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def validate_user(user_id):
 | 
			
		||||
        from users.models import User
 | 
			
		||||
        user = User.objects.filter(id=user_id).first()
 | 
			
		||||
        if user is None:
 | 
			
		||||
            raise serializers.ValidationError('user id not exist')
 | 
			
		||||
        return user
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ConnectionTokenUserSerializer(serializers.ModelSerializer):
 | 
			
		||||
    class Meta:
 | 
			
		||||
        model = User
 | 
			
		||||
| 
						 | 
				
			
			@ -114,7 +118,6 @@ class ConnectionTokenDomainSerializer(serializers.ModelSerializer):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class ConnectionTokenFilterRuleSerializer(serializers.ModelSerializer):
 | 
			
		||||
 | 
			
		||||
    class Meta:
 | 
			
		||||
        model = CommandFilterRule
 | 
			
		||||
        fields = [
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -11,6 +11,7 @@ router.register('access-keys', api.AccessKeyViewSet, 'access-key')
 | 
			
		|||
router.register('sso', api.SSOViewSet, 'sso')
 | 
			
		||||
router.register('temp-tokens', api.TempTokenViewSet, 'temp-token')
 | 
			
		||||
router.register('connection-token', api.UserConnectionTokenViewSet, 'connection-token')
 | 
			
		||||
router.register('super-connection-token', api.UserSuperConnectionTokenViewSet, 'super-connection-token')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
urlpatterns = [
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue