perf: Support SAML2, OIDC user authentication services, mapping user group field information

pull/14091/head
feng 2024-08-29 19:23:04 +08:00 committed by feng626
parent 1068662ab1
commit c545e2a3aa
4 changed files with 67 additions and 17 deletions

View File

@ -8,27 +8,26 @@
"""
import base64
import requests
from rest_framework.exceptions import ParseError
import requests
from django.conf import settings
from django.contrib.auth import get_user_model
from django.contrib.auth.backends import ModelBackend
from django.core.exceptions import SuspiciousOperation
from django.db import transaction
from django.urls import reverse
from django.conf import settings
from rest_framework.exceptions import ParseError
from common.utils import get_logger
from authentication.signals import user_auth_success, user_auth_failed
from authentication.utils import build_absolute_uri_for_oidc
from common.utils import get_logger
from users.utils import construct_user_email
from ..base import JMSBaseAuthBackend
from .utils import validate_and_return_id_token
from .decorator import ssl_verification
from .signals import (
openid_create_or_update_user
)
from authentication.signals import user_auth_success, user_auth_failed
from .utils import validate_and_return_id_token
from ..base import JMSBaseAuthBackend
logger = get_logger(__file__)
@ -55,16 +54,17 @@ class UserMixin:
logger.debug(log_prompt.format(user_attrs))
username = user_attrs.get('username')
name = user_attrs.get('name')
groups = user_attrs.pop('groups', None)
user, created = get_user_model().objects.get_or_create(
username=username, defaults=user_attrs
)
user_attrs['groups'] = groups
logger.debug(log_prompt.format("user: {}|created: {}".format(user, created)))
logger.debug(log_prompt.format("Send signal => openid create or update user"))
openid_create_or_update_user.send(
sender=self.__class__, request=request, user=user, created=created,
name=name, username=username, email=email
sender=self.__class__, request=request, user=user,
created=created, attrs=user_attrs,
)
return user, created
@ -269,7 +269,8 @@ class OIDCAuthPasswordBackend(OIDCBaseBackend):
# Calls the token endpoint.
logger.debug(log_prompt.format('Call the token endpoint'))
token_response = requests.post(settings.AUTH_OPENID_PROVIDER_TOKEN_ENDPOINT, data=token_payload, timeout=request_timeout)
token_response = requests.post(settings.AUTH_OPENID_PROVIDER_TOKEN_ENDPOINT, data=token_payload,
timeout=request_timeout)
try:
token_response.raise_for_status()
token_response_data = token_response.json()

View File

@ -27,9 +27,13 @@ class SAML2Backend(JMSModelBackend):
log_prompt = "Get or Create user [SAML2Backend]: {}"
logger.debug(log_prompt.format('start'))
groups = saml_user_data.pop('groups', None)
user, created = get_user_model().objects.get_or_create(
username=saml_user_data['username'], defaults=saml_user_data
)
saml_user_data['groups'] = groups
logger.debug(log_prompt.format("user: {}|created: {}".format(user, created)))
logger.debug(log_prompt.format("Send signal => saml2 create or update user"))

View File

@ -87,6 +87,7 @@ class PrepareRequestMixin:
('name', 'name', False),
('phone', 'phone', False),
('comment', 'comment', False),
('groups', 'groups', False),
)
attr_list = []
for name, friend_name, is_required in need_attrs:
@ -185,7 +186,7 @@ class PrepareRequestMixin:
user_attrs = {}
attr_mapping = settings.SAML2_RENAME_ATTRIBUTES
attrs = saml_instance.get_attributes()
valid_attrs = ['username', 'name', 'email', 'comment', 'phone']
valid_attrs = ['username', 'name', 'email', 'comment', 'phone', 'groups']
for attr, value in attrs.items():
attr = attr.rsplit('/', 1)[-1]

View File

@ -21,11 +21,13 @@ from common.signals import django_ready
from common.utils import get_logger
from jumpserver.utils import get_current_request
from ops.celery.decorator import register_as_period_task
from orgs.models import Organization
from orgs.utils import tmp_to_root_org
from rbac.builtin import BuiltinRole
from rbac.const import Scope
from rbac.models import RoleBinding
from settings.signals import setting_changed
from .models import User, UserPasswordHistory
from .models import User, UserPasswordHistory, UserGroup
from .signals import post_user_create
logger = get_logger(__file__)
@ -50,7 +52,9 @@ def user_authenticated_handle(user, created, source, attrs=None, **kwargs):
if created:
user.source = source
user.save()
bind_user_to_org_role(user)
org_ids = bind_user_to_org_role(user)
group_names = attrs.get('groups')
bind_user_to_group(org_ids, group_names, user)
if not attrs:
return
@ -146,7 +150,7 @@ def radius_create_user(sender, user, **kwargs):
@receiver(openid_create_or_update_user)
def on_openid_create_or_update_user(sender, request, user, created, name, username, email, **kwargs):
def on_openid_create_or_update_user(sender, request, user, created, attrs, **kwargs):
if not check_only_allow_exist_user_auth(created):
return
@ -157,7 +161,13 @@ def on_openid_create_or_update_user(sender, request, user, created, name, userna
)
user.source = User.Source.openid.value
user.save()
bind_user_to_org_role(user)
org_ids = bind_user_to_org_role(user)
group_names = attrs.get('groups')
bind_user_to_group(org_ids, group_names, user)
name = attrs.get('name')
username = attrs.get('username')
email = attrs.get('email')
if not created and settings.AUTH_OPENID_ALWAYS_UPDATE_USER:
logger.debug(
@ -225,3 +235,37 @@ def bind_user_to_org_role(user):
]
RoleBinding.objects.bulk_create(bindings, ignore_conflicts=True)
return org_ids
def bind_user_to_group(org_ids, group_names, user):
if not isinstance(group_names, list):
return
org_ids = org_ids or [Organization.DEFAULT_ID]
with tmp_to_root_org():
existing_groups = UserGroup.objects.filter(org_id__in=org_ids).values_list('org_id', 'name')
org_groups_map = {}
for org_id, group_name in existing_groups:
org_groups_map.setdefault(org_id, []).append(group_name)
groups_to_create = []
for org_id in org_ids:
existing_group_names = set(org_groups_map.get(org_id, []))
new_group_names = set(group_names) - existing_group_names
groups_to_create.extend(
UserGroup(org_id=org_id, name=name) for name in new_group_names
)
UserGroup.objects.bulk_create(groups_to_create)
user_groups = UserGroup.objects.filter(org_id__in=org_ids, name__in=group_names)
user_group_links = [
User.groups.through(user_id=user.id, usergroup_id=group.id)
for group in user_groups
]
if user_group_links:
User.groups.through.objects.bulk_create(user_group_links)