diff --git a/apps/authentication/backends/saml2/views.py b/apps/authentication/backends/saml2/views.py index b0b8fef8d..e91fd0660 100644 --- a/apps/authentication/backends/saml2/views.py +++ b/apps/authentication/backends/saml2/views.py @@ -74,27 +74,37 @@ class PrepareRequestMixin: return idp_settings @staticmethod - def get_attribute_consuming_service(): - attr_mapping = settings.SAML2_RENAME_ATTRIBUTES - if attr_mapping and isinstance(attr_mapping, dict): - attr_list = [ - { - "name": sp_key, - "friendlyName": idp_key, "isRequired": True - } - for idp_key, sp_key in attr_mapping.items() - ] - request_attribute_template = { - "attributeConsumingService": { - "isDefault": False, - "serviceName": "JumpServer", - "serviceDescription": "JumpServer", - "requestedAttributes": attr_list - } + def get_request_attributes(): + attr_mapping = settings.SAML2_RENAME_ATTRIBUTES or {} + attr_map_reverse = {v: k for k, v in attr_mapping.items()} + need_attrs = ( + ('username', 'username', True), + ('email', 'email', True), + ('name', 'name', False), + ('phone', 'phone', False), + ('comment', 'comment', False), + ) + attr_list = [] + for name, friend_name, is_required in need_attrs: + rename_name = attr_map_reverse.get(friend_name) + name = rename_name if rename_name else name + attr_list.append({ + "name": name, "isRequired": is_required, + "friendlyName": friend_name, + }) + return attr_list + + def get_attribute_consuming_service(self): + attr_list = self.get_request_attributes() + request_attribute_template = { + "attributeConsumingService": { + "isDefault": False, + "serviceName": "JumpServer", + "serviceDescription": "JumpServer", + "requestedAttributes": attr_list } - return request_attribute_template - else: - return {} + } + return request_attribute_template @staticmethod def get_advanced_settings(): @@ -167,11 +177,14 @@ class PrepareRequestMixin: def get_attributes(self, saml_instance): user_attrs = {} + attr_mapping = settings.SAML2_RENAME_ATTRIBUTES attrs = saml_instance.get_attributes() valid_attrs = ['username', 'name', 'email', 'comment', 'phone'] for attr, value in attrs.items(): attr = attr.rsplit('/', 1)[-1] + if attr_mapping and attr_mapping.get(attr): + attr = attr_mapping.get(attr) if attr not in valid_attrs: continue user_attrs[attr] = self.value_to_str(value)