feat: MFA 登录次数限制

pull/5922/head
xinwen 2021-04-08 12:47:49 +08:00 committed by 老广
parent 8895763ab4
commit 607b7fd29f
7 changed files with 143 additions and 70 deletions

View File

@ -29,7 +29,7 @@ class MFAChallengeApi(AuthMixin, CreateAPIView):
if not valid: if not valid:
self.request.session['auth_mfa'] = '' self.request.session['auth_mfa'] = ''
raise errors.MFAFailedError( raise errors.MFAFailedError(
username=user.username, request=self.request username=user.username, request=self.request, ip=self.get_request_ip()
) )
else: else:
self.request.session['auth_mfa'] = '1' self.request.session['auth_mfa'] = '1'

View File

@ -6,9 +6,7 @@ from django.conf import settings
from common.exceptions import JMSException from common.exceptions import JMSException
from .signals import post_auth_failed from .signals import post_auth_failed
from users.utils import ( from users.utils import LoginBlockUtil, MFABlockUtils
increase_login_failed_count, get_login_failed_count
)
reason_password_failed = 'password_failed' reason_password_failed = 'password_failed'
reason_password_decrypt_failed = 'password_decrypt_failed' reason_password_decrypt_failed = 'password_decrypt_failed'
@ -52,7 +50,15 @@ block_login_msg = _(
"The account has been locked " "The account has been locked "
"(please contact admin to unlock it or try again after {} minutes)" "(please contact admin to unlock it or try again after {} minutes)"
) )
mfa_failed_msg = _("MFA code invalid, or ntp sync server time") block_mfa_msg = _(
"The account has been locked "
"(please contact admin to unlock it or try again after {} minutes)"
)
mfa_failed_msg = _(
"MFA code invalid, or ntp sync server time, "
"You can also try {times_try} times "
"(The account will be temporarily locked for {block_time} minutes)"
)
mfa_required_msg = _("MFA required") mfa_required_msg = _("MFA required")
mfa_unset_msg = _("MFA not set, please set it first") mfa_unset_msg = _("MFA not set, please set it first")
@ -80,7 +86,7 @@ class AuthFailedNeedBlockMixin:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
increase_login_failed_count(self.username, self.ip) LoginBlockUtil(self.username, self.ip).incr_failed_count()
class AuthFailedError(Exception): class AuthFailedError(Exception):
@ -107,13 +113,12 @@ class AuthFailedError(Exception):
class CredentialError(AuthFailedNeedLogMixin, AuthFailedNeedBlockMixin, AuthFailedError): class CredentialError(AuthFailedNeedLogMixin, AuthFailedNeedBlockMixin, AuthFailedError):
def __init__(self, error, username, ip, request): def __init__(self, error, username, ip, request):
super().__init__(error=error, username=username, ip=ip, request=request) super().__init__(error=error, username=username, ip=ip, request=request)
times_up = settings.SECURITY_LOGIN_LIMIT_COUNT util = LoginBlockUtil(username, ip)
times_failed = get_login_failed_count(username, ip) times_remainder = util.get_remainder_times()
times_try = int(times_up) - int(times_failed)
block_time = settings.SECURITY_LOGIN_LIMIT_TIME block_time = settings.SECURITY_LOGIN_LIMIT_TIME
default_msg = invalid_login_msg.format( default_msg = invalid_login_msg.format(
times_try=times_try, block_time=block_time times_try=times_remainder, block_time=block_time
) )
if error == reason_password_failed: if error == reason_password_failed:
self.msg = default_msg self.msg = default_msg
@ -123,12 +128,32 @@ class CredentialError(AuthFailedNeedLogMixin, AuthFailedNeedBlockMixin, AuthFail
class MFAFailedError(AuthFailedNeedLogMixin, AuthFailedError): class MFAFailedError(AuthFailedNeedLogMixin, AuthFailedError):
error = reason_mfa_failed error = reason_mfa_failed
msg = mfa_failed_msg msg: str
def __init__(self, username, request): def __init__(self, username, request, ip):
util = MFABlockUtils(username, ip)
util.incr_failed_count()
times_remainder = util.get_remainder_times()
block_time = settings.SECURITY_LOGIN_LIMIT_TIME
if times_remainder:
self.msg = mfa_failed_msg.format(
times_try=times_remainder, block_time=block_time
)
else:
self.msg = block_mfa_msg.format(settings.SECURITY_LOGIN_LIMIT_TIME)
super().__init__(username=username, request=request) super().__init__(username=username, request=request)
class BlockMFAError(AuthFailedNeedLogMixin, AuthFailedError):
error = 'block_mfa'
def __init__(self, username, request, ip):
self.msg = block_mfa_msg.format(settings.SECURITY_LOGIN_LIMIT_TIME)
super().__init__(username=username, request=request, ip=ip)
class MFAUnsetError(AuthFailedNeedLogMixin, AuthFailedError): class MFAUnsetError(AuthFailedNeedLogMixin, AuthFailedError):
error = reason_mfa_unset error = reason_mfa_unset
msg = mfa_unset_msg msg = mfa_unset_msg

View File

@ -15,9 +15,7 @@ from django.shortcuts import reverse
from common.utils import get_object_or_none, get_request_ip, get_logger, bulk_get from common.utils import get_object_or_none, get_request_ip, get_logger, bulk_get
from users.models import User from users.models import User
from users.utils import ( from users.utils import LoginBlockUtil, MFABlockUtils
is_block_login, clean_failed_count
)
from . import errors from . import errors
from .utils import rsa_decrypt from .utils import rsa_decrypt
from .signals import post_auth_success, post_auth_failed from .signals import post_auth_success, post_auth_failed
@ -117,7 +115,7 @@ class AuthMixin:
else: else:
username = self.request.POST.get("username") username = self.request.POST.get("username")
ip = self.get_request_ip() ip = self.get_request_ip()
if is_block_login(username, ip): if LoginBlockUtil(username, ip).is_block():
logger.warn('Ip was blocked' + ': ' + username + ':' + ip) logger.warn('Ip was blocked' + ': ' + username + ':' + ip)
exception = errors.BlockLoginError(username=username, ip=ip) exception = errors.BlockLoginError(username=username, ip=ip)
if raise_exception: if raise_exception:
@ -197,7 +195,7 @@ class AuthMixin:
self._check_password_require_reset_or_not(user) self._check_password_require_reset_or_not(user)
self._check_passwd_is_too_simple(user, password) self._check_passwd_is_too_simple(user, password)
clean_failed_count(username, ip) LoginBlockUtil(username, ip).clean_failed_count()
request.session['auth_password'] = 1 request.session['auth_password'] = 1
request.session['user_id'] = str(user.id) request.session['user_id'] = str(user.id)
request.session['auto_login'] = auto_login request.session['auto_login'] = auto_login
@ -253,15 +251,34 @@ class AuthMixin:
raise errors.MFAUnsetError(user, self.request, url) raise errors.MFAUnsetError(user, self.request, url)
raise errors.MFARequiredError() raise errors.MFARequiredError()
def check_user_mfa(self, code): def mark_mfa_ok(self):
user = self.get_user_from_session()
ok = user.check_mfa(code)
if ok:
self.request.session['auth_mfa'] = 1 self.request.session['auth_mfa'] = 1
self.request.session['auth_mfa_time'] = time.time() self.request.session['auth_mfa_time'] = time.time()
self.request.session['auth_mfa_type'] = 'otp' self.request.session['auth_mfa_type'] = 'otp'
def check_mfa_is_block(self, username, ip, raise_exception=True):
if MFABlockUtils(username, ip).is_block():
logger.warn('Ip was blocked' + ': ' + username + ':' + ip)
exception = errors.BlockMFAError(username=username, request=self.request, ip=ip)
if raise_exception:
raise exception
else:
return exception
def check_user_mfa(self, code):
user = self.get_user_from_session()
ip = self.get_request_ip()
self.check_mfa_is_block(user.username, ip)
ok = user.check_mfa(code)
if ok:
self.mark_mfa_ok()
return return
raise errors.MFAFailedError(username=user.username, request=self.request)
raise errors.MFAFailedError(
username=user.username,
request=self.request,
ip=ip
)
def get_ticket(self): def get_ticket(self):
from tickets.models import Ticket from tickets.models import Ticket

View File

@ -22,10 +22,12 @@ class UserLoginOtpView(mixins.AuthMixin, FormView):
try: try:
self.check_user_mfa(otp_code) self.check_user_mfa(otp_code)
return redirect_to_guard_view() return redirect_to_guard_view()
except errors.MFAFailedError as e: except (errors.MFAFailedError, errors.BlockMFAError) as e:
form.add_error('otp_code', e.msg) form.add_error('otp_code', e.msg)
return super().form_invalid(form) return super().form_invalid(form)
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
import traceback
traceback.print_exception()
return redirect_to_guard_view() return redirect_to_guard_view()

View File

@ -16,7 +16,7 @@ from common.mixins import CommonApiMixin
from common.utils import get_logger from common.utils import get_logger
from orgs.utils import current_org from orgs.utils import current_org
from orgs.models import ROLE as ORG_ROLE, OrganizationMember from orgs.models import ROLE as ORG_ROLE, OrganizationMember
from users.utils import send_reset_mfa_mail from users.utils import send_reset_mfa_mail, LoginBlockUtil, MFABlockUtils
from .. import serializers from .. import serializers
from ..serializers import UserSerializer, UserRetrieveSerializer, MiniUserSerializer, InviteSerializer from ..serializers import UserSerializer, UserRetrieveSerializer, MiniUserSerializer, InviteSerializer
from .mixins import UserQuerysetMixin from .mixins import UserQuerysetMixin
@ -190,16 +190,12 @@ class UserChangePasswordApi(UserQuerysetMixin, generics.RetrieveUpdateAPIView):
class UserUnblockPKApi(UserQuerysetMixin, generics.UpdateAPIView): class UserUnblockPKApi(UserQuerysetMixin, generics.UpdateAPIView):
permission_classes = (IsOrgAdmin,) permission_classes = (IsOrgAdmin,)
serializer_class = serializers.UserSerializer serializer_class = serializers.UserSerializer
key_prefix_limit = "_LOGIN_LIMIT_{}_{}"
key_prefix_block = "_LOGIN_BLOCK_{}"
def perform_update(self, serializer): def perform_update(self, serializer):
user = self.get_object() user = self.get_object()
username = user.username if user else '' username = user.username if user else ''
key_limit = self.key_prefix_limit.format(username, '*') LoginBlockUtil.unblock_user(username)
key_block = self.key_prefix_block.format(username) MFABlockUtils.unblock_user(username)
cache.delete_pattern(key_limit)
cache.delete(key_block)
class UserResetOTPApi(UserQuerysetMixin, generics.RetrieveAPIView): class UserResetOTPApi(UserQuerysetMixin, generics.RetrieveAPIView):

View File

@ -669,10 +669,13 @@ class User(AuthMixin, TokenMixin, RoleMixin, MFAMixin, AbstractUser):
@property @property
def login_blocked(self): def login_blocked(self):
key_prefix_block = "_LOGIN_BLOCK_{}" from users.utils import LoginBlockUtil, MFABlockUtils
key_block = key_prefix_block.format(self.username) if LoginBlockUtil.is_user_block(self.username):
blocked = bool(cache.get(key_block)) return True
return blocked if MFABlockUtils.is_user_block(self.username):
return True
return False
def delete(self, using=None, keep_parents=False): def delete(self, using=None, keep_parents=False):
if self.pk == 1 or self.username == 'admin': if self.pk == 1 or self.username == 'admin':

View File

@ -322,50 +322,80 @@ def check_password_rules(password):
return bool(match_obj) return bool(match_obj)
key_prefix_limit = "_LOGIN_LIMIT_{}_{}" class BlockUtil:
key_prefix_block = "_LOGIN_BLOCK_{}" BLOCK_KEY_TMPL: str
def __init__(self, username):
self.block_key = self.BLOCK_KEY_TMPL.format(username)
self.key_ttl = int(settings.SECURITY_LOGIN_LIMIT_TIME) * 60
def block(self):
cache.set(self.block_key, True, self.key_ttl)
def is_block(self):
return bool(cache.get(self.block_key))
# def increase_login_failed_count(key_limit, key_block): class BlockUtilBase:
def increase_login_failed_count(username, ip): LIMIT_KEY_TMPL: str
key_limit = key_prefix_limit.format(username, ip) BLOCK_KEY_TMPL: str
count = cache.get(key_limit)
count = count + 1 if count else 1
limit_time = settings.SECURITY_LOGIN_LIMIT_TIME def __init__(self, username, ip):
cache.set(key_limit, count, int(limit_time)*60) self.username = username
self.ip = ip
self.limit_key = self.LIMIT_KEY_TMPL.format(username, ip)
self.block_key = self.BLOCK_KEY_TMPL.format(username)
self.key_ttl = int(settings.SECURITY_LOGIN_LIMIT_TIME) * 60
def get_remainder_times(self):
times_up = settings.SECURITY_LOGIN_LIMIT_COUNT
times_failed = self.get_failed_count()
times_remainder = int(times_up) - int(times_failed)
return times_remainder
def get_login_failed_count(username, ip): def incr_failed_count(self):
key_limit = key_prefix_limit.format(username, ip) limit_key = self.limit_key
count = cache.get(key_limit, 0) count = cache.get(limit_key, 0)
return count count += 1
cache.set(limit_key, count, self.key_ttl)
def clean_failed_count(username, ip):
key_limit = key_prefix_limit.format(username, ip)
key_block = key_prefix_block.format(username)
cache.delete(key_limit)
cache.delete(key_block)
def is_block_login(username, ip):
count = get_login_failed_count(username, ip)
key_block = key_prefix_block.format(username)
limit_count = settings.SECURITY_LOGIN_LIMIT_COUNT limit_count = settings.SECURITY_LOGIN_LIMIT_COUNT
limit_time = settings.SECURITY_LOGIN_LIMIT_TIME
if count >= limit_count: if count >= limit_count:
cache.set(key_block, 1, int(limit_time)*60) cache.set(self.block_key, True, self.key_ttl)
if count and count >= limit_count:
return True def get_failed_count(self):
count = cache.get(self.limit_key, 0)
return count
def clean_failed_count(self):
cache.delete(self.limit_key)
cache.delete(self.block_key)
@classmethod
def unblock_user(cls, username):
key_limit = cls.LIMIT_KEY_TMPL.format(username, '*')
key_block = cls.BLOCK_KEY_TMPL.format(username)
# Redis 尽量不要用通配
cache.delete_pattern(key_limit)
cache.delete(key_block)
@classmethod
def is_user_block(cls, username):
block_key = cls.BLOCK_KEY_TMPL.format(username)
return bool(cache.get(block_key))
def is_block(self):
return bool(cache.get(self.block_key))
def is_need_unblock(key_block): class LoginBlockUtil(BlockUtilBase):
if not cache.get(key_block): LIMIT_KEY_TMPL = "_LOGIN_LIMIT_{}_{}"
return False BLOCK_KEY_TMPL = "_LOGIN_BLOCK_{}"
return True
class MFABlockUtils(BlockUtilBase):
LIMIT_KEY_TMPL = "_MFA_LIMIT_{}_{}"
BLOCK_KEY_TMPL = "_MFA_BLOCK_{}"
def construct_user_email(username, email): def construct_user_email(username, email):