diff --git a/apps/authentication/middleware.py b/apps/authentication/middleware.py index ff050f815..46aa46ded 100644 --- a/apps/authentication/middleware.py +++ b/apps/authentication/middleware.py @@ -1,17 +1,36 @@ -from django.shortcuts import redirect +from django.shortcuts import redirect, reverse +from django.http import HttpResponse class MFAMiddleware: + """ + 这个 中间件 是用来全局拦截开启了 MFA 却没有认证的,如 OIDC, CAS,使用第三方库做的登录,直接 login 了, + 所以只能在 Middleware 中控制 + """ def __init__(self, get_response): self.get_response = get_response def __call__(self, request): response = self.get_response(request) + # 没有校验 + if not request.session.get('auth_mfa_required'): + return response + # 没有认证过,证明不是从 第三方 来的 + if request.user.is_anonymous: + return response - white_urls = ['login/mfa', 'mfa/select', 'jsi18n/', '/static/'] + # 这个是 mfa 登录页需要的请求, 也得放出来, 用户其实已经在 CAS/OIDC 中完成登录了 + white_urls = [ + 'login/mfa', 'mfa/select', 'jsi18n/', '/static/', '/profile/otp', + '/logout/', '/login/' + ] for url in white_urls: if request.path.find(url) > -1: return response - if request.session.get('auth_mfa_required'): - return redirect('authentication:login-mfa') - return response + + # 因为使用 CAS/OIDC 登录的,不小心去了别的页面就回不来了 + if request.path.find('users/profile') > -1: + return HttpResponse('', status=401) + + url = reverse('authentication:login-mfa') + '?_=middleware' + return redirect(url) diff --git a/apps/authentication/mixins.py b/apps/authentication/mixins.py index 88437dd16..2167ab198 100644 --- a/apps/authentication/mixins.py +++ b/apps/authentication/mixins.py @@ -257,7 +257,8 @@ class MFAMixin: def _check_login_page_mfa_if_need(self, user): if not settings.SECURITY_MFA_IN_LOGIN_PAGE: return - self._check_if_no_active_mfa(user) + if not user.active_mfa_backends: + return request = self.request data = request.data if hasattr(request, 'data') else request.POST @@ -274,10 +275,8 @@ class MFAMixin: if not user.mfa_enabled: return - self._check_if_no_active_mfa(user) - - active_mfa_mapper = user.active_mfa_backends_mapper - raise errors.MFARequiredError(mfa_types=tuple(active_mfa_mapper.keys())) + active_mfa_names = user.active_mfa_backends_mapper.keys() + raise errors.MFARequiredError(mfa_types=tuple(active_mfa_names)) def mark_mfa_ok(self, mfa_type): self.request.session['auth_mfa'] = 1 diff --git a/apps/authentication/views/mfa.py b/apps/authentication/views/mfa.py index ec51ed63c..fd8b80e32 100644 --- a/apps/authentication/views/mfa.py +++ b/apps/authentication/views/mfa.py @@ -3,6 +3,7 @@ from __future__ import unicode_literals from django.views.generic.edit import FormView +from django.shortcuts import redirect from common.utils import get_logger from .. import forms, errors, mixins @@ -19,9 +20,15 @@ class UserLoginMFAView(mixins.AuthMixin, FormView): def get(self, *args, **kwargs): try: - self.get_user_from_session() + user = self.get_user_from_session() except errors.SessionEmptyError: - return redirect_to_guard_view() + return redirect_to_guard_view('session_empty') + + try: + self._check_if_no_active_mfa(user) + except errors.MFAUnsetError as e: + return redirect(e.url + '?_=login_mfa') + return super().get(*args, **kwargs) def form_valid(self, form): @@ -30,17 +37,17 @@ class UserLoginMFAView(mixins.AuthMixin, FormView): try: self._do_check_user_mfa(code, mfa_type) - return redirect_to_guard_view() + return redirect_to_guard_view('mfa_ok') except (errors.MFAFailedError, errors.BlockMFAError) as e: form.add_error('code', e.msg) return super().form_invalid(form) except errors.SessionEmptyError: - return redirect_to_guard_view() + return redirect_to_guard_view('session_empty') except Exception as e: logger.error(e) import traceback traceback.print_exc() - return redirect_to_guard_view() + return redirect_to_guard_view('unexpect') def get_context_data(self, **kwargs): user = self.get_user_from_session() diff --git a/apps/authentication/views/utils.py b/apps/authentication/views/utils.py index 182d7390b..63a1d76c6 100644 --- a/apps/authentication/views/utils.py +++ b/apps/authentication/views/utils.py @@ -3,6 +3,6 @@ from django.shortcuts import reverse, redirect -def redirect_to_guard_view(): - continue_url = reverse('authentication:login-guard') +def redirect_to_guard_view(comment=''): + continue_url = reverse('authentication:login-guard') + '?_=' + comment return redirect(continue_url) diff --git a/apps/static/css/otp.css b/apps/static/css/otp.css index 4c9ed2606..f78916ac1 100644 --- a/apps/static/css/otp.css +++ b/apps/static/css/otp.css @@ -136,6 +136,10 @@ article ul li:last-child{ border-radius: 6px; color: white; } + +.next:hover { + color: white; +} /*绑定TOTP*/ /*版权信息*/ diff --git a/apps/users/templates/users/user_otp_check_password.html b/apps/users/templates/users/user_otp_check_password.html index edb9b2519..501d6eb77 100644 --- a/apps/users/templates/users/user_otp_check_password.html +++ b/apps/users/templates/users/user_otp_check_password.html @@ -7,36 +7,8 @@ {% endblock %} {% block content %} -