diff --git a/apps/acls/serializers/login_acl.py b/apps/acls/serializers/login_acl.py index cf40e078b..a699ae1ea 100644 --- a/apps/acls/serializers/login_acl.py +++ b/apps/acls/serializers/login_acl.py @@ -2,6 +2,7 @@ from django.utils.translation import ugettext as _ from rest_framework import serializers from common.drf.serializers import BulkModelSerializer from common.drf.serializers import MethodSerializer +from jumpserver.utils import has_valid_xpack_license from ..models import LoginACL from .rules import RuleSerializer @@ -40,12 +41,11 @@ class LoginACLSerializer(BulkModelSerializer): self.set_action_choices() def set_action_choices(self): - from xpack.plugins.license.models import License action = self.fields.get('action') if not action: return choices = action._choices - if not License.has_valid_license(): + if not has_valid_xpack_license(): choices.pop(LoginACL.ActionChoices.confirm, None) action._choices = choices diff --git a/apps/assets/api/system_user.py b/apps/assets/api/system_user.py index 213858e13..b9f1007d9 100644 --- a/apps/assets/api/system_user.py +++ b/apps/assets/api/system_user.py @@ -8,6 +8,7 @@ from orgs.mixins.api import OrgBulkModelViewSet from orgs.mixins import generics from common.mixins.api import SuggestionMixin from orgs.utils import tmp_to_root_org +from rest_framework.decorators import action from ..models import SystemUser, Asset from .. import serializers from ..serializers import SystemUserWithAuthInfoSerializer, SystemUserTempAuthSerializer @@ -45,6 +46,32 @@ class SystemUserViewSet(SuggestionMixin, OrgBulkModelViewSet): ordering = ('name', ) permission_classes = (IsOrgAdminOrAppUser,) + @action(methods=['get'], detail=False, url_path='su-from') + def su_from(self, request, *args, **kwargs): + """ API 获取可选的 su_from 系统用户""" + queryset = self.filter_queryset(self.get_queryset()) + queryset = queryset.filter( + protocol=SystemUser.Protocol.ssh, login_mode=SystemUser.LOGIN_AUTO + ) + return self.get_paginate_response_if_need(queryset) + + @action(methods=['get'], detail=True, url_path='su-to') + def su_to(self, request, *args, **kwargs): + """ 获取系统用户的所有 su_to 系统用户 """ + pk = kwargs.get('pk') + system_user = get_object_or_404(SystemUser, pk=pk) + queryset = system_user.su_to.all() + queryset = self.filter_queryset(queryset) + return self.get_paginate_response_if_need(queryset) + + def get_paginate_response_if_need(self, queryset): + page = self.paginate_queryset(queryset) + if page is not None: + serializer = self.get_serializer(page, many=True) + return self.get_paginated_response(serializer.data) + serializer = self.get_serializer(queryset, many=True) + return Response(serializer.data) + class SystemUserAuthInfoApi(generics.RetrieveUpdateDestroyAPIView): """ diff --git a/apps/assets/migrations/0079_auto_20211102_1922.py b/apps/assets/migrations/0079_auto_20211102_1922.py new file mode 100644 index 000000000..f0a05dc06 --- /dev/null +++ b/apps/assets/migrations/0079_auto_20211102_1922.py @@ -0,0 +1,28 @@ +# Generated by Django 3.1.12 on 2021-11-02 11:22 + +from django.db import migrations + + +def create_internal_platform(apps, schema_editor): + model = apps.get_model("assets", "Platform") + db_alias = schema_editor.connection.alias + type_platforms = ( + ('Windows-RDP', 'Windows', {'security': 'rdp'}), + ('Windows-TLS', 'Windows', {'security': 'tls'}), + ) + for name, base, meta in type_platforms: + defaults = {'name': name, 'base': base, 'meta': meta, 'internal': True} + model.objects.using(db_alias).update_or_create( + name=name, defaults=defaults + ) + + +class Migration(migrations.Migration): + + dependencies = [ + ('assets', '0078_auto_20211014_2209'), + ] + + operations = [ + migrations.RunPython(create_internal_platform) + ] diff --git a/apps/assets/migrations/0080_auto_20211104_1347.py b/apps/assets/migrations/0080_auto_20211104_1347.py new file mode 100644 index 000000000..75210149e --- /dev/null +++ b/apps/assets/migrations/0080_auto_20211104_1347.py @@ -0,0 +1,24 @@ +# Generated by Django 3.1.13 on 2021-11-04 05:47 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('assets', '0079_auto_20211102_1922'), + ] + + operations = [ + migrations.AddField( + model_name='systemuser', + name='su_enabled', + field=models.BooleanField(default=False, verbose_name='User switch'), + ), + migrations.AddField( + model_name='systemuser', + name='su_from', + field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='su_to', to='assets.systemuser', verbose_name='Switch from'), + ), + ] diff --git a/apps/assets/models/asset.py b/apps/assets/models/asset.py index 91acd3d34..7d9d6d0d0 100644 --- a/apps/assets/models/asset.py +++ b/apps/assets/models/asset.py @@ -164,38 +164,7 @@ class Platform(models.Model): # ordering = ('name',) -class Asset(AbsConnectivity, ProtocolsMixin, NodesRelationMixin, OrgModelMixin): - # Important - PLATFORM_CHOICES = ( - ('Linux', 'Linux'), - ('Unix', 'Unix'), - ('MacOS', 'MacOS'), - ('BSD', 'BSD'), - ('Windows', 'Windows'), - ('Windows2016', 'Windows(2016)'), - ('Other', 'Other'), - ) - - id = models.UUIDField(default=uuid.uuid4, primary_key=True) - ip = models.CharField(max_length=128, verbose_name=_('IP'), db_index=True) - hostname = models.CharField(max_length=128, verbose_name=_('Hostname')) - protocol = models.CharField(max_length=128, default=ProtocolsMixin.Protocol.ssh, - choices=ProtocolsMixin.Protocol.choices, - verbose_name=_('Protocol')) - port = models.IntegerField(default=22, verbose_name=_('Port')) - protocols = models.CharField(max_length=128, default='ssh/22', blank=True, verbose_name=_("Protocols")) - platform = models.ForeignKey(Platform, default=Platform.default, on_delete=models.PROTECT, verbose_name=_("Platform"), related_name='assets') - domain = models.ForeignKey("assets.Domain", null=True, blank=True, related_name='assets', verbose_name=_("Domain"), on_delete=models.SET_NULL) - nodes = models.ManyToManyField('assets.Node', default=default_node, related_name='assets', verbose_name=_("Nodes")) - is_active = models.BooleanField(default=True, verbose_name=_('Is active')) - - # Auth - admin_user = models.ForeignKey('assets.SystemUser', on_delete=models.SET_NULL, null=True, verbose_name=_("Admin user"), related_name='admin_assets') - - # Some information - public_ip = models.CharField(max_length=128, blank=True, null=True, verbose_name=_('Public IP')) - number = models.CharField(max_length=32, null=True, blank=True, verbose_name=_('Asset number')) - +class AbsHardwareInfo(models.Model): # Collect vendor = models.CharField(max_length=64, null=True, blank=True, verbose_name=_('Vendor')) model = models.CharField(max_length=54, null=True, blank=True, verbose_name=_('Model')) @@ -214,6 +183,49 @@ class Asset(AbsConnectivity, ProtocolsMixin, NodesRelationMixin, OrgModelMixin): os_arch = models.CharField(max_length=16, blank=True, null=True, verbose_name=_('OS arch')) hostname_raw = models.CharField(max_length=128, blank=True, null=True, verbose_name=_('Hostname raw')) + class Meta: + abstract = True + + @property + def cpu_info(self): + info = "" + if self.cpu_model: + info += self.cpu_model + if self.cpu_count and self.cpu_cores: + info += "{}*{}".format(self.cpu_count, self.cpu_cores) + return info + + @property + def hardware_info(self): + if self.cpu_count: + return '{} Core {} {}'.format( + self.cpu_vcpus or self.cpu_count * self.cpu_cores, + self.memory, self.disk_total + ) + else: + return '' + + +class Asset(AbsConnectivity, AbsHardwareInfo, ProtocolsMixin, NodesRelationMixin, OrgModelMixin): + id = models.UUIDField(default=uuid.uuid4, primary_key=True) + ip = models.CharField(max_length=128, verbose_name=_('IP'), db_index=True) + hostname = models.CharField(max_length=128, verbose_name=_('Hostname')) + protocol = models.CharField(max_length=128, default=ProtocolsMixin.Protocol.ssh, + choices=ProtocolsMixin.Protocol.choices, verbose_name=_('Protocol')) + port = models.IntegerField(default=22, verbose_name=_('Port')) + protocols = models.CharField(max_length=128, default='ssh/22', blank=True, verbose_name=_("Protocols")) + platform = models.ForeignKey(Platform, default=Platform.default, on_delete=models.PROTECT, verbose_name=_("Platform"), related_name='assets') + domain = models.ForeignKey("assets.Domain", null=True, blank=True, related_name='assets', verbose_name=_("Domain"), on_delete=models.SET_NULL) + nodes = models.ManyToManyField('assets.Node', default=default_node, related_name='assets', verbose_name=_("Nodes")) + is_active = models.BooleanField(default=True, verbose_name=_('Is active')) + + # Auth + admin_user = models.ForeignKey('assets.SystemUser', on_delete=models.SET_NULL, null=True, verbose_name=_("Admin user"), related_name='admin_assets') + + # Some information + public_ip = models.CharField(max_length=128, blank=True, null=True, verbose_name=_('Public IP')) + number = models.CharField(max_length=32, null=True, blank=True, verbose_name=_('Asset number')) + labels = models.ManyToManyField('assets.Label', blank=True, related_name='assets', verbose_name=_("Labels")) created_by = models.CharField(max_length=128, null=True, blank=True, verbose_name=_('Created by')) date_created = models.DateTimeField(auto_now_add=True, null=True, blank=True, verbose_name=_('Date created')) @@ -269,25 +281,6 @@ class Asset(AbsConnectivity, ProtocolsMixin, NodesRelationMixin, OrgModelMixin): def is_support_ansible(self): return self.has_protocol('ssh') and self.platform_base not in ("Other",) - @property - def cpu_info(self): - info = "" - if self.cpu_model: - info += self.cpu_model - if self.cpu_count and self.cpu_cores: - info += "{}*{}".format(self.cpu_count, self.cpu_cores) - return info - - @property - def hardware_info(self): - if self.cpu_count: - return '{} Core {} {}'.format( - self.cpu_vcpus or self.cpu_count * self.cpu_cores, - self.memory, self.disk_total - ) - else: - return '' - def get_auth_info(self): if not self.admin_user: return {} diff --git a/apps/assets/models/user.py b/apps/assets/models/user.py index 52e3c2af8..5f1a8df76 100644 --- a/apps/assets/models/user.py +++ b/apps/assets/models/user.py @@ -208,6 +208,9 @@ class SystemUser(ProtocolMixin, AuthMixin, BaseUser): home = models.CharField(max_length=4096, default='', verbose_name=_('Home'), blank=True) system_groups = models.CharField(default='', max_length=4096, verbose_name=_('System groups'), blank=True) ad_domain = models.CharField(default='', max_length=256) + # linux su 命令 (switch user) + su_enabled = models.BooleanField(default=False, verbose_name=_('User switch')) + su_from = models.ForeignKey('self', on_delete=models.SET_NULL, related_name='su_to', null=True, verbose_name=_("Switch from")) def __str__(self): username = self.username @@ -267,6 +270,21 @@ class SystemUser(ProtocolMixin, AuthMixin, BaseUser): assets = Asset.objects.filter(id__in=asset_ids) return assets + def add_related_assets(self, assets_or_ids): + self.assets.add(*tuple(assets_or_ids)) + self.add_related_assets_to_su_from_if_need(assets_or_ids) + + def add_related_assets_to_su_from_if_need(self, assets_or_ids): + if self.protocol not in [self.Protocol.ssh.value]: + return + if not self.su_enabled: + return + if not self.su_from: + return + if self.su_from.protocol != self.protocol: + return + self.su_from.assets.add(*tuple(assets_or_ids)) + class Meta: ordering = ['name'] unique_together = [('name', 'org_id')] diff --git a/apps/assets/serializers/asset.py b/apps/assets/serializers/asset.py index fccbcaa9e..b13eda715 100644 --- a/apps/assets/serializers/asset.py +++ b/apps/assets/serializers/asset.py @@ -66,7 +66,9 @@ class AssetSerializer(BulkOrgResourceModelSerializer): ) protocols = ProtocolsField(label=_('Protocols'), required=False, default=['ssh/22']) domain_display = serializers.ReadOnlyField(source='domain.name', label=_('Domain name')) - nodes_display = serializers.ListField(child=serializers.CharField(), label=_('Nodes name'), required=False) + nodes_display = serializers.ListField( + child=serializers.CharField(), label=_('Nodes name'), required=False + ) """ 资产的数据结构 @@ -79,11 +81,11 @@ class AssetSerializer(BulkOrgResourceModelSerializer): 'protocol', 'port', 'protocols', 'is_active', 'public_ip', 'number', 'comment', ] - hardware_fields = [ + fields_hardware = [ 'vendor', 'model', 'sn', 'cpu_model', 'cpu_count', 'cpu_cores', 'cpu_vcpus', 'memory', 'disk_total', 'disk_info', - 'os', 'os_version', 'os_arch', 'hostname_raw', 'hardware_info', - 'connectivity', 'date_verified' + 'os', 'os_version', 'os_arch', 'hostname_raw', + 'cpu_info', 'hardware_info', ] fields_fk = [ 'domain', 'domain_display', 'platform', 'admin_user', 'admin_user_display' @@ -92,18 +94,16 @@ class AssetSerializer(BulkOrgResourceModelSerializer): 'nodes', 'nodes_display', 'labels', ] read_only_fields = [ + 'connectivity', 'date_verified', 'cpu_info', 'hardware_info', 'created_by', 'date_created', ] - fields = fields_small + hardware_fields + fields_fk + fields_m2m + read_only_fields - - extra_kwargs = {k: {'read_only': True} for k in hardware_fields} - extra_kwargs.update({ + fields = fields_small + fields_hardware + fields_fk + fields_m2m + read_only_fields + extra_kwargs = { 'protocol': {'write_only': True}, 'port': {'write_only': True}, 'hardware_info': {'label': _('Hardware info'), 'read_only': True}, - 'org_name': {'label': _('Org name'), 'read_only': True}, 'admin_user_display': {'label': _('Admin user display'), 'read_only': True}, - }) + } def get_fields(self): fields = super().get_fields() diff --git a/apps/assets/serializers/system_user.py b/apps/assets/serializers/system_user.py index b662d062c..b86740f8c 100644 --- a/apps/assets/serializers/system_user.py +++ b/apps/assets/serializers/system_user.py @@ -40,6 +40,7 @@ class SystemUserSerializer(AuthSerializerMixin, BulkOrgResourceModelSerializer): 'login_mode', 'login_mode_display', 'priority', 'sudo', 'shell', 'sftp_root', 'home', 'system_groups', 'ad_domain', 'username_same_with_user', 'auto_push', 'auto_generate_key', + 'su_enabled', 'su_from', 'date_created', 'date_updated', 'comment', 'created_by', ] fields_m2m = ['cmd_filters', 'assets_amount', 'applications_amount', 'nodes'] @@ -57,7 +58,8 @@ class SystemUserSerializer(AuthSerializerMixin, BulkOrgResourceModelSerializer): 'login_mode_display': {'label': _('Login mode display')}, 'created_by': {'read_only': True}, 'ad_domain': {'required': False, 'allow_blank': True, 'label': _('Ad domain')}, - 'is_asset_protocol': {'label': _('Is asset protocol')} + 'is_asset_protocol': {'label': _('Is asset protocol')}, + 'su_from': {'help_text': _('Only ssh and automatic login system users are supported')} } def validate_auto_push(self, value): @@ -146,6 +148,29 @@ class SystemUserSerializer(AuthSerializerMixin, BulkOrgResourceModelSerializer): raise serializers.ValidationError(_("Password or private key required")) return password + def validate_su_from(self, su_from: SystemUser): + # self: su enabled + su_enabled = self.get_initial_value('su_enabled', default=False) + if not su_enabled: + return + if not su_from: + error = _('This field is required.') + raise serializers.ValidationError(error) + # self: protocol ssh + protocol = self.get_initial_value('protocol', default=SystemUser.Protocol.ssh.value) + if protocol not in [SystemUser.Protocol.ssh.value]: + error = _('Only ssh protocol system users are allowed') + raise serializers.ValidationError(error) + # su_from: protocol same + if su_from.protocol != protocol: + error = _('The protocol must be consistent with the current user: {}').format(protocol) + raise serializers.ValidationError(error) + # su_from: login model auto + if su_from.login_mode != su_from.LOGIN_AUTO: + error = _('Only system users with automatic login are allowed') + raise serializers.ValidationError(error) + return su_from + def _validate_admin_user(self, attrs): if self.instance: tp = self.instance.type diff --git a/apps/assets/signals_handler/system_user.py b/apps/assets/signals_handler/system_user.py index 00111030c..00b19e110 100644 --- a/apps/assets/signals_handler/system_user.py +++ b/apps/assets/signals_handler/system_user.py @@ -140,3 +140,5 @@ def on_system_user_update(instance: SystemUser, created, **kwargs): logger.info("System user update signal recv: {}".format(instance)) assets = instance.assets.all().valid() push_system_user_to_assets.delay(instance.id, [_asset.id for _asset in assets]) + # add assets to su_from + instance.add_related_assets_to_su_from_if_need(assets) diff --git a/apps/audits/signals_handler.py b/apps/audits/signals_handler.py index 4ba7e8408..362bd4c11 100644 --- a/apps/audits/signals_handler.py +++ b/apps/audits/signals_handler.py @@ -15,7 +15,7 @@ from rest_framework.request import Request from assets.models import Asset, SystemUser from authentication.signals import post_auth_failed, post_auth_success -from authentication.utils import check_different_city_login +from authentication.utils import check_different_city_login_if_need from jumpserver.utils import current_request from users.models import User from users.signals import post_user_change_password @@ -304,7 +304,7 @@ def generate_data(username, request, login_type=None): @receiver(post_auth_success) def on_user_auth_success(sender, user, request, login_type=None, **kwargs): logger.debug('User login success: {}'.format(user.username)) - check_different_city_login(user, request) + check_different_city_login_if_need(user, request) data = generate_data(user.username, request, login_type=login_type) data.update({'mfa': int(user.mfa_enabled), 'status': True}) write_login_log(**data) diff --git a/apps/authentication/api/connection_token.py b/apps/authentication/api/connection_token.py index 842011a94..2ff4fa71c 100644 --- a/apps/authentication/api/connection_token.py +++ b/apps/authentication/api/connection_token.py @@ -4,6 +4,7 @@ import urllib.parse import json import base64 from typing import Callable +import os from django.conf import settings from django.core.cache import cache @@ -50,6 +51,10 @@ class ClientProtocolMixin: user = self.request.user return asset, application, system_user, user + @staticmethod + def parse_env_bool(env_key, env_default, true_value, false_value): + return true_value if is_true(os.getenv(env_key, env_default)) else false_value + def get_rdp_file_content(self, serializer): options = { 'full address:s': '', @@ -112,6 +117,10 @@ class ClientProtocolMixin: options['desktopheight:i'] = height else: options['smart sizing:i'] = '1' + + options['session bpp:i'] = os.getenv('JUMPSERVER_COLOR_DEPTH', '32') + options['audiomode:i'] = self.parse_env_bool('JUMPSERVER_DISABLE_AUDIO', 'false', '2', '0') + content = '' for k, v in options.items(): content += f'{k}:{v}\n' diff --git a/apps/authentication/api/mfa.py b/apps/authentication/api/mfa.py index 978c52072..751231819 100644 --- a/apps/authentication/api/mfa.py +++ b/apps/authentication/api/mfa.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- # -import builtins import time from django.utils.translation import ugettext as _ @@ -12,55 +11,76 @@ from rest_framework.serializers import ValidationError from rest_framework.response import Response from common.permissions import IsValidUser, NeedMFAVerify -from users.models.user import MFAType, User +from common.utils import get_logger +from users.models.user import User from ..serializers import OtpVerifySerializer from .. import serializers from .. import errors +from ..mfa.otp import MFAOtp from ..mixins import AuthMixin -__all__ = ['MFAChallengeApi', 'UserOtpVerifyApi', 'SendSMSVerifyCodeApi', 'MFASelectTypeApi'] +logger = get_logger(__name__) + +__all__ = [ + 'MFAChallengeVerifyApi', 'UserOtpVerifyApi', + 'MFASendCodeApi' +] -class MFASelectTypeApi(AuthMixin, CreateAPIView): +# MFASelectAPi 原来的名字 +class MFASendCodeApi(AuthMixin, CreateAPIView): + """ + 选择 MFA 后对应操作 api,koko 目前在用 + """ permission_classes = (AllowAny,) serializer_class = serializers.MFASelectTypeSerializer def perform_create(self, serializer): + username = serializer.validated_data.get('username', '') mfa_type = serializer.validated_data['type'] - if mfa_type == MFAType.SMS_CODE: + if not username: user = self.get_user_from_session() - user.send_sms_code() + else: + user = get_object_or_404(User, username=username) + + mfa_backend = user.get_mfa_backend_by_type(mfa_type) + if not mfa_backend or not mfa_backend.challenge_required: + raise ValidationError('MFA type not support: {} {}'.format(mfa_type, mfa_backend)) + mfa_backend.send_challenge() + + def create(self, request, *args, **kwargs): + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + + try: + self.perform_create(serializer) + return Response(serializer.data, status=201) + except Exception as e: + logger.exception(e) + return Response({'error': str(e)}, status=400) -class MFAChallengeApi(AuthMixin, CreateAPIView): +class MFAChallengeVerifyApi(AuthMixin, CreateAPIView): permission_classes = (AllowAny,) serializer_class = serializers.MFAChallengeSerializer def perform_create(self, serializer): - try: - user = self.get_user_from_session() - code = serializer.validated_data.get('code') - mfa_type = serializer.validated_data.get('type', MFAType.OTP) + user = self.get_user_from_session() + code = serializer.validated_data.get('code') + mfa_type = serializer.validated_data.get('type', '') + self._do_check_user_mfa(code, mfa_type, user) - valid = user.check_mfa(code, mfa_type=mfa_type) - if not valid: - self.request.session['auth_mfa'] = '' - raise errors.MFAFailedError( - username=user.username, request=self.request, ip=self.get_request_ip() - ) - else: - self.request.session['auth_mfa'] = '1' + def create(self, request, *args, **kwargs): + try: + super().create(request, *args, **kwargs) + return Response({'msg': 'ok'}) except errors.AuthFailedError as e: data = {"error": e.error, "msg": e.msg} raise ValidationError(data) except errors.NeedMoreInfoError as e: return Response(e.as_data(), status=200) - def create(self, request, *args, **kwargs): - super().create(request, *args, **kwargs) - return Response({'msg': 'ok'}) - class UserOtpVerifyApi(CreateAPIView): permission_classes = (IsValidUser,) @@ -73,30 +93,17 @@ class UserOtpVerifyApi(CreateAPIView): serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) code = serializer.validated_data["code"] + otp = MFAOtp(request.user) - if request.user.check_mfa(code): + ok, error = otp.check_code(code) + if ok: request.session["MFA_VERIFY_TIME"] = int(time.time()) return Response({"ok": "1"}) else: - return Response({"error": _("Code is invalid")}, status=400) + return Response({"error": _("Code is invalid") + ", " + error}, status=400) def get_permissions(self): - if self.request.method.lower() == 'get' and settings.SECURITY_VIEW_AUTH_NEED_MFA: + if self.request.method.lower() == 'get' \ + and settings.SECURITY_VIEW_AUTH_NEED_MFA: self.permission_classes = [NeedMFAVerify] return super().get_permissions() - - -class SendSMSVerifyCodeApi(AuthMixin, CreateAPIView): - permission_classes = (AllowAny,) - - def create(self, request, *args, **kwargs): - username = request.data.get('username', '') - username = username.strip() - if username: - user = get_object_or_404(User, username=username) - else: - user = self.get_user_from_session() - if not user.mfa_enabled: - raise errors.NotEnableMFAError - timeout = user.send_sms_code() - return Response({'code': 'ok', 'timeout': timeout}) diff --git a/apps/authentication/api/password.py b/apps/authentication/api/password.py index af8b41358..95ebe6edc 100644 --- a/apps/authentication/api/password.py +++ b/apps/authentication/api/password.py @@ -4,7 +4,7 @@ from rest_framework.response import Response from authentication.serializers import PasswordVerifySerializer from common.permissions import IsValidUser from authentication.mixins import authenticate -from authentication.errors import PasswdInvalid +from authentication.errors import PasswordInvalid from authentication.mixins import AuthMixin @@ -20,7 +20,7 @@ class UserPasswordVerifyApi(AuthMixin, CreateAPIView): user = authenticate(request=request, username=user.username, password=password) if not user: - raise PasswdInvalid + raise PasswordInvalid - self.set_passwd_verify_on_session(user) + self.mark_password_ok(user) return Response() diff --git a/apps/authentication/api/token.py b/apps/authentication/api/token.py index f7516496c..d8e8eb6fc 100644 --- a/apps/authentication/api/token.py +++ b/apps/authentication/api/token.py @@ -40,5 +40,5 @@ class TokenCreateApi(AuthMixin, CreateAPIView): return Response(e.as_data(), status=400) except errors.NeedMoreInfoError as e: return Response(e.as_data(), status=200) - except errors.PasswdTooSimple as e: + except errors.PasswordTooSimple as e: return redirect(e.url) diff --git a/apps/authentication/errors.py b/apps/authentication/errors.py index 8a6f219bd..19b13ab8e 100644 --- a/apps/authentication/errors.py +++ b/apps/authentication/errors.py @@ -8,7 +8,6 @@ from rest_framework import status from common.exceptions import JMSException from .signals import post_auth_failed from users.utils import LoginBlockUtil, MFABlockUtils -from users.models import MFAType reason_password_failed = 'password_failed' reason_password_decrypt_failed = 'password_decrypt_failed' @@ -60,22 +59,11 @@ block_mfa_msg = _( "The account has been locked " "(please contact admin to unlock it or try again after {} minutes)" ) -otp_failed_msg = _( - "One-time password invalid, or ntp sync server time, " +mfa_error_msg = _( + "{error}," "You can also try {times_try} times " "(The account will be temporarily locked for {block_time} minutes)" ) -sms_failed_msg = _( - "SMS verify code invalid," - "You can also try {times_try} times " - "(The account will be temporarily locked for {block_time} minutes)" -) -mfa_type_failed_msg = _( - "The MFA type({mfa_type}) is not supported, " - "You can also try {times_try} times " - "(The account will be temporarily locked for {block_time} minutes)" -) - mfa_required_msg = _("MFA required") mfa_unset_msg = _("MFA not set, please set it first") otp_unset_msg = _("OTP not set, please set it first") @@ -151,29 +139,19 @@ class MFAFailedError(AuthFailedNeedLogMixin, AuthFailedError): error = reason_mfa_failed msg: str - def __init__(self, username, request, ip, mfa_type=MFAType.OTP): - util = MFABlockUtils(username, ip) - util.incr_failed_count() + def __init__(self, username, request, ip, mfa_type, error): + super().__init__(username=username, request=request) - times_remainder = util.get_remainder_times() + util = MFABlockUtils(username, ip) + times_remainder = util.incr_failed_count() block_time = settings.SECURITY_LOGIN_LIMIT_TIME if times_remainder: - if mfa_type == MFAType.OTP: - self.msg = otp_failed_msg.format( - times_try=times_remainder, block_time=block_time - ) - elif mfa_type == MFAType.SMS_CODE: - self.msg = sms_failed_msg.format( - times_try=times_remainder, block_time=block_time - ) - else: - self.msg = mfa_type_failed_msg.format( - mfa_type=mfa_type, times_try=times_remainder, block_time=block_time - ) + self.msg = mfa_error_msg.format( + error=error, times_try=times_remainder, block_time=block_time + ) else: self.msg = block_mfa_msg.format(settings.SECURITY_LOGIN_LIMIT_TIME) - super().__init__(username=username, request=request) class BlockMFAError(AuthFailedNeedLogMixin, AuthFailedError): @@ -228,7 +206,7 @@ class MFARequiredError(NeedMoreInfoError): msg = mfa_required_msg error = 'mfa_required' - def __init__(self, error='', msg='', mfa_types=tuple(MFAType)): + def __init__(self, error='', msg='', mfa_types=()): super().__init__(error=error, msg=msg) self.choices = mfa_types @@ -305,7 +283,7 @@ class SSOAuthClosed(JMSException): default_detail = _('SSO auth closed') -class PasswdTooSimple(JMSException): +class PasswordTooSimple(JMSException): default_code = 'passwd_too_simple' default_detail = _('Your password is too simple, please change it for security') @@ -314,7 +292,7 @@ class PasswdTooSimple(JMSException): self.url = url -class PasswdNeedUpdate(JMSException): +class PasswordNeedUpdate(JMSException): default_code = 'passwd_need_update' default_detail = _('You should to change your password before login') @@ -357,7 +335,7 @@ class FeiShuNotBound(JMSException): default_detail = 'FeiShu is not bound' -class PasswdInvalid(JMSException): +class PasswordInvalid(JMSException): default_code = 'passwd_invalid' default_detail = _('Your password is invalid') @@ -368,10 +346,6 @@ class NotHaveUpDownLoadPerm(JMSException): default_detail = _('No upload or download permission') -class NotEnableMFAError(JMSException): - default_detail = mfa_unset_msg - - class OTPBindRequiredError(JMSException): default_detail = otp_unset_msg @@ -380,11 +354,13 @@ class OTPBindRequiredError(JMSException): self.url = url -class OTPCodeRequiredError(AuthFailedError): +class MFACodeRequiredError(AuthFailedError): msg = _("Please enter MFA code") + class SMSCodeRequiredError(AuthFailedError): msg = _("Please enter SMS code") + class UserPhoneNotSet(AuthFailedError): msg = _('Phone not set') diff --git a/apps/authentication/mfa/__init__.py b/apps/authentication/mfa/__init__.py new file mode 100644 index 000000000..16279eb0d --- /dev/null +++ b/apps/authentication/mfa/__init__.py @@ -0,0 +1,5 @@ +from .otp import MFAOtp, otp_failed_msg +from .sms import MFASms +from .radius import MFARadius + +MFA_BACKENDS = [MFAOtp, MFASms, MFARadius] diff --git a/apps/authentication/mfa/base.py b/apps/authentication/mfa/base.py new file mode 100644 index 000000000..2158b8fe1 --- /dev/null +++ b/apps/authentication/mfa/base.py @@ -0,0 +1,72 @@ +import abc + +from django.utils.translation import ugettext_lazy as _ + + +class BaseMFA(abc.ABC): + placeholder = _('Please input security code') + + def __init__(self, user): + """ + :param user: Authenticated user, Anonymous or None + 因为首页登录时,可能没法获取到一些状态 + """ + self.user = user + + def is_authenticated(self): + return self.user and self.user.is_authenticated + + @property + @abc.abstractmethod + def name(self): + return '' + + @property + @abc.abstractmethod + def display_name(self): + return '' + + @staticmethod + def challenge_required(): + return False + + def send_challenge(self): + pass + + @abc.abstractmethod + def check_code(self, code) -> tuple: + return False, 'Error msg' + + @abc.abstractmethod + def is_active(self): + return False + + @staticmethod + @abc.abstractmethod + def global_enabled(): + return False + + @abc.abstractmethod + def get_enable_url(self) -> str: + return '' + + @abc.abstractmethod + def get_disable_url(self) -> str: + return '' + + @abc.abstractmethod + def disable(self): + pass + + @abc.abstractmethod + def can_disable(self) -> bool: + return True + + @staticmethod + def help_text_of_enable(): + return '' + + @staticmethod + def help_text_of_disable(): + return '' + diff --git a/apps/authentication/mfa/otp.py b/apps/authentication/mfa/otp.py new file mode 100644 index 000000000..9d67c4ae2 --- /dev/null +++ b/apps/authentication/mfa/otp.py @@ -0,0 +1,51 @@ +from django.utils.translation import gettext_lazy as _ +from django.shortcuts import reverse + +from .base import BaseMFA + + +otp_failed_msg = _("OTP code invalid, or server time error") + + +class MFAOtp(BaseMFA): + name = 'otp' + display_name = _('OTP') + + def check_code(self, code): + from users.utils import check_otp_code + assert self.is_authenticated() + + ok = check_otp_code(self.user.otp_secret_key, code) + msg = '' if ok else otp_failed_msg + return ok, msg + + def is_active(self): + if not self.is_authenticated(): + return True + return self.user.otp_secret_key + + @staticmethod + def global_enabled(): + return True + + def get_enable_url(self) -> str: + return reverse('authentication:user-otp-enable-start') + + def disable(self): + assert self.is_authenticated() + self.user.otp_secret_key = '' + self.user.save(update_fields=['otp_secret_key']) + + def can_disable(self) -> bool: + return True + + def get_disable_url(self): + return reverse('authentication:user-otp-disable') + + @staticmethod + def help_text_of_enable(): + return _("Virtual OTP based MFA") + + def help_text_of_disable(self): + return '' + diff --git a/apps/authentication/mfa/radius.py b/apps/authentication/mfa/radius.py new file mode 100644 index 000000000..ad20456c1 --- /dev/null +++ b/apps/authentication/mfa/radius.py @@ -0,0 +1,46 @@ +from django.utils.translation import ugettext_lazy as _ +from django.conf import settings + +from .base import BaseMFA +from ..backends.radius import RadiusBackend + +mfa_failed_msg = _("Radius verify code invalid") + + +class MFARadius(BaseMFA): + name = 'otp_radius' + display_name = _('Radius MFA') + + def check_code(self, code): + assert self.is_authenticated() + backend = RadiusBackend() + username = self.user.username + user = backend.authenticate( + None, username=username, password=code + ) + ok = user is not None + msg = '' if ok else mfa_failed_msg + return ok, msg + + def is_active(self): + return True + + @staticmethod + def global_enabled(): + return settings.OTP_IN_RADIUS + + def get_enable_url(self) -> str: + return '' + + def can_disable(self): + return False + + def disable(self): + return '' + + @staticmethod + def help_text_of_disable(): + return _("Radius global enabled, cannot disable") + + def get_disable_url(self) -> str: + return '' diff --git a/apps/authentication/mfa/sms.py b/apps/authentication/mfa/sms.py new file mode 100644 index 000000000..cc2855cfd --- /dev/null +++ b/apps/authentication/mfa/sms.py @@ -0,0 +1,60 @@ +from django.utils.translation import ugettext_lazy as _ +from django.conf import settings + +from .base import BaseMFA +from common.sdk.sms import SendAndVerifySMSUtil + +sms_failed_msg = _("SMS verify code invalid") + + +class MFASms(BaseMFA): + name = 'sms' + display_name = _("SMS") + placeholder = _("SMS verification code") + + def __init__(self, user): + super().__init__(user) + phone = user.phone if self.is_authenticated() else '' + self.sms = SendAndVerifySMSUtil(phone) + + def check_code(self, code): + assert self.is_authenticated() + ok = self.sms.verify(code) + msg = '' if ok else sms_failed_msg + return ok, msg + + def is_active(self): + if not self.is_authenticated(): + return True + return self.user.phone + + @staticmethod + def challenge_required(): + return True + + def send_challenge(self): + self.sms.gen_and_send() + + @staticmethod + def global_enabled(): + return settings.SMS_ENABLED + + def get_enable_url(self) -> str: + return '/ui/#/users/profile/?activeTab=ProfileUpdate' + + def can_disable(self) -> bool: + return True + + def disable(self): + return '/ui/#/users/profile/?activeTab=ProfileUpdate' + + @staticmethod + def help_text_of_enable(): + return _("Set phone number to enable") + + @staticmethod + def help_text_of_disable(): + return _("Clear phone number to disable") + + def get_disable_url(self) -> str: + return '/ui/#/users/profile/?activeTab=ProfileUpdate' diff --git a/apps/authentication/middleware.py b/apps/authentication/middleware.py index 59eabff75..9a5e4e793 100644 --- a/apps/authentication/middleware.py +++ b/apps/authentication/middleware.py @@ -10,5 +10,5 @@ class MFAMiddleware: if request.path.find('/auth/login/otp/') > -1: return response if request.session.get('auth_mfa_required'): - return redirect('authentication:login-otp') + return redirect('authentication:login-mfa') return response diff --git a/apps/authentication/mixins.py b/apps/authentication/mixins.py index d07cfb0d7..a7d845662 100644 --- a/apps/authentication/mixins.py +++ b/apps/authentication/mixins.py @@ -1,24 +1,26 @@ # -*- coding: utf-8 -*- # import inspect -from django.utils.http import urlencode from functools import partial import time +from typing import Callable +from django.utils.http import urlencode from django.core.cache import cache from django.conf import settings from django.urls import reverse_lazy from django.contrib import auth from django.utils.translation import ugettext as _ +from rest_framework.request import Request from django.contrib.auth import ( BACKEND_SESSION_KEY, _get_backends, PermissionDenied, user_login_failed, _clean_credentials ) -from django.shortcuts import reverse, redirect +from django.shortcuts import reverse, redirect, get_object_or_404 from common.utils import get_object_or_none, get_request_ip, get_logger, bulk_get, FlashMessageUtil from acls.models import LoginACL -from users.models import User, MFAType +from users.models import User from users.utils import LoginBlockUtil, MFABlockUtils from . import errors from .utils import rsa_decrypt, gen_key_pair @@ -32,8 +34,7 @@ def check_backend_can_auth(username, backend_path, allowed_auth_backends): if allowed_auth_backends is not None and backend_path not in allowed_auth_backends: logger.debug('Skip user auth backend: {}, {} not in'.format( username, backend_path, ','.join(allowed_auth_backends) - ) - ) + )) return False return True @@ -109,17 +110,18 @@ class PasswordEncryptionViewMixin: def decrypt_passwd(self, raw_passwd): # 获取解密密钥,对密码进行解密 rsa_private_key = self.request.session.get(RSA_PRIVATE_KEY) - if rsa_private_key is not None: - try: - return rsa_decrypt(raw_passwd, rsa_private_key) - except Exception as e: - logger.error(e, exc_info=True) - logger.error( - f'Decrypt password failed: password[{raw_passwd}] ' - f'rsa_private_key[{rsa_private_key}]' - ) - return None - return raw_passwd + if rsa_private_key is None: + return raw_passwd + + try: + return rsa_decrypt(raw_passwd, rsa_private_key) + except Exception as e: + logger.error(e, exc_info=True) + logger.error( + f'Decrypt password failed: password[{raw_passwd}] ' + f'rsa_private_key[{rsa_private_key}]' + ) + return None def get_request_ip(self): ip = '' @@ -132,7 +134,7 @@ class PasswordEncryptionViewMixin: # 生成加解密密钥对,public_key传递给前端,private_key存入session中供解密使用 rsa_public_key = self.request.session.get(RSA_PUBLIC_KEY) rsa_private_key = self.request.session.get(RSA_PRIVATE_KEY) - if not all((rsa_private_key, rsa_public_key)): + if not all([rsa_private_key, rsa_public_key]): rsa_private_key, rsa_public_key = gen_key_pair() rsa_public_key = rsa_public_key.replace('\n', '\\n') self.request.session[RSA_PRIVATE_KEY] = rsa_private_key @@ -144,49 +146,9 @@ class PasswordEncryptionViewMixin: return super().get_context_data(**kwargs) -class AuthMixin(PasswordEncryptionViewMixin): - request = None - partial_credential_error = None - - key_prefix_captcha = "_LOGIN_INVALID_{}" - - def get_user_from_session(self): - if self.request.session.is_empty(): - raise errors.SessionEmptyError() - - if all((self.request.user, - not self.request.user.is_anonymous, - BACKEND_SESSION_KEY in self.request.session)): - user = self.request.user - user.backend = self.request.session[BACKEND_SESSION_KEY] - return user - - user_id = self.request.session.get('user_id') - if not user_id: - user = None - else: - user = get_object_or_none(User, pk=user_id) - if not user: - raise errors.SessionEmptyError() - user.backend = self.request.session.get("auth_backend") - return user - - def _check_is_block(self, username, raise_exception=True): - ip = self.get_request_ip() - if LoginBlockUtil(username, ip).is_block(): - logger.warn('Ip was blocked' + ': ' + username + ':' + ip) - exception = errors.BlockLoginError(username=username, ip=ip) - if raise_exception: - raise errors.BlockLoginError(username=username, ip=ip) - else: - return exception - - def check_is_block(self, raise_exception=True): - if hasattr(self.request, 'data'): - username = self.request.data.get("username") - else: - username = self.request.POST.get("username") - self._check_is_block(username, raise_exception) +class CommonMixin(PasswordEncryptionViewMixin): + request: Request + get_request_ip: Callable def raise_credential_error(self, error): raise self.partial_credential_error(error=error) @@ -197,6 +159,31 @@ class AuthMixin(PasswordEncryptionViewMixin): ip=ip, request=request ) + def get_user_from_session(self): + if self.request.session.is_empty(): + raise errors.SessionEmptyError() + + if all([ + self.request.user, + not self.request.user.is_anonymous, + BACKEND_SESSION_KEY in self.request.session + ]): + user = self.request.user + user.backend = self.request.session[BACKEND_SESSION_KEY] + return user + + user_id = self.request.session.get('user_id') + auth_password = self.request.session.get('auth_password') + auth_expired_at = self.request.session.get('auth_password_expired_at') + auth_expired = auth_expired_at < time.time() if auth_expired_at else False + + if not user_id or not auth_password or auth_expired: + raise errors.SessionEmptyError() + + user = get_object_or_404(User, pk=user_id) + user.backend = self.request.session.get("auth_backend") + return user + def get_auth_data(self, decrypt_passwd=False): request = self.request if hasattr(request, 'data'): @@ -214,6 +201,31 @@ class AuthMixin(PasswordEncryptionViewMixin): password = password + challenge.strip() return username, password, public_key, ip, auto_login + +class AuthPreCheckMixin: + request: Request + get_request_ip: Callable + raise_credential_error: Callable + + def _check_is_block(self, username, raise_exception=True): + ip = self.get_request_ip() + is_block = LoginBlockUtil(username, ip).is_block() + if not is_block: + return + logger.warn('Ip was blocked' + ': ' + username + ':' + ip) + exception = errors.BlockLoginError(username=username, ip=ip) + if raise_exception: + raise errors.BlockLoginError(username=username, ip=ip) + else: + return exception + + def check_is_block(self, raise_exception=True): + if hasattr(self.request, 'data'): + username = self.request.data.get("username") + else: + username = self.request.POST.get("username") + self._check_is_block(username, raise_exception) + def _check_only_allow_exists_user_auth(self, username): # 仅允许预先存在的用户认证 if not settings.ONLY_ALLOW_EXIST_USER_AUTH: @@ -224,105 +236,92 @@ class AuthMixin(PasswordEncryptionViewMixin): logger.error(f"Only allow exist user auth, login failed: {username}") self.raise_credential_error(errors.reason_user_not_exist) - def _check_auth_user_is_valid(self, username, password, public_key): - user = authenticate(self.request, username=username, password=password, public_key=public_key) - if not user: - self.raise_credential_error(errors.reason_password_failed) - elif user.is_expired: - self.raise_credential_error(errors.reason_user_expired) - elif not user.is_active: - self.raise_credential_error(errors.reason_user_inactive) - return user - def _check_login_mfa_login_if_need(self, user): +class MFAMixin: + request: Request + get_user_from_session: Callable + get_request_ip: Callable + + def _check_login_page_mfa_if_need(self, user): + if not settings.SECURITY_MFA_IN_LOGIN_PAGE: + return + request = self.request - if hasattr(request, 'data'): - data = request.data - else: - data = request.POST + data = request.data if hasattr(request, 'data') else request.POST code = data.get('code') - mfa_type = data.get('mfa_type') - if settings.SECURITY_MFA_IN_LOGIN_PAGE and mfa_type: - if not code: - if mfa_type == MFAType.OTP and bool(user.otp_secret_key): - raise errors.OTPCodeRequiredError - elif mfa_type == MFAType.SMS_CODE: - raise errors.SMSCodeRequiredError - self.check_user_mfa(code, mfa_type, user=user) + mfa_type = data.get('mfa_type', 'otp') + if not code: + raise errors.MFACodeRequiredError + self._do_check_user_mfa(code, mfa_type, user=user) - def _check_login_acl(self, user, ip): - # ACL 限制用户登录 - is_allowed, limit_type = LoginACL.allow_user_to_login(user, ip) - if not is_allowed: - if limit_type == 'ip': - raise errors.LoginIPNotAllowed(username=user.username, request=self.request) - elif limit_type == 'time': - raise errors.TimePeriodNotAllowed(username=user.username, request=self.request) + def check_user_mfa_if_need(self, user): + if self.request.session.get('auth_mfa'): + return + if not user.mfa_enabled: + return - def set_login_failed_mark(self): + active_mfa_mapper = user.active_mfa_backends_mapper + if not active_mfa_mapper: + url = reverse('authentication:user-otp-enable-start') + raise errors.MFAUnsetError(user, self.request, url) + raise errors.MFARequiredError(mfa_types=tuple(active_mfa_mapper.keys())) + + def mark_mfa_ok(self, mfa_type): + self.request.session['auth_mfa'] = 1 + self.request.session['auth_mfa_time'] = time.time() + self.request.session['auth_mfa_required'] = 0 + self.request.session['auth_mfa_type'] = mfa_type + + def clean_mfa_mark(self): + keys = ['auth_mfa', 'auth_mfa_time', 'auth_mfa_required', 'auth_mfa_type'] + for k in keys: + self.request.session.pop(k, '') + + def check_mfa_is_block(self, username, ip, raise_exception=True): + blocked = MFABlockUtils(username, ip).is_block() + if not blocked: + return + logger.warn('Ip was blocked' + ': ' + username + ':' + ip) + exception = errors.BlockMFAError(username=username, request=self.request, ip=ip) + if raise_exception: + raise exception + else: + return exception + + def _do_check_user_mfa(self, code, mfa_type, user=None): + user = user if user else self.get_user_from_session() + if not user.mfa_enabled: + return + + # 监测 MFA 是不是屏蔽了 ip = self.get_request_ip() - cache.set(self.key_prefix_captcha.format(ip), 1, 3600) + self.check_mfa_is_block(user.username, ip) - def set_passwd_verify_on_session(self, user: User): - self.request.session['user_id'] = str(user.id) - self.request.session['auth_password'] = 1 - self.request.session['auth_password_expired_at'] = time.time() + settings.AUTH_EXPIRED_SECONDS + ok = False + mfa_backend = user.get_mfa_backend_by_type(mfa_type) + if mfa_backend: + ok, msg = mfa_backend.check_code(code) + else: + msg = _('The MFA type({}) is not supported'.format(mfa_type)) - def check_is_need_captcha(self): - # 最近有登录失败时需要填写验证码 - ip = get_request_ip(self.request) - need = cache.get(self.key_prefix_captcha.format(ip)) - return need + if ok: + self.mark_mfa_ok(mfa_type) + return - def check_user_auth(self, decrypt_passwd=False): - self.check_is_block() - username, password, public_key, ip, auto_login = self.get_auth_data(decrypt_passwd) + raise errors.MFAFailedError( + username=user.username, + request=self.request, + ip=ip, mfa_type=mfa_type, + error=msg + ) - self._check_only_allow_exists_user_auth(username) - user = self._check_auth_user_is_valid(username, password, public_key) - # 校验login-acl规则 - self._check_login_acl(user, ip) - self._check_password_require_reset_or_not(user) - self._check_passwd_is_too_simple(user, password) - self._check_passwd_need_update(user) + @staticmethod + def get_user_mfa_context(user=None): + mfa_backends = User.get_user_mfa_backends(user) + return {'mfa_backends': mfa_backends} - # 校验login-mfa, 如果登录页面上显示 mfa 的话 - self._check_login_mfa_login_if_need(user) - - LoginBlockUtil(username, ip).clean_failed_count() - request = self.request - request.session['auth_password'] = 1 - request.session['user_id'] = str(user.id) - request.session['auto_login'] = auto_login - request.session['auth_backend'] = getattr(user, 'backend', settings.AUTH_BACKEND_MODEL) - return user - - def _check_is_local_user(self, user: User): - if user.source != User.Source.local: - raise self.raise_credential_error(error=errors.only_local_users_are_allowed) - - def check_oauth2_auth(self, user: User, auth_backend): - ip = self.get_request_ip() - request = self.request - - self._set_partial_credential_error(user.username, ip, request) - - if user.is_expired: - self.raise_credential_error(errors.reason_user_expired) - elif not user.is_active: - self.raise_credential_error(errors.reason_user_inactive) - - self._check_is_block(user.username) - self._check_login_acl(user, ip) - - LoginBlockUtil(user.username, ip).clean_failed_count() - MFABlockUtils(user.username, ip).clean_failed_count() - - request.session['auth_password'] = 1 - request.session['user_id'] = str(user.id) - request.session['auth_backend'] = auth_backend - return user +class AuthPostCheckMixin: @classmethod def generate_reset_password_url_with_flash_msg(cls, user, message): reset_passwd_url = reverse('authentication:reset-password') @@ -344,14 +343,14 @@ class AuthMixin(PasswordEncryptionViewMixin): if user.is_superuser and password == 'admin': message = _('Your password is too simple, please change it for security') url = cls.generate_reset_password_url_with_flash_msg(user, message=message) - raise errors.PasswdTooSimple(url) + raise errors.PasswordTooSimple(url) @classmethod def _check_passwd_need_update(cls, user: User): if user.need_update_password: message = _('You should to change your password before login') url = cls.generate_reset_password_url_with_flash_msg(user, message) - raise errors.PasswdNeedUpdate(url) + raise errors.PasswordNeedUpdate(url) @classmethod def _check_password_require_reset_or_not(cls, user: User): @@ -360,76 +359,20 @@ class AuthMixin(PasswordEncryptionViewMixin): url = cls.generate_reset_password_url_with_flash_msg(user, message) raise errors.PasswordRequireResetError(url) - def check_user_auth_if_need(self, decrypt_passwd=False): - request = self.request - if request.session.get('auth_password') and \ - request.session.get('user_id'): - user = self.get_user_from_session() - if user: - return user - return self.check_user_auth(decrypt_passwd=decrypt_passwd) - def check_user_mfa_if_need(self, user): - if self.request.session.get('auth_mfa'): +class AuthACLMixin: + request: Request + get_request_ip: Callable + + def _check_login_acl(self, user, ip): + # ACL 限制用户登录 + is_allowed, limit_type = LoginACL.allow_user_to_login(user, ip) + if is_allowed: return - if settings.OTP_IN_RADIUS: - return - if not user.mfa_enabled: - return - - unset, url = user.mfa_enabled_but_not_set() - if unset: - raise errors.MFAUnsetError(user, self.request, url) - raise errors.MFARequiredError(mfa_types=user.get_supported_mfa_types()) - - def mark_mfa_ok(self, mfa_type=MFAType.OTP): - self.request.session['auth_mfa'] = 1 - self.request.session['auth_mfa_time'] = time.time() - self.request.session['auth_mfa_required'] = '' - self.request.session['auth_mfa_type'] = mfa_type - - def clean_mfa_mark(self): - self.request.session['auth_mfa'] = '' - self.request.session['auth_mfa_time'] = '' - self.request.session['auth_mfa_required'] = '' - self.request.session['auth_mfa_type'] = '' - - def check_mfa_is_block(self, username, ip, raise_exception=True): - blocked = MFABlockUtils(username, ip).is_block() - if not blocked: - return - logger.warn('Ip was blocked' + ': ' + username + ':' + ip) - exception = errors.BlockMFAError(username=username, request=self.request, ip=ip) - if raise_exception: - raise exception - else: - return exception - - def check_user_mfa(self, code, mfa_type=MFAType.OTP, user=None): - user = user if user else self.get_user_from_session() - if not user.mfa_enabled: - return - - if not bool(user.phone) and mfa_type == MFAType.SMS_CODE: - raise errors.UserPhoneNotSet - - if not bool(user.otp_secret_key) and mfa_type == MFAType.OTP: - self.set_passwd_verify_on_session(user) - raise errors.OTPBindRequiredError(reverse_lazy('authentication:user-otp-enable-bind')) - - ip = self.get_request_ip() - self.check_mfa_is_block(user.username, ip) - ok = user.check_mfa(code, mfa_type=mfa_type) - - if ok: - self.mark_mfa_ok() - return - - raise errors.MFAFailedError( - username=user.username, - request=self.request, - ip=ip, mfa_type=mfa_type, - ) + if limit_type == 'ip': + raise errors.LoginIPNotAllowed(username=user.username, request=self.request) + elif limit_type == 'time': + raise errors.TimePeriodNotAllowed(username=user.username, request=self.request) def get_ticket(self): from tickets.models import Ticket @@ -480,11 +423,99 @@ class AuthMixin(PasswordEncryptionViewMixin): self.get_ticket_or_create(confirm_setting) self.check_user_login_confirm() + +class AuthMixin(CommonMixin, AuthPreCheckMixin, AuthACLMixin, MFAMixin, AuthPostCheckMixin): + request = None + partial_credential_error = None + + key_prefix_captcha = "_LOGIN_INVALID_{}" + + def _check_auth_user_is_valid(self, username, password, public_key): + user = authenticate( + self.request, username=username, + password=password, public_key=public_key + ) + if not user: + self.raise_credential_error(errors.reason_password_failed) + elif user.is_expired: + self.raise_credential_error(errors.reason_user_expired) + elif not user.is_active: + self.raise_credential_error(errors.reason_user_inactive) + return user + + def set_login_failed_mark(self): + ip = self.get_request_ip() + cache.set(self.key_prefix_captcha.format(ip), 1, 3600) + + def check_is_need_captcha(self): + # 最近有登录失败时需要填写验证码 + ip = get_request_ip(self.request) + need = cache.get(self.key_prefix_captcha.format(ip)) + return need + + def check_user_auth(self, decrypt_passwd=False): + # pre check + self.check_is_block() + username, password, public_key, ip, auto_login = self.get_auth_data(decrypt_passwd) + self._check_only_allow_exists_user_auth(username) + + # check auth + user = self._check_auth_user_is_valid(username, password, public_key) + + # 校验login-acl规则 + self._check_login_acl(user, ip) + + # post check + self._check_password_require_reset_or_not(user) + self._check_passwd_is_too_simple(user, password) + self._check_passwd_need_update(user) + + # 校验login-mfa, 如果登录页面上显示 mfa 的话 + self._check_login_page_mfa_if_need(user) + + # 标记密码验证成功 + self.mark_password_ok(user=user, auto_login=auto_login) + LoginBlockUtil(user.username, ip).clean_failed_count() + return user + + def mark_password_ok(self, user, auto_login=False): + request = self.request + request.session['auth_password'] = 1 + request.session['auth_password_expired_at'] = time.time() + settings.AUTH_EXPIRED_SECONDS + request.session['user_id'] = str(user.id) + request.session['auto_login'] = auto_login + request.session['auth_backend'] = getattr(user, 'backend', settings.AUTH_BACKEND_MODEL) + + def check_oauth2_auth(self, user: User, auth_backend): + ip = self.get_request_ip() + request = self.request + + self._set_partial_credential_error(user.username, ip, request) + + if user.is_expired: + self.raise_credential_error(errors.reason_user_expired) + elif not user.is_active: + self.raise_credential_error(errors.reason_user_inactive) + + self._check_is_block(user.username) + self._check_login_acl(user, ip) + + LoginBlockUtil(user.username, ip).clean_failed_count() + MFABlockUtils(user.username, ip).clean_failed_count() + + self.mark_password_ok(user, False) + return user + + def check_user_auth_if_need(self, decrypt_passwd=False): + request = self.request + if not request.session.get('auth_password'): + return self.check_user_auth(decrypt_passwd=decrypt_passwd) + return self.get_user_from_session() + def clear_auth_mark(self): - self.request.session['auth_password'] = '' - self.request.session['auth_user_id'] = '' - self.request.session['auth_confirm'] = '' - self.request.session['auth_ticket_id'] = '' + keys = ['auth_password', 'user_id', 'auth_confirm', 'auth_ticket_id'] + for k in keys: + self.request.session.pop(k, '') def send_auth_signal(self, success=True, user=None, username='', reason=''): if success: @@ -503,31 +534,3 @@ class AuthMixin(PasswordEncryptionViewMixin): if args: guard_url = "%s?%s" % (guard_url, args) return redirect(guard_url) - - @staticmethod - def get_user_mfa_methods(user=None): - otp_enabled = user.otp_secret_key if user else True - # 没有用户时,或者有用户并且有电话配置 - sms_enabled = any([user and user.phone, not user]) \ - and settings.SMS_ENABLED and settings.XPACK_ENABLED - - methods = [ - { - 'name': 'otp', - 'label': 'MFA', - 'enable': otp_enabled, - 'selected': False, - }, - { - 'name': 'sms', - 'label': _('SMS'), - 'enable': sms_enabled, - 'selected': False, - }, - ] - - for item in methods: - if item['enable']: - item['selected'] = True - break - return methods diff --git a/apps/authentication/serializers.py b/apps/authentication/serializers.py index 548819089..a87e1e942 100644 --- a/apps/authentication/serializers.py +++ b/apps/authentication/serializers.py @@ -78,6 +78,7 @@ class BearerTokenSerializer(serializers.Serializer): class MFASelectTypeSerializer(serializers.Serializer): type = serializers.CharField() + username = serializers.CharField(required=False, allow_blank=True, allow_null=True) class MFAChallengeSerializer(serializers.Serializer): diff --git a/apps/authentication/signals_handlers.py b/apps/authentication/signals_handlers.py index 87b177ff5..d895c8498 100644 --- a/apps/authentication/signals_handlers.py +++ b/apps/authentication/signals_handlers.py @@ -13,11 +13,11 @@ from .signals import post_auth_success, post_auth_failed @receiver(user_logged_in) def on_user_auth_login_success(sender, user, request, **kwargs): - # 开启了 MFA,且没有校验过 - - if user.mfa_enabled and not settings.OTP_IN_RADIUS and not request.session.get('auth_mfa'): + # 开启了 MFA,且没有校验过, 可以全局校验, middleware 中可以全局管理 oidc 等第三方认证的 MFA + if user.mfa_enabled and not request.session.get('auth_mfa'): request.session['auth_mfa_required'] = 1 + # 单点登录,超过了自动退出 if settings.USER_LOGIN_SINGLE_MACHINE_ENABLED: user_id = 'single_machine_login_' + str(user.id) session_key = cache.get(user_id) diff --git a/apps/authentication/templates/authentication/login.html b/apps/authentication/templates/authentication/login.html index f968e325c..a6550e725 100644 --- a/apps/authentication/templates/authentication/login.html +++ b/apps/authentication/templates/authentication/login.html @@ -160,7 +160,7 @@ {% bootstrap_field form.challenge show_label=False %} {% elif form.mfa_type %}