feat: 优化SAML2生成的metadata文件内容及属性映射

pull/8160/head
jiangweidong 2022-04-21 15:18:17 +08:00 committed by 老广
parent 9804ca5dd0
commit 3a3f7eaf71
1 changed files with 33 additions and 20 deletions

View File

@ -74,27 +74,37 @@ class PrepareRequestMixin:
return idp_settings return idp_settings
@staticmethod @staticmethod
def get_attribute_consuming_service(): def get_request_attributes():
attr_mapping = settings.SAML2_RENAME_ATTRIBUTES attr_mapping = settings.SAML2_RENAME_ATTRIBUTES or {}
if attr_mapping and isinstance(attr_mapping, dict): attr_map_reverse = {v: k for k, v in attr_mapping.items()}
attr_list = [ need_attrs = (
{ ('username', 'username', True),
"name": sp_key, ('email', 'email', True),
"friendlyName": idp_key, "isRequired": True ('name', 'name', False),
} ('phone', 'phone', False),
for idp_key, sp_key in attr_mapping.items() ('comment', 'comment', False),
] )
request_attribute_template = { attr_list = []
"attributeConsumingService": { for name, friend_name, is_required in need_attrs:
"isDefault": False, rename_name = attr_map_reverse.get(friend_name)
"serviceName": "JumpServer", name = rename_name if rename_name else name
"serviceDescription": "JumpServer", attr_list.append({
"requestedAttributes": attr_list "name": name, "isRequired": is_required,
} "friendlyName": friend_name,
})
return attr_list
def get_attribute_consuming_service(self):
attr_list = self.get_request_attributes()
request_attribute_template = {
"attributeConsumingService": {
"isDefault": False,
"serviceName": "JumpServer",
"serviceDescription": "JumpServer",
"requestedAttributes": attr_list
} }
return request_attribute_template }
else: return request_attribute_template
return {}
@staticmethod @staticmethod
def get_advanced_settings(): def get_advanced_settings():
@ -167,11 +177,14 @@ class PrepareRequestMixin:
def get_attributes(self, saml_instance): def get_attributes(self, saml_instance):
user_attrs = {} user_attrs = {}
attr_mapping = settings.SAML2_RENAME_ATTRIBUTES
attrs = saml_instance.get_attributes() attrs = saml_instance.get_attributes()
valid_attrs = ['username', 'name', 'email', 'comment', 'phone'] valid_attrs = ['username', 'name', 'email', 'comment', 'phone']
for attr, value in attrs.items(): for attr, value in attrs.items():
attr = attr.rsplit('/', 1)[-1] attr = attr.rsplit('/', 1)[-1]
if attr_mapping and attr_mapping.get(attr):
attr = attr_mapping.get(attr)
if attr not in valid_attrs: if attr not in valid_attrs:
continue continue
user_attrs[attr] = self.value_to_str(value) user_attrs[attr] = self.value_to_str(value)