mirror of https://github.com/jumpserver/jumpserver
perf: Oauth2.0 support two methods for passing authentication credentials.
parent
d4dc31aefa
commit
35a1655905
|
@ -1,5 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
#
|
#
|
||||||
|
import base64
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
|
@ -67,14 +68,6 @@ class OAuth2Backend(JMSModelBackend):
|
||||||
response_data = response_data['data']
|
response_data = response_data['data']
|
||||||
return response_data
|
return response_data
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_query_dict(response_data, query_dict):
|
|
||||||
query_dict.update({
|
|
||||||
'uid': response_data.get('uid', ''),
|
|
||||||
'access_token': response_data.get('access_token', '')
|
|
||||||
})
|
|
||||||
return query_dict
|
|
||||||
|
|
||||||
def authenticate(self, request, code=None, **kwargs):
|
def authenticate(self, request, code=None, **kwargs):
|
||||||
log_prompt = "Process authenticate [OAuth2Backend]: {}"
|
log_prompt = "Process authenticate [OAuth2Backend]: {}"
|
||||||
logger.debug(log_prompt.format('Start'))
|
logger.debug(log_prompt.format('Start'))
|
||||||
|
@ -83,29 +76,31 @@ class OAuth2Backend(JMSModelBackend):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
query_dict = {
|
query_dict = {
|
||||||
'client_id': settings.AUTH_OAUTH2_CLIENT_ID,
|
'grant_type': 'authorization_code', 'code': code,
|
||||||
'client_secret': settings.AUTH_OAUTH2_CLIENT_SECRET,
|
|
||||||
'grant_type': 'authorization_code',
|
|
||||||
'code': code,
|
|
||||||
'redirect_uri': build_absolute_uri(
|
'redirect_uri': build_absolute_uri(
|
||||||
request, path=reverse(settings.AUTH_OAUTH2_AUTH_LOGIN_CALLBACK_URL_NAME)
|
request, path=reverse(settings.AUTH_OAUTH2_AUTH_LOGIN_CALLBACK_URL_NAME)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
if '?' in settings.AUTH_OAUTH2_ACCESS_TOKEN_ENDPOINT:
|
separator = '&' if '?' in settings.AUTH_OAUTH2_ACCESS_TOKEN_ENDPOINT else '?'
|
||||||
separator = '&'
|
|
||||||
else:
|
|
||||||
separator = '?'
|
|
||||||
access_token_url = '{url}{separator}{query}'.format(
|
access_token_url = '{url}{separator}{query}'.format(
|
||||||
url=settings.AUTH_OAUTH2_ACCESS_TOKEN_ENDPOINT, separator=separator, query=urlencode(query_dict)
|
url=settings.AUTH_OAUTH2_ACCESS_TOKEN_ENDPOINT,
|
||||||
|
separator=separator, query=urlencode(query_dict)
|
||||||
)
|
)
|
||||||
# token_method -> get, post(post_data), post_json
|
# token_method -> get, post(post_data), post_json
|
||||||
token_method = settings.AUTH_OAUTH2_ACCESS_TOKEN_METHOD.lower()
|
token_method = settings.AUTH_OAUTH2_ACCESS_TOKEN_METHOD.lower()
|
||||||
logger.debug(log_prompt.format('Call the access token endpoint[method: %s]' % token_method))
|
logger.debug(log_prompt.format('Call the access token endpoint[method: %s]' % token_method))
|
||||||
|
encoded_credentials = base64.b64encode(
|
||||||
|
f"{settings.AUTH_OAUTH2_CLIENT_ID}:{settings.AUTH_OAUTH2_CLIENT_SECRET}".encode()
|
||||||
|
).decode()
|
||||||
headers = {
|
headers = {
|
||||||
'Accept': 'application/json'
|
'Accept': 'application/json', 'Authorization': f'Basic {encoded_credentials}'
|
||||||
}
|
}
|
||||||
if token_method.startswith('post'):
|
if token_method.startswith('post'):
|
||||||
body_key = 'json' if token_method.endswith('json') else 'data'
|
body_key = 'json' if token_method.endswith('json') else 'data'
|
||||||
|
query_dict.update({
|
||||||
|
'client_id': settings.AUTH_OAUTH2_CLIENT_ID,
|
||||||
|
'client_secret': settings.AUTH_OAUTH2_CLIENT_SECRET,
|
||||||
|
})
|
||||||
access_token_response = requests.post(
|
access_token_response = requests.post(
|
||||||
access_token_url, headers=headers, **{body_key: query_dict}
|
access_token_url, headers=headers, **{body_key: query_dict}
|
||||||
)
|
)
|
||||||
|
@ -121,22 +116,12 @@ class OAuth2Backend(JMSModelBackend):
|
||||||
logger.error(log_prompt.format(error))
|
logger.error(log_prompt.format(error))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
query_dict = self.get_query_dict(response_data, query_dict)
|
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
'Accept': 'application/json',
|
'Accept': 'application/json',
|
||||||
'Authorization': 'Bearer {}'.format(response_data.get('access_token', ''))
|
'Authorization': 'Bearer {}'.format(response_data.get('access_token', ''))
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.debug(log_prompt.format('Get userinfo endpoint'))
|
logger.debug(log_prompt.format('Get userinfo endpoint'))
|
||||||
if '?' in settings.AUTH_OAUTH2_PROVIDER_USERINFO_ENDPOINT:
|
userinfo_url = settings.AUTH_OAUTH2_PROVIDER_USERINFO_ENDPOINT
|
||||||
separator = '&'
|
|
||||||
else:
|
|
||||||
separator = '?'
|
|
||||||
userinfo_url = '{url}{separator}{query}'.format(
|
|
||||||
url=settings.AUTH_OAUTH2_PROVIDER_USERINFO_ENDPOINT, separator=separator,
|
|
||||||
query=urlencode(query_dict)
|
|
||||||
)
|
|
||||||
userinfo_response = requests.get(userinfo_url, headers=headers)
|
userinfo_response = requests.get(userinfo_url, headers=headers)
|
||||||
try:
|
try:
|
||||||
userinfo_response.raise_for_status()
|
userinfo_response.raise_for_status()
|
||||||
|
|
Loading…
Reference in New Issue