mirror of https://github.com/jumpserver/jumpserver
				
				
				
			
		
			
				
	
	
		
			239 lines
		
	
	
		
			8.0 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			239 lines
		
	
	
		
			8.0 KiB
		
	
	
	
		
			Python
		
	
	
# -*- coding: utf-8 -*-
 | 
						|
#
 | 
						|
 | 
						|
import threading
 | 
						|
from rest_framework import generics
 | 
						|
from rest_framework.views import Response, APIView
 | 
						|
from orgs.models import Organization
 | 
						|
from django.utils.translation import ugettext_lazy as _
 | 
						|
from django.conf import settings
 | 
						|
 | 
						|
from ..models import Setting
 | 
						|
from ..utils import (
 | 
						|
    LDAPServerUtil, LDAPCacheUtil, LDAPImportUtil, LDAPSyncUtil,
 | 
						|
    LDAP_USE_CACHE_FLAGS, LDAPTestUtil
 | 
						|
)
 | 
						|
from ..tasks import sync_ldap_user
 | 
						|
from common.utils import get_logger, is_uuid
 | 
						|
from ..serializers import (
 | 
						|
    LDAPTestConfigSerializer, LDAPUserSerializer,
 | 
						|
    LDAPTestLoginSerializer
 | 
						|
)
 | 
						|
from orgs.utils import current_org
 | 
						|
from users.models import User
 | 
						|
 | 
						|
logger = get_logger(__file__)
 | 
						|
 | 
						|
 | 
						|
class LDAPTestingConfigAPI(APIView):
 | 
						|
    serializer_class = LDAPTestConfigSerializer
 | 
						|
    perm_model = Setting
 | 
						|
    rbac_perms = {
 | 
						|
        'POST': 'settings.change_auth'
 | 
						|
    }
 | 
						|
 | 
						|
    def post(self, request):
 | 
						|
        serializer = self.serializer_class(data=request.data)
 | 
						|
        if not serializer.is_valid():
 | 
						|
            return Response({"error": str(serializer.errors)}, status=400)
 | 
						|
        config = self.get_ldap_config(serializer)
 | 
						|
        ok, msg = LDAPTestUtil(config).test_config()
 | 
						|
        status = 200 if ok else 400
 | 
						|
        return Response(msg, status=status)
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def get_ldap_config(serializer):
 | 
						|
        server_uri = serializer.validated_data["AUTH_LDAP_SERVER_URI"]
 | 
						|
        bind_dn = serializer.validated_data["AUTH_LDAP_BIND_DN"]
 | 
						|
        password = serializer.validated_data["AUTH_LDAP_BIND_PASSWORD"]
 | 
						|
        use_ssl = serializer.validated_data.get("AUTH_LDAP_START_TLS", False)
 | 
						|
        search_ou = serializer.validated_data["AUTH_LDAP_SEARCH_OU"]
 | 
						|
        search_filter = serializer.validated_data["AUTH_LDAP_SEARCH_FILTER"]
 | 
						|
        attr_map = serializer.validated_data["AUTH_LDAP_USER_ATTR_MAP"]
 | 
						|
        auth_ldap = serializer.validated_data.get('AUTH_LDAP', False)
 | 
						|
 | 
						|
        if not password:
 | 
						|
            password = settings.AUTH_LDAP_BIND_PASSWORD
 | 
						|
 | 
						|
        config = {
 | 
						|
            'server_uri': server_uri,
 | 
						|
            'bind_dn': bind_dn,
 | 
						|
            'password': password,
 | 
						|
            'use_ssl': use_ssl,
 | 
						|
            'search_ou': search_ou,
 | 
						|
            'search_filter': search_filter,
 | 
						|
            'attr_map': attr_map,
 | 
						|
            'auth_ldap': auth_ldap
 | 
						|
        }
 | 
						|
        return config
 | 
						|
 | 
						|
 | 
						|
class LDAPTestingLoginAPI(APIView):
 | 
						|
    serializer_class = LDAPTestLoginSerializer
 | 
						|
    perm_model = Setting
 | 
						|
    rbac_perms = {
 | 
						|
        'POST': 'settings.change_auth'
 | 
						|
    }
 | 
						|
 | 
						|
    def post(self, request):
 | 
						|
        serializer = self.serializer_class(data=request.data)
 | 
						|
        if not serializer.is_valid():
 | 
						|
            return Response({"error": str(serializer.errors)}, status=400)
 | 
						|
        username = serializer.validated_data['username']
 | 
						|
        password = serializer.validated_data['password']
 | 
						|
        ok, msg = LDAPTestUtil().test_login(username, password)
 | 
						|
        status = 200 if ok else 400
 | 
						|
        return Response(msg, status=status)
 | 
						|
 | 
						|
 | 
						|
class LDAPUserListApi(generics.ListAPIView):
 | 
						|
    serializer_class = LDAPUserSerializer
 | 
						|
    perm_model = Setting
 | 
						|
    rbac_perms = {
 | 
						|
        'list': 'settings.change_auth'
 | 
						|
    }
 | 
						|
 | 
						|
    def get_queryset_from_cache(self):
 | 
						|
        search_value = self.request.query_params.get('search')
 | 
						|
        users = LDAPCacheUtil().search(search_value=search_value)
 | 
						|
        return users
 | 
						|
 | 
						|
    def get_queryset_from_server(self):
 | 
						|
        search_value = self.request.query_params.get('search')
 | 
						|
        users = LDAPServerUtil().search(search_value=search_value)
 | 
						|
        return users
 | 
						|
 | 
						|
    def get_queryset(self):
 | 
						|
        if hasattr(self, 'swagger_fake_view'):
 | 
						|
            return User.objects.none()
 | 
						|
        cache_police = self.request.query_params.get('cache_police', True)
 | 
						|
        if cache_police in LDAP_USE_CACHE_FLAGS:
 | 
						|
            users = self.get_queryset_from_cache()
 | 
						|
        else:
 | 
						|
            users = self.get_queryset_from_server()
 | 
						|
        return users
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def processing_queryset(queryset):
 | 
						|
        db_username_list = User.objects.all().values_list('username', flat=True)
 | 
						|
        for q in queryset:
 | 
						|
            q['id'] = q['username']
 | 
						|
            q['existing'] = q['username'] in db_username_list
 | 
						|
        return queryset
 | 
						|
 | 
						|
    def sort_queryset(self, queryset):
 | 
						|
        order_by = self.request.query_params.get('order')
 | 
						|
        if not order_by:
 | 
						|
            order_by = 'existing'
 | 
						|
        if order_by.startswith('-'):
 | 
						|
            order_by = order_by.lstrip('-')
 | 
						|
            reverse = True
 | 
						|
        else:
 | 
						|
            reverse = False
 | 
						|
        queryset = sorted(queryset, key=lambda x: x[order_by], reverse=reverse)
 | 
						|
        return queryset
 | 
						|
 | 
						|
    def filter_queryset(self, queryset):
 | 
						|
        if queryset is None:
 | 
						|
            return queryset
 | 
						|
        queryset = self.processing_queryset(queryset)
 | 
						|
        queryset = self.sort_queryset(queryset)
 | 
						|
        return queryset
 | 
						|
 | 
						|
    def list(self, request, *args, **kwargs):
 | 
						|
        cache_police = self.request.query_params.get('cache_police', True)
 | 
						|
        # 不是用缓存
 | 
						|
        if cache_police not in LDAP_USE_CACHE_FLAGS:
 | 
						|
            return super().list(request, *args, **kwargs)
 | 
						|
 | 
						|
        try:
 | 
						|
            queryset = self.get_queryset()
 | 
						|
        except Exception as e:
 | 
						|
            data = {'error': str(e)}
 | 
						|
            return Response(data=data, status=400)
 | 
						|
 | 
						|
        # 缓存有数据
 | 
						|
        if queryset is not None:
 | 
						|
            return super().list(request, *args, **kwargs)
 | 
						|
 | 
						|
        sync_util = LDAPSyncUtil()
 | 
						|
        # 还没有同步任务
 | 
						|
        if sync_util.task_no_start:
 | 
						|
            # 任务外部设置 task running 状态
 | 
						|
            sync_util.set_task_status(sync_util.TASK_STATUS_IS_RUNNING)
 | 
						|
            t = threading.Thread(target=sync_ldap_user)
 | 
						|
            t.start()
 | 
						|
            data = {'msg': _('Synchronization start, please wait.')}
 | 
						|
            return Response(data=data, status=409)
 | 
						|
        # 同步任务正在执行
 | 
						|
        if sync_util.task_is_running:
 | 
						|
            data = {'msg': _('Synchronization is running, please wait.')}
 | 
						|
            return Response(data=data, status=409)
 | 
						|
        # 同步任务执行结束
 | 
						|
        if sync_util.task_is_over:
 | 
						|
            msg = sync_util.get_task_error_msg()
 | 
						|
            data = {'error': _('Synchronization error: {}'.format(msg))}
 | 
						|
            return Response(data=data, status=400)
 | 
						|
 | 
						|
        return super().list(request, *args, **kwargs)
 | 
						|
 | 
						|
 | 
						|
class LDAPUserImportAPI(APIView):
 | 
						|
    perm_model = Setting
 | 
						|
    rbac_perms = {
 | 
						|
        'POST': 'settings.change_auth'
 | 
						|
    }
 | 
						|
 | 
						|
    def get_org(self):
 | 
						|
        org_id = self.request.data.get('org_id')
 | 
						|
        if is_uuid(org_id):
 | 
						|
            org = Organization.objects.get(id=org_id)
 | 
						|
        else:
 | 
						|
            org = current_org
 | 
						|
        return org
 | 
						|
 | 
						|
    def get_ldap_users(self):
 | 
						|
        username_list = self.request.data.get('username_list', [])
 | 
						|
        cache_police = self.request.query_params.get('cache_police', True)
 | 
						|
        if '*' in username_list:
 | 
						|
            users = LDAPServerUtil().search()
 | 
						|
        elif cache_police in LDAP_USE_CACHE_FLAGS:
 | 
						|
            users = LDAPCacheUtil().search(search_users=username_list)
 | 
						|
        else:
 | 
						|
            users = LDAPServerUtil().search(search_users=username_list)
 | 
						|
        return users
 | 
						|
 | 
						|
    def post(self, request):
 | 
						|
        try:
 | 
						|
            users = self.get_ldap_users()
 | 
						|
        except Exception as e:
 | 
						|
            return Response({'error': str(e)}, status=400)
 | 
						|
 | 
						|
        if users is None:
 | 
						|
            return Response({'msg': _('Get ldap users is None')}, status=400)
 | 
						|
 | 
						|
        org = self.get_org()
 | 
						|
        errors = LDAPImportUtil().perform_import(users, org)
 | 
						|
        if errors:
 | 
						|
            return Response({'errors': errors}, status=400)
 | 
						|
 | 
						|
        count = users if users is None else len(users)
 | 
						|
        return Response({
 | 
						|
            'msg': _('Imported {} users successfully (Organization: {})').format(count, org)
 | 
						|
        })
 | 
						|
 | 
						|
 | 
						|
class LDAPCacheRefreshAPI(generics.RetrieveAPIView):
 | 
						|
    perm_model = Setting
 | 
						|
    rbac_perms = {
 | 
						|
        'retrieve': 'settings.change_auth'
 | 
						|
    }
 | 
						|
 | 
						|
    def retrieve(self, request, *args, **kwargs):
 | 
						|
        try:
 | 
						|
            LDAPSyncUtil().clear_cache()
 | 
						|
        except Exception as e:
 | 
						|
            logger.error(str(e))
 | 
						|
            return Response(data={'msg': str(e)}, status=400)
 | 
						|
        return Response(data={'msg': 'success'})
 |