diff --git a/apps/jumpserver/settings.py b/apps/jumpserver/settings.py index 9ab8c5242..b82f3be49 100644 --- a/apps/jumpserver/settings.py +++ b/apps/jumpserver/settings.py @@ -278,7 +278,7 @@ REST_FRAMEWORK = { 'users.authentication.AccessKeyAuthentication', 'users.authentication.AccessTokenAuthentication', 'users.authentication.PrivateTokenAuthentication', - 'rest_framework.authentication.SessionAuthentication', + 'users.authentication.SessionAuthentication', ), 'DEFAULT_FILTER_BACKENDS': ('django_filters.rest_framework.DjangoFilterBackend',), } diff --git a/apps/perms/api.py b/apps/perms/api.py index b80e351ef..7b50c4a9d 100644 --- a/apps/perms/api.py +++ b/apps/perms/api.py @@ -2,6 +2,7 @@ # from rest_framework.views import APIView, Response +from rest_framework.decorators import api_view from rest_framework.generics import ListAPIView, get_object_or_404 from rest_framework import viewsets from users.permissions import IsValidUser, IsSuperUser @@ -127,7 +128,7 @@ class MyGrantedAssetsGroupsApi(APIView): for asset in assets: for asset_group in asset.groups.all(): if asset_group.id in asset_groups: - asset_groups[asset_group.id]['asset_amount'] += 1 + asset_groups[asset_group.id]['assets_amount'] += 1 else: asset_groups[asset_group.id] = { 'id': asset_group.id, diff --git a/apps/users/api.py b/apps/users/api.py index a31bd0207..8772adbdf 100644 --- a/apps/users/api.py +++ b/apps/users/api.py @@ -8,6 +8,9 @@ from django.conf import settings from rest_framework import generics, viewsets from rest_framework.response import Response from rest_framework.views import APIView +from rest_framework.decorators import api_view +from rest_framework.permissions import AllowAny +from rest_framework.authentication import SessionAuthentication from rest_framework_bulk import BulkModelViewSet from django_filters.rest_framework import DjangoFilterBackend @@ -86,13 +89,21 @@ class UserGroupUpdateUserApi(generics.RetrieveUpdateAPIView): class UserToken(APIView): - permission_classes = (IsValidUser,) + permission_classes = (AllowAny,) - def get(self, request): - if not request.user: - return Response({'error': 'unauthorized'}) - token = generate_token(request) - return Response({'token': token}) + def post(self, request): + username = request.data.get('username', '') + email = request.data.get('email', '') + password = request.data.get('password', '') + public_key = request.data.get('public_key', '') + + user, msg = check_user_valid(username=username, email=email, + password=password, public_key=public_key) + if user: + token = generate_token(request) + return Response({'Token': token, 'key': 'Bearer'}, status=200) + else: + return Response({'error': msg}, status=406) class UserProfile(APIView): diff --git a/apps/users/authentication.py b/apps/users/authentication.py index 63493dec5..6647bfcfd 100644 --- a/apps/users/authentication.py +++ b/apps/users/authentication.py @@ -122,3 +122,8 @@ class AccessTokenAuthentication(authentication.BaseAuthentication): class PrivateTokenAuthentication(authentication.TokenAuthentication): model = PrivateToken + + +class SessionAuthentication(authentication.SessionAuthentication): + def enforce_csrf(self, request): + return None \ No newline at end of file diff --git a/apps/users/models/user.py b/apps/users/models/user.py index 985d6f4bd..1d71fd31e 100644 --- a/apps/users/models/user.py +++ b/apps/users/models/user.py @@ -45,8 +45,6 @@ class User(AbstractUser): verbose_name=_('Date expired')) created_by = models.CharField(max_length=30, default='', verbose_name=_('Created by')) - - @property def password_raw(self): raise AttributeError('Password raw is not a readable attribute') diff --git a/apps/users/urls/api_urls.py b/apps/users/urls/api_urls.py index 3964a5030..34401d525 100644 --- a/apps/users/urls/api_urls.py +++ b/apps/users/urls/api_urls.py @@ -12,18 +12,17 @@ app_name = 'users' router = BulkRouter() router.register(r'v1/users', api.UserViewSet, 'user') router.register(r'v1/user-groups', api.UserGroupViewSet, 'user-group') -# router.register(r'v1/user-groups', api.AssetViewSet, 'api-groups') urlpatterns = [ - url(r'^v1/token$', api.UserToken.as_view(), name='user-token'), - url(r'^v1/profile$', api.UserProfile.as_view(), name='user-profile'), - url(r'^v1/users/(?P\d+)/reset-password$', api.UserResetPasswordApi.as_view(), name='user-reset-password'), - url(r'^v1/users/(?P\d+)/reset-pk$', api.UserResetPKApi.as_view(), name='user-reset-pk'), - url(r'^v1/users/(?P\d+)/update-pk$', api.UserUpdatePKApi.as_view(), name='user-update-pk'), - url(r'^v1/users/(?P\d+)/groups$', + url(r'^v1/token/$', api.UserToken.as_view(), name='user-token'), + url(r'^v1/profile/$', api.UserProfile.as_view(), name='user-profile'), + url(r'^v1/users/(?P\d+)/password/reset/$', api.UserResetPasswordApi.as_view(), name='user-reset-password'), + url(r'^v1/users/(?P\d+)/public-key/reset/$', api.UserResetPKApi.as_view(), name='user-public-key-reset'), + url(r'^v1/users/(?P\d+)/public-key/update/$', api.UserUpdatePKApi.as_view(), name='user-public-key-update'), + url(r'^v1/users/(?P\d+)/groups/$', api.UserUpdateGroupApi.as_view(), name='user-update-group'), - url(r'^v1/user-groups/(?P\d+)/users$', + url(r'^v1/user-groups/(?P\d+)/users/$', api.UserGroupUpdateUserApi.as_view(), name='user-group-update-user'), ] diff --git a/apps/users/utils.py b/apps/users/utils.py index bf0127600..418cc60a6 100644 --- a/apps/users/utils.py +++ b/apps/users/utils.py @@ -180,21 +180,33 @@ def send_reset_ssh_key_mail(user): def check_user_valid(**kwargs): password = kwargs.pop('password', None) public_key = kwargs.pop('public_key', None) - user = get_object_or_none(User, **kwargs) + email = kwargs.pop('email') + username = kwargs.pop('username') + + if username: + user = get_object_or_none(User, username=username) + elif email: + user = get_object_or_none(User, email=email) + else: + user = None + + if user is None: + return None, _('User not exist') + elif not user.is_valid: + return None, _('Disabled or expired') - if user is None or not user.is_valid: - return None if password and user.check_password(password): - return user + return user, '' + if public_key: public_key_saved = user.public_key.split() if len(public_key_saved) == 1: if public_key == public_key_saved[0]: - return user + return user, '' elif len(public_key_saved) > 1: if public_key == public_key_saved[1]: - return user - return None + return user, '' + return None, _('Passowrd or SSH public key invalid') def refresh_token(token, user):