diff --git a/apps/authentication/api/sso.py b/apps/authentication/api/sso.py index 6e48bda41..a4cd6a67d 100644 --- a/apps/authentication/api/sso.py +++ b/apps/authentication/api/sso.py @@ -5,11 +5,13 @@ from django.conf import settings from django.contrib.auth import login from django.http.response import HttpResponseRedirect from rest_framework import serializers +from rest_framework import status from rest_framework.decorators import action from rest_framework.permissions import AllowAny from rest_framework.request import Request from rest_framework.response import Response +from authentication.errors import ACLError from common.api import JMSGenericViewSet from common.const.http import POST, GET from common.permissions import OnlySuperUser @@ -17,7 +19,10 @@ from common.serializers import EmptySerializer from common.utils import reverse, safe_next_url from common.utils.timezone import utc_now from users.models import User -from ..errors import SSOAuthClosed +from users.utils import LoginBlockUtil, LoginIpBlockUtil +from ..errors import ( + SSOAuthClosed, AuthFailedError, LoginConfirmBaseError, SSOAuthKeyTTLError +) from ..filters import AuthKeyQueryDeclaration from ..mixins import AuthMixin from ..models import SSOToken @@ -63,31 +68,58 @@ class SSOViewSet(AuthMixin, JMSGenericViewSet): 此接口违反了 `Restful` 的规范 `GET` 应该是安全的方法,但此接口是不安全的 """ + status_code = status.HTTP_400_BAD_REQUEST request.META['HTTP_X_JMS_LOGIN_TYPE'] = 'W' authkey = request.query_params.get(AUTH_KEY) next_url = request.query_params.get(NEXT_URL) if not next_url or not next_url.startswith('/'): next_url = reverse('index') - if not authkey: - raise serializers.ValidationError("authkey is required") - try: + if not authkey: + raise serializers.ValidationError("authkey is required") + authkey = UUID(authkey) token = SSOToken.objects.get(authkey=authkey, expired=False) - # 先过期,只能访问这一次 + except (ValueError, SSOToken.DoesNotExist, serializers.ValidationError) as e: + error_msg = str(e) + self.send_auth_signal(success=False, reason=error_msg) + return Response({'error': error_msg}, status=status_code) + + error_msg = None + user = token.user + username = user.username + ip = self.get_request_ip() + + try: + if (utc_now().timestamp() - token.date_created.timestamp()) > settings.AUTH_SSO_AUTHKEY_TTL: + raise SSOAuthKeyTTLError() + + self._check_is_block(username, True) + self._check_only_allow_exists_user_auth(username) + self._check_login_acl(user, ip) + self.check_user_login_confirm_if_need(user) + + self.request.session['auth_backend'] = settings.AUTH_BACKEND_SSO + login(self.request, user, settings.AUTH_BACKEND_SSO) + self.send_auth_signal(success=True, user=user) + self.mark_mfa_ok('otp', user) + + LoginIpBlockUtil(ip).clean_block_if_need() + LoginBlockUtil(username, ip).clean_failed_count() + self.clear_auth_mark() + except (ACLError, LoginConfirmBaseError): # 无需记录日志 + pass + except (AuthFailedError, SSOAuthKeyTTLError) as e: + error_msg = e.msg + except Exception as e: + error_msg = str(e) + finally: token.expired = True token.save() - except (ValueError, SSOToken.DoesNotExist): - self.send_auth_signal(success=False, reason='authkey_invalid') - return HttpResponseRedirect(next_url) - # 判断是否过期 - if (utc_now().timestamp() - token.date_created.timestamp()) > settings.AUTH_SSO_AUTHKEY_TTL: - self.send_auth_signal(success=False, reason='authkey_timeout') + if error_msg: + self.send_auth_signal(success=False, username=username, reason=error_msg) + return Response({'error': error_msg}, status=status_code) + else: return HttpResponseRedirect(next_url) - - user = token.user - login(self.request, user, settings.AUTH_BACKEND_SSO) - self.send_auth_signal(success=True, user=user) - return HttpResponseRedirect(next_url) diff --git a/apps/authentication/errors/failed.py b/apps/authentication/errors/failed.py index f6d8004c6..729d93b6d 100644 --- a/apps/authentication/errors/failed.py +++ b/apps/authentication/errors/failed.py @@ -52,6 +52,10 @@ class AuthFailedError(Exception): return str(self.msg) +class SSOAuthKeyTTLError(Exception): + msg = 'sso_authkey_timeout' + + class BlockGlobalIpLoginError(AuthFailedError): error = 'block_global_ip_login' diff --git a/apps/authentication/mixins.py b/apps/authentication/mixins.py index 31cb1dc19..721a189d7 100644 --- a/apps/authentication/mixins.py +++ b/apps/authentication/mixins.py @@ -363,7 +363,6 @@ class AuthACLMixin: if acl.is_action(acl.ActionChoices.notice): self.request.session['auth_notice_required'] = '1' self.request.session['auth_acl_id'] = str(acl.id) - return def _check_third_party_login_acl(self): request = self.request