Merge branch 'v3' of github.com:jumpserver/jumpserver into v3

pull/8997/head
ibuler 2022-10-27 18:34:34 +08:00
commit 097ebc2362
15 changed files with 322 additions and 486 deletions

View File

@ -180,9 +180,10 @@ class CommandFilterRule(OrgModelMixin):
@classmethod @classmethod
def get_queryset( def get_queryset(
cls, user_id=None, user_group_id=None, system_user_id=None, cls, user_id=None, user_group_id=None, account=None,
asset_id=None, org_id=None asset_id=None, org_id=None
): ):
from perms.models.const import SpecialAccount
user_groups = [] user_groups = []
user = get_object_or_none(User, pk=user_id) user = get_object_or_none(User, pk=user_id)
if user: if user:
@ -191,7 +192,7 @@ class CommandFilterRule(OrgModelMixin):
if user_group: if user_group:
org_id = user_group.org_id org_id = user_group.org_id
user_groups.append(user_group) user_groups.append(user_group)
account = get_object_or_none(Account, pk=system_user_id)
asset = get_object_or_none(Asset, pk=asset_id) asset = get_object_or_none(Asset, pk=asset_id)
q = Q() q = Q()
if user: if user:
@ -200,7 +201,8 @@ class CommandFilterRule(OrgModelMixin):
q |= Q(user_groups__in=set(user_groups)) q |= Q(user_groups__in=set(user_groups))
if account: if account:
org_id = account.org_id org_id = account.org_id
q |= Q(accounts=account) q |= Q(accounts__contains=list(account)) |\
Q(accounts__contains=SpecialAccount.ALL.value)
if asset: if asset:
org_id = asset.org_id org_id = asset.org_id
q |= Q(assets=asset) q |= Q(assets=asset)

View File

@ -1,6 +1,7 @@
import abc import abc
import os import os
import json import json
import time
import base64 import base64
import urllib.parse import urllib.parse
from django.http import HttpResponse from django.http import HttpResponse
@ -41,17 +42,25 @@ class ConnectionTokenMixin:
def get_request_resources(self, serializer): def get_request_resources(self, serializer):
user = self.get_request_resource_user(serializer) user = self.get_request_resource_user(serializer)
asset = serializer.validated_data.get('asset') asset = serializer.validated_data.get('asset')
application = serializer.validated_data.get('application') account = serializer.validated_data.get('account')
system_user = serializer.validated_data.get('system_user') return user, asset, account
return user, asset, application, system_user
@staticmethod @staticmethod
def check_user_has_resource_permission(user, asset, application, system_user): def check_user_has_resource_permission(user, asset, account):
from perms.utils.asset import has_asset_system_permission from perms.utils.account import PermAccountUtil
if not asset or not user:
error = ''
raise PermissionDenied(error)
if asset and not has_asset_system_permission(user, asset, system_user): actions, expire_at = PermAccountUtil().validate_permission(
error = f'User not has this asset and system user permission: ' \ user, asset, account_username=account
f'user={user.id} system_user={system_user.id} asset={asset.id}' )
if not actions:
error = ''
raise PermissionDenied(error)
if expire_at < time.time():
error = ''
raise PermissionDenied(error) raise PermissionDenied(error)
def get_smart_endpoint(self, protocol, asset=None, application=None): def get_smart_endpoint(self, protocol, asset=None, application=None):
@ -69,13 +78,12 @@ class ConnectionTokenMixin:
return true_value if is_true(os.getenv(env_key, env_default)) else false_value return true_value if is_true(os.getenv(env_key, env_default)) else false_value
def get_client_protocol_data(self, token: ConnectionToken): def get_client_protocol_data(self, token: ConnectionToken):
from assets.models import SystemUser protocol = token.protocol
protocol = token.system_user.protocol
username = token.user.username username = token.user.username
rdp_config = ssh_token = '' rdp_config = ssh_token = ''
if protocol == SystemUser.Protocol.rdp: if protocol == 'rdp':
filename, rdp_config = self.get_rdp_file_info(token) filename, rdp_config = self.get_rdp_file_info(token)
elif protocol == SystemUser.Protocol.ssh: elif protocol == 'ssh':
filename, ssh_token = self.get_ssh_token(token) filename, ssh_token = self.get_ssh_token(token)
else: else:
raise ValueError('Protocol not support: {}'.format(protocol)) raise ValueError('Protocol not support: {}'.format(protocol))
@ -134,15 +142,12 @@ class ConnectionTokenMixin:
rdp_options['screen mode id:i'] = '2' if full_screen else '1' rdp_options['screen mode id:i'] = '2' if full_screen else '1'
# 设置 RDP Server 地址 # 设置 RDP Server 地址
endpoint = self.get_smart_endpoint( endpoint = self.get_smart_endpoint(protocol='rdp', asset=token.asset)
protocol='rdp', asset=token.asset, application=token.application
)
rdp_options['full address:s'] = f'{endpoint.host}:{endpoint.rdp_port}' rdp_options['full address:s'] = f'{endpoint.host}:{endpoint.rdp_port}'
# 设置用户名 # 设置用户名
rdp_options['username:s'] = '{}|{}'.format(token.user.username, str(token.id)) rdp_options['username:s'] = '{}|{}'.format(token.user.username, str(token.id))
if token.system_user.ad_domain: # rdp_options['domain:s'] = token.account_ad_domain
rdp_options['domain:s'] = token.system_user.ad_domain
# 设置宽高 # 设置宽高
height = self.request.query_params.get('height') height = self.request.query_params.get('height')
@ -158,13 +163,12 @@ class ConnectionTokenMixin:
if token.asset: if token.asset:
name = token.asset.name name = token.asset.name
elif token.application and token.application.category_remote_app: # remote-app
app = '||jmservisor' # app = '||jmservisor'
name = token.application.name # rdp_options['remoteapplicationmode:i'] = '1'
rdp_options['remoteapplicationmode:i'] = '1' # rdp_options['alternate shell:s'] = app
rdp_options['alternate shell:s'] = app # rdp_options['remoteapplicationprogram:s'] = app
rdp_options['remoteapplicationprogram:s'] = app # rdp_options['remoteapplicationname:s'] = name
rdp_options['remoteapplicationname:s'] = name
else: else:
name = '*' name = '*'
prefix_name = f'{token.user.username}-{name}' prefix_name = f'{token.user.username}-{name}'
@ -188,16 +192,12 @@ class ConnectionTokenMixin:
def get_ssh_token(self, token: ConnectionToken): def get_ssh_token(self, token: ConnectionToken):
if token.asset: if token.asset:
name = token.asset.name name = token.asset.name
elif token.application:
name = token.application.name
else: else:
name = '*' name = '*'
prefix_name = f'{token.user.username}-{name}' prefix_name = f'{token.user.username}-{name}'
filename = self.get_connect_filename(prefix_name) filename = self.get_connect_filename(prefix_name)
endpoint = self.get_smart_endpoint( endpoint = self.get_smart_endpoint(protocol='ssh', asset=token.asset)
protocol='ssh', asset=token.asset, application=token.application
)
data = { data = {
'ip': endpoint.host, 'ip': endpoint.host,
'port': str(endpoint.ssh_port), 'port': str(endpoint.ssh_port),
@ -251,8 +251,8 @@ class ConnectionTokenViewSet(ConnectionTokenMixin, RootOrgViewMixin, JMSModelVie
return token return token
def perform_create(self, serializer): def perform_create(self, serializer):
user, asset, application, system_user = self.get_request_resources(serializer) user, asset, account = self.get_request_resources(serializer)
self.check_user_has_resource_permission(user, asset, application, system_user) self.check_user_has_resource_permission(user, asset, account)
return super(ConnectionTokenViewSet, self).perform_create(serializer) return super(ConnectionTokenViewSet, self).perform_create(serializer)
@action(methods=['POST'], detail=False, url_path='secret-info/detail') @action(methods=['POST'], detail=False, url_path='secret-info/detail')
@ -264,7 +264,6 @@ class ConnectionTokenViewSet(ConnectionTokenMixin, RootOrgViewMixin, JMSModelVie
token_id = request.data.get('token') or '' token_id = request.data.get('token') or ''
token = get_object_or_404(ConnectionToken, pk=token_id) token = get_object_or_404(ConnectionToken, pk=token_id)
self.check_token_valid(token) self.check_token_valid(token)
token.load_system_user_auth()
serializer = self.get_serializer(instance=token) serializer = self.get_serializer(instance=token)
return Response(serializer.data, status=status.HTTP_200_OK) return Response(serializer.data, status=status.HTTP_200_OK)

View File

@ -1,248 +0,0 @@
import uuid
from datetime import datetime, timedelta
from django.utils import timezone
from django.utils.translation import ugettext_lazy as _
from django.conf import settings
from rest_framework.authtoken.models import Token
from orgs.mixins.models import OrgModelMixin
from django.db import models
from common.utils import lazyproperty
from common.utils.timezone import as_current_tz
from common.db.models import BaseCreateUpdateModel, JMSBaseModel
class AccessKey(models.Model):
id = models.UUIDField(verbose_name='AccessKeyID', primary_key=True,
default=uuid.uuid4, editable=False)
secret = models.UUIDField(verbose_name='AccessKeySecret',
default=uuid.uuid4, editable=False)
user = models.ForeignKey(settings.AUTH_USER_MODEL, verbose_name='User',
on_delete=models.CASCADE, related_name='access_keys')
is_active = models.BooleanField(default=True, verbose_name=_('Active'))
date_created = models.DateTimeField(auto_now_add=True)
def get_id(self):
return str(self.id)
def get_secret(self):
return str(self.secret)
def get_full_value(self):
return '{}:{}'.format(self.id, self.secret)
def __str__(self):
return str(self.id)
class Meta:
verbose_name = _("Access key")
class PrivateToken(Token):
"""Inherit from auth token, otherwise migration is boring"""
class Meta:
verbose_name = _('Private Token')
class SSOToken(BaseCreateUpdateModel):
"""
类似腾讯企业邮的 [单点登录](https://exmail.qq.com/qy_mng_logic/doc#10036)
出于安全考虑这里的 `token` 使用一次随即过期但我们保留每一个生成过的 `token`
"""
authkey = models.UUIDField(primary_key=True, default=uuid.uuid4, verbose_name=_('Token'))
expired = models.BooleanField(default=False, verbose_name=_('Expired'))
user = models.ForeignKey('users.User', on_delete=models.CASCADE, verbose_name=_('User'), db_constraint=False)
class Meta:
verbose_name = _('SSO token')
def date_expired_default():
return timezone.now() + timedelta(seconds=settings.CONNECTION_TOKEN_EXPIRATION)
class ConnectionToken(OrgModelMixin, JMSBaseModel):
secret = models.CharField(max_length=64, default='', verbose_name=_("Secret"))
date_expired = models.DateTimeField(
default=date_expired_default, verbose_name=_("Date expired")
)
user = models.ForeignKey(
'users.User', on_delete=models.SET_NULL, verbose_name=_('User'),
related_name='connection_tokens', null=True, blank=True
)
user_display = models.CharField(max_length=128, default='', verbose_name=_("User display"))
asset = models.ForeignKey(
'assets.Asset', on_delete=models.SET_NULL, verbose_name=_('Asset'),
related_name='connection_tokens', null=True, blank=True
)
asset_display = models.CharField(max_length=128, default='', verbose_name=_("Asset display"))
account = models.CharField(max_length=128, default='', verbose_name=_("Account"))
class Meta:
ordering = ('-date_expired',)
verbose_name = _('Connection token')
permissions = [
('view_connectiontokensecret', _('Can view connection token secret'))
]
@classmethod
def get_default_date_expired(cls):
return date_expired_default()
@property
def is_expired(self):
return self.date_expired < timezone.now()
@property
def expire_time(self):
interval = self.date_expired - timezone.now()
seconds = interval.total_seconds()
if seconds < 0:
seconds = 0
return int(seconds)
def expire(self):
self.date_expired = timezone.now()
self.save()
@property
def is_valid(self):
return not self.is_expired
def is_type(self, tp):
return self.type == tp
def renewal(self):
""" 续期 Token将来支持用户自定义创建 token 后,续期策略要修改 """
self.date_expired = self.get_default_date_expired()
self.save()
actions = expired_at = None # actions 和 expired_at 在 check_valid() 中赋值
def check_valid(self):
from perms.utils.permission import validate_permission as asset_validate_permission
if self.is_expired:
is_valid = False
error = _('Connection token expired at: {}').format(as_current_tz(self.date_expired))
return is_valid, error
if not self.user:
is_valid = False
error = _('User not exists')
return is_valid, error
if not self.user.is_valid:
is_valid = False
error = _('User invalid, disabled or expired')
return is_valid, error
if not self.account:
is_valid = False
error = _('Account not exists')
return is_valid, error
if not self.asset:
is_valid = False
error = _('Asset not exists')
return is_valid, error
if not self.asset.is_active:
is_valid = False
error = _('Asset inactive')
return is_valid, error
has_perm, actions, expired_at = asset_validate_permission(
self.user, self.asset, self.account
)
if not has_perm:
is_valid = False
error = _('User has no permission to access asset or permission expired')
return is_valid, error
self.actions = actions
self.expired_at = expired_at
return True, ''
@lazyproperty
def domain(self):
if self.asset:
return self.asset.domain
if not self.application:
return
if self.application.category_remote_app:
asset = self.application.get_remote_app_asset()
domain = asset.domain if asset else None
else:
domain = self.application.domain
return domain
@lazyproperty
def gateway(self):
from assets.models import Domain
if not self.domain:
return
self.domain: Domain
return self.domain.random_gateway()
@lazyproperty
def remote_app(self):
if not self.application:
return {}
if not self.application.category_remote_app:
return {}
return self.application.get_rdp_remote_app_setting()
@lazyproperty
def asset_or_remote_app_asset(self):
if self.asset:
return self.asset
if self.application and self.application.category_remote_app:
return self.application.get_remote_app_asset()
@lazyproperty
def cmd_filter_rules(self):
from assets.models import CommandFilterRule
kwargs = {
'user_id': self.user.id,
'system_user_id': self.system_user.id,
}
if self.asset:
kwargs['asset_id'] = self.asset.id
elif self.application:
kwargs['application_id'] = self.application_id
rules = CommandFilterRule.get_queryset(**kwargs)
return rules
def load_system_user_auth(self):
if self.asset:
self.system_user.load_asset_more_auth(self.asset.id, self.user.username, self.user.id)
elif self.application:
self.system_user.load_app_more_auth(self.application.id, self.user.username, self.user.id)
class TempToken(JMSBaseModel):
username = models.CharField(max_length=128, verbose_name=_("Username"))
secret = models.CharField(max_length=64, verbose_name=_("Secret"))
verified = models.BooleanField(default=False, verbose_name=_("Verified"))
date_verified = models.DateTimeField(null=True, verbose_name=_("Date verified"))
date_expired = models.DateTimeField(verbose_name=_("Date expired"))
class Meta:
verbose_name = _("Temporary token")
@property
def user(self):
from users.models import User
return User.objects.filter(username=self.username).first()
@property
def is_valid(self):
not_expired = self.date_expired and self.date_expired > timezone.now()
return not self.verified and not_expired
class SuperConnectionToken(ConnectionToken):
class Meta:
proxy = True
verbose_name = _("Super connection token")

View File

@ -0,0 +1,5 @@
from .access_key import *
from .connection_token import *
from .private_token import *
from .sso_token import *
from .temp_token import *

View File

@ -0,0 +1,31 @@
import uuid
from django.utils.translation import ugettext_lazy as _
from django.conf import settings
from django.db import models
class AccessKey(models.Model):
id = models.UUIDField(verbose_name='AccessKeyID', primary_key=True,
default=uuid.uuid4, editable=False)
secret = models.UUIDField(verbose_name='AccessKeySecret',
default=uuid.uuid4, editable=False)
user = models.ForeignKey(settings.AUTH_USER_MODEL, verbose_name='User',
on_delete=models.CASCADE, related_name='access_keys')
is_active = models.BooleanField(default=True, verbose_name=_('Active'))
date_created = models.DateTimeField(auto_now_add=True)
def get_id(self):
return str(self.id)
def get_secret(self):
return str(self.secret)
def get_full_value(self):
return '{}:{}'.format(self.id, self.secret)
def __str__(self):
return str(self.id)
class Meta:
verbose_name = _("Access key")

View File

@ -0,0 +1,140 @@
import time
from datetime import timedelta
from django.utils import timezone
from django.utils.translation import ugettext_lazy as _
from django.conf import settings
from orgs.mixins.models import OrgModelMixin
from django.db import models
from common.utils import lazyproperty
from common.utils.timezone import as_current_tz
from common.db.models import JMSBaseModel
def date_expired_default():
return timezone.now() + timedelta(seconds=settings.CONNECTION_TOKEN_EXPIRATION)
class ConnectionToken(OrgModelMixin, JMSBaseModel):
user = models.ForeignKey(
'users.User', on_delete=models.SET_NULL, null=True, blank=True,
related_name='connection_tokens', verbose_name=_('User')
)
asset = models.ForeignKey(
'assets.Asset', on_delete=models.SET_NULL, null=True, blank=True,
related_name='connection_tokens', verbose_name=_('Asset'),
)
user_display = models.CharField(max_length=128, default='', verbose_name=_("User display"))
asset_display = models.CharField(max_length=128, default='', verbose_name=_("Asset display"))
protocol = ''
account = models.CharField(max_length=128, default='', verbose_name=_("Account"))
secret = models.CharField(max_length=64, default='', verbose_name=_("Secret"))
date_expired = models.DateTimeField(default=date_expired_default, verbose_name=_("Date expired"))
class Meta:
ordering = ('-date_expired',)
verbose_name = _('Connection token')
permissions = [
('view_connectiontokensecret', _('Can view connection token secret'))
]
@property
def is_expired(self):
return self.date_expired < timezone.now()
@property
def expire_time(self):
interval = self.date_expired - timezone.now()
seconds = interval.total_seconds()
if seconds < 0:
seconds = 0
return int(seconds)
@property
def is_valid(self):
return not self.is_expired
@classmethod
def get_default_date_expired(cls):
return date_expired_default()
def expire(self):
self.date_expired = timezone.now()
self.save()
def renewal(self):
""" 续期 Token将来支持用户自定义创建 token 后,续期策略要修改 """
self.date_expired = self.get_default_date_expired()
self.save()
# actions 和 expired_at 在 check_valid() 中赋值
actions = expire_at = None
def check_valid(self):
from perms.utils.account import PermAccountUtil
if self.is_expired:
is_valid = False
error = _('Connection token expired at: {}').format(as_current_tz(self.date_expired))
return is_valid, error
if not self.user:
is_valid = False
error = _('User not exists')
return is_valid, error
if not self.user.is_valid:
is_valid = False
error = _('User invalid, disabled or expired')
return is_valid, error
if not self.asset:
is_valid = False
error = _('Asset not exists')
return is_valid, error
if not self.asset.is_active:
is_valid = False
error = _('Asset inactive')
return is_valid, error
if not self.account:
is_valid = False
error = _('Account not exists')
return is_valid, error
actions, expire_at = PermAccountUtil().validate_permission(
self.user, self.asset, self.account
)
if not actions or expire_at < time.time():
is_valid = False
error = _('User has no permission to access asset or permission expired')
return is_valid, error
self.actions = actions
self.expire_at = expire_at
return True, ''
@lazyproperty
def domain(self):
domain = self.asset.domain if self.asset else None
return domain
@lazyproperty
def gateway(self):
from assets.models import Domain
if not self.domain:
return
self.domain: Domain
return self.domain.random_gateway()
@lazyproperty
def cmd_filter_rules(self):
from assets.models import CommandFilterRule
kwargs = {
'user_id': self.user.id,
'account': self.account,
}
if self.asset:
kwargs['asset_id'] = self.asset.id
rules = CommandFilterRule.get_queryset(**kwargs)
return rules
class SuperConnectionToken(ConnectionToken):
class Meta:
proxy = True
verbose_name = _("Super connection token")

View File

@ -0,0 +1,9 @@
from django.utils.translation import ugettext_lazy as _
from rest_framework.authtoken.models import Token
class PrivateToken(Token):
"""Inherit from auth token, otherwise migration is boring"""
class Meta:
verbose_name = _('Private Token')

View File

@ -0,0 +1,18 @@
import uuid
from django.utils.translation import ugettext_lazy as _
from django.db import models
from common.db.models import BaseCreateUpdateModel
class SSOToken(BaseCreateUpdateModel):
"""
类似腾讯企业邮的 [单点登录](https://exmail.qq.com/qy_mng_logic/doc#10036)
出于安全考虑这里的 `token` 使用一次随即过期但我们保留每一个生成过的 `token`
"""
authkey = models.UUIDField(primary_key=True, default=uuid.uuid4, verbose_name=_('Token'))
expired = models.BooleanField(default=False, verbose_name=_('Expired'))
user = models.ForeignKey('users.User', on_delete=models.CASCADE, verbose_name=_('User'), db_constraint=False)
class Meta:
verbose_name = _('SSO token')

View File

@ -0,0 +1,26 @@
from django.utils import timezone
from django.utils.translation import ugettext_lazy as _
from django.db import models
from common.db.models import JMSBaseModel
class TempToken(JMSBaseModel):
username = models.CharField(max_length=128, verbose_name=_("Username"))
secret = models.CharField(max_length=64, verbose_name=_("Secret"))
verified = models.BooleanField(default=False, verbose_name=_("Verified"))
date_verified = models.DateTimeField(null=True, verbose_name=_("Date verified"))
date_expired = models.DateTimeField(verbose_name=_("Date expired"))
class Meta:
verbose_name = _("Temporary token")
@property
def user(self):
from users.models import User
return User.objects.filter(username=self.username).first()
@property
def is_valid(self):
not_expired = self.date_expired and self.date_expired > timezone.now()
return not self.verified and not_expired

View File

@ -165,6 +165,6 @@ class ConnectionTokenSecretSerializer(OrgResourceModelSerializerMixin):
class Meta: class Meta:
model = ConnectionToken model = ConnectionToken
fields = [ fields = [
'id', 'secret', 'type', 'user', 'asset', 'application', 'system_user', 'id', 'secret', 'type', 'user', 'asset', 'account',
'remote_app', 'cmd_filter_rules', 'domain', 'gateway', 'actions', 'expired_at', 'cmd_filter_rules', 'domain', 'gateway', 'actions', 'expired_at',
] ]

View File

@ -140,7 +140,10 @@ class Organization(OrgRoleMixin, models.Model):
@classmethod @classmethod
def default(cls): def default(cls):
defaults = dict(id=cls.DEFAULT_ID, name=cls.DEFAULT_NAME) defaults = dict(id=cls.DEFAULT_ID, name=cls.DEFAULT_NAME)
obj, created = cls.objects.get_or_create(defaults=defaults, id=cls.DEFAULT_ID, builtin=True) obj, created = cls.objects.get_or_create(defaults=defaults, id=cls.DEFAULT_ID)
if not obj.builtin:
obj.builtin = True
obj.save()
return obj return obj
@classmethod @classmethod

View File

@ -1,34 +1,20 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
import uuid
import time
from collections import defaultdict
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from django.utils.decorators import method_decorator
from rest_framework.views import APIView, Response
from rest_framework import status
from rest_framework.generics import ( from rest_framework.generics import (
ListAPIView, get_object_or_404, RetrieveAPIView ListAPIView, get_object_or_404
)
from orgs.utils import tmp_to_root_org
from perms.utils.permission import (
get_asset_system_user_ids_with_actions_by_user, validate_permission
) )
from common.permissions import IsValidUser from common.permissions import IsValidUser
from common.utils import get_logger, lazyproperty from common.utils import get_logger, lazyproperty
from perms.hands import User, Asset, Account from perms.hands import User, Asset, Account
from perms import serializers from perms import serializers
from perms.models import AssetPermission, Action from perms.models import Action
from perms.utils import PermAccountUtil from perms.utils import PermAccountUtil
logger = get_logger(__name__) logger = get_logger(__name__)
__all__ = [ __all__ = [
'ValidateUserAssetPermissionApi',
'GetUserAssetPermissionActionsApi',
'UserGrantedAssetAccountsApi', 'UserGrantedAssetAccountsApi',
'MyGrantedAssetAccountsApi', 'MyGrantedAssetAccountsApi',
'UserGrantedAssetSpecialAccountsApi', 'UserGrantedAssetSpecialAccountsApi',
@ -36,70 +22,6 @@ __all__ = [
] ]
@method_decorator(tmp_to_root_org(), name='get')
class GetUserAssetPermissionActionsApi(RetrieveAPIView):
serializer_class = serializers.ActionsSerializer
rbac_perms = {
'retrieve': 'perms.view_userassets',
'GET': 'perms.view_userassets',
}
def get_user(self):
user_id = self.request.query_params.get('user_id', '')
user = get_object_or_404(User, id=user_id)
return user
def get_object(self):
asset_id = self.request.query_params.get('asset_id', '')
account = self.request.query_params.get('account', '')
try:
asset_id = uuid.UUID(asset_id)
except ValueError:
return Response({'msg': False}, status=403)
asset = get_object_or_404(Asset, id=asset_id)
system_users_actions = get_asset_system_user_ids_with_actions_by_user(self.get_user(), asset)
# actions = system_users_actions.get(system_user.id)
actions = system_users_actions.get(account)
return {"actions": actions}
@method_decorator(tmp_to_root_org(), name='get')
class ValidateUserAssetPermissionApi(APIView):
rbac_perms = {
'GET': 'perms.view_userassets'
}
def get(self, request, *args, **kwargs):
user_id = self.request.query_params.get('user_id', '')
asset_id = request.query_params.get('asset_id', '')
account = request.query_params.get('account', '')
action_name = request.query_params.get('action_name', '')
data = {
'has_permission': False,
'expire_at': int(time.time()),
'actions': []
}
if not all((user_id, asset_id, account, action_name)):
return Response(data)
user = User.objects.get(id=user_id)
asset = Asset.objects.valid().get(id=asset_id)
has_perm, actions, expire_at = validate_permission(user, asset, account, action_name)
status_code = status.HTTP_200_OK if has_perm else status.HTTP_403_FORBIDDEN
data = {
'has_permission': has_perm,
'actions': actions,
'expire_at': int(expire_at)
}
return Response(data, status=status_code)
class UserGrantedAssetAccountsApi(ListAPIView): class UserGrantedAssetAccountsApi(ListAPIView):
serializer_class = serializers.AccountsGrantedSerializer serializer_class = serializers.AccountsGrantedSerializer
rbac_perms = { rbac_perms = {

View File

@ -84,12 +84,6 @@ permission_urlpatterns = [
# 授权规则中授权的资产 # 授权规则中授权的资产
path('<uuid:pk>/assets/all/', api.AssetPermissionAllAssetListApi.as_view(), name='asset-permission-all-assets'), path('<uuid:pk>/assets/all/', api.AssetPermissionAllAssetListApi.as_view(), name='asset-permission-all-assets'),
path('<uuid:pk>/users/all/', api.AssetPermissionAllUserListApi.as_view(), name='asset-permission-all-users'), path('<uuid:pk>/users/all/', api.AssetPermissionAllUserListApi.as_view(), name='asset-permission-all-users'),
# 验证用户是否有某个资产和系统用户的权限
# Todo: v3 先不动, 可能会修改连接资产时的逻辑, 直接获取认证信息,获取不到就时没有权限,就不需要校验了
path('user/validate/', api.ValidateUserAssetPermissionApi.as_view(), name='validate-user-asset-permission'),
path('user/actions/', api.GetUserAssetPermissionActionsApi.as_view(), name='get-user-asset-permission-actions'),
] ]
asset_permission_urlpatterns = [ asset_permission_urlpatterns = [

View File

@ -1,3 +1,4 @@
import time
from collections import defaultdict from collections import defaultdict
from assets.models import Account from assets.models import Account
from .permission import AssetPermissionUtil from .permission import AssetPermissionUtil
@ -8,25 +9,29 @@ __all__ = ['PermAccountUtil']
class PermAccountUtil(AssetPermissionUtil): class PermAccountUtil(AssetPermissionUtil):
""" 资产授权账号相关的工具 """ """ 资产授权账号相关的工具 """
def get_perm_accounts_for_user_asset(self, user, asset, with_actions=False):
""" 获取授权给用户某个资产的账号 """
perms = self.get_permissions_for_user_asset(user, asset)
accounts = self.get_perm_accounts_for_permissions(perms, with_actions=with_actions)
return accounts
def get_perm_accounts_for_user(self, user, with_actions=False): def get_perm_accounts_for_user(self, user, with_actions=False):
""" 获取授权给用户的所有账号 """ """ 获取授权给用户的所有账号 """
perms = self.get_permissions_for_user(user) perms = self.get_permissions_for_user(user)
accounts = self.get_perm_accounts_for_permissions(perms, with_actions=with_actions) accounts = self.get_perm_accounts_for_permissions(perms, with_actions=with_actions)
return accounts return accounts
def get_perm_accounts_for_user_asset(self, user, asset, with_actions=False, with_perms=False):
""" 获取授权给用户某个资产的账号 """
perms = self.get_permissions_for_user_asset(user, asset)
accounts = self.get_perm_accounts_for_permissions(perms, with_actions=with_actions)
if with_perms:
return perms, accounts
return accounts
def get_perm_accounts_for_user_group_asset(self, user_group, asset, with_actions=False): def get_perm_accounts_for_user_group_asset(self, user_group, asset, with_actions=False):
""" 获取授权给用户组某个资产的账号 """
perms = self.get_permissions_for_user_group_asset(user_group, asset) perms = self.get_permissions_for_user_group_asset(user_group, asset)
accounts = self.get_perm_accounts_for_permissions(perms, with_actions=with_actions) accounts = self.get_perm_accounts_for_permissions(perms, with_actions=with_actions)
return accounts return accounts
@staticmethod @staticmethod
def get_perm_accounts_for_permissions(permissions, with_actions=False): def get_perm_accounts_for_permissions(permissions, with_actions=False):
""" 获取授权规则包含的账号 """
aid_actions_map = defaultdict(int) aid_actions_map = defaultdict(int)
for perm in permissions: for perm in permissions:
account_ids = perm.get_all_accounts(flat=True) account_ids = perm.get_all_accounts(flat=True)
@ -40,3 +45,14 @@ class PermAccountUtil(AssetPermissionUtil):
account.actions = aid_actions_map.get(str(account.id)) account.actions = aid_actions_map.get(str(account.id))
return accounts return accounts
def validate_permission(self, user, asset, account_username):
""" 校验用户有某个资产下某个账号名的权限 """
perms, accounts = self.get_perm_accounts_for_user_asset(
user, asset, with_actions=True, with_perms=True
)
perm = perms.first()
# Todo: 后面可能需要加上 protocol 进行过滤, 因为同名的账号协议是不一样可能会存在多个
account = accounts.filter(username=account_username).first()
actions = account.actions if account else []
expire_at = perm.date_expired if perm else time.time()
return actions, expire_at

View File

@ -14,21 +14,6 @@ logger = get_logger(__file__)
class AssetPermissionUtil(object): class AssetPermissionUtil(object):
""" 资产授权相关的方法工具 """ """ 资产授权相关的方法工具 """
def get_permissions_for_user_asset(self, user, asset):
""" 获取同时包含用户、资产的授权规则 """
user_perm_ids = self.get_permissions_for_user(user, flat=True)
asset_perm_ids = self.get_permissions_for_asset(asset, flat=True)
perm_ids = set(user_perm_ids) & set(asset_perm_ids)
perms = AssetPermission.objects.filter(id__in=perm_ids)
return perms
def get_permissions_for_user_group_asset(self, user_group, asset):
user_perm_ids = self.get_permissions_for_user_groups([user_group], flat=True)
asset_perm_ids = self.get_permissions_for_asset(asset, flat=True)
perm_ids = set(user_perm_ids) & set(asset_perm_ids)
perms = AssetPermission.objects.filter(id__in=perm_ids)
return perms
def get_permissions_for_user(self, user, with_group=True, flat=False): def get_permissions_for_user(self, user, with_group=True, flat=False):
""" 获取用户的授权规则 """ """ 获取用户的授权规则 """
perm_ids = set() perm_ids = set()
@ -43,22 +28,21 @@ class AssetPermissionUtil(object):
perm_ids.update(group_perm_ids) perm_ids.update(group_perm_ids)
if flat: if flat:
return perm_ids return perm_ids
perms = AssetPermission.objects.filter(id__in=perm_ids) perms = self.get_permissions(ids=perm_ids)
return perms return perms
@staticmethod def get_permissions_for_user_groups(self, user_groups, flat=False):
def get_permissions_for_user_groups(user_groups, flat=False):
""" 获取用户组的授权规则 """ """ 获取用户组的授权规则 """
if isinstance(user_groups, list): if isinstance(user_groups, list):
group_ids = [g.id for g in user_groups] group_ids = [g.id for g in user_groups]
else: else:
group_ids = user_groups.values_list('id', flat=True).distinct() group_ids = user_groups.values_list('id', flat=True).distinct()
group_perm_ids = AssetPermission.user_groups.through.objects\ group_perm_ids = AssetPermission.user_groups.through.objects \
.filter(usergroup_id__in=group_ids)\ .filter(usergroup_id__in=group_ids) \
.values_list('assetpermission_id', flat=True).distinct() .values_list('assetpermission_id', flat=True).distinct()
if flat: if flat:
return group_perm_ids return group_perm_ids
perms = AssetPermission.objects.filter(id__in=group_perm_ids) perms = self.get_permissions(ids=group_perm_ids)
return perms return perms
def get_permissions_for_asset(self, asset, with_node=True, flat=False): def get_permissions_for_asset(self, asset, with_node=True, flat=False):
@ -73,11 +57,10 @@ class AssetPermissionUtil(object):
perm_ids.update(node_perm_ids) perm_ids.update(node_perm_ids)
if flat: if flat:
return perm_ids return perm_ids
perms = AssetPermission.objects.filter(id__in=perm_ids) perms = self.get_permissions(ids=perm_ids)
return perms return perms
@staticmethod def get_permissions_for_nodes(self, nodes, with_ancestor=False, flat=False):
def get_permissions_for_nodes(nodes, with_ancestor=False, flat=False):
""" 获取节点的授权规则 """ """ 获取节点的授权规则 """
if with_ancestor: if with_ancestor:
node_ids = set() node_ids = set()
@ -87,93 +70,29 @@ class AssetPermissionUtil(object):
node_ids.update(_node_ids) node_ids.update(_node_ids)
else: else:
node_ids = nodes.values_list('id', flat=True).distinct() node_ids = nodes.values_list('id', flat=True).distinct()
node_perm_ids = AssetPermission.nodes.through.objects.filter(node_id__in=node_ids) \ perm_ids = AssetPermission.nodes.through.objects.filter(node_id__in=node_ids) \
.values_list('assetpermission_id', flat=True).distinct() .values_list('assetpermission_id', flat=True).distinct()
if flat: if flat:
return node_perm_ids return perm_ids
perms = AssetPermission.objects.filter(id__in=node_perm_ids) perms = self.get_permissions(ids=perm_ids)
return perms return perms
def get_permissions_for_user_asset(self, user, asset):
""" 获取同时包含用户、资产的授权规则 """
user_perm_ids = self.get_permissions_for_user(user, flat=True)
asset_perm_ids = self.get_permissions_for_asset(asset, flat=True)
perm_ids = set(user_perm_ids) & set(asset_perm_ids)
perms = self.get_permissions(ids=perm_ids)
return perms
# TODO: 下面的方法放到类中进行实现 def get_permissions_for_user_group_asset(self, user_group, asset):
user_perm_ids = self.get_permissions_for_user_groups([user_group], flat=True)
asset_perm_ids = self.get_permissions_for_asset(asset, flat=True)
def validate_permission(user, asset, account, action='connect'): perm_ids = set(user_perm_ids) & set(asset_perm_ids)
asset_perm_ids = get_user_all_asset_perm_ids(user) perms = self.get_permissions(ids=perm_ids)
return perms
asset_perm_ids_from_asset = AssetPermission.assets.through.objects.filter(
assetpermission_id__in=asset_perm_ids,
asset_id=asset.id
).values_list('assetpermission_id', flat=True)
nodes = asset.get_nodes()
node_keys = set()
for node in nodes:
ancestor_keys = node.get_ancestor_keys(with_self=True)
node_keys.update(ancestor_keys)
node_ids = set(Node.objects.filter(key__in=node_keys).values_list('id', flat=True))
asset_perm_ids_from_node = AssetPermission.nodes.through.objects.filter(
assetpermission_id__in=asset_perm_ids,
node_id__in=node_ids
).values_list('assetpermission_id', flat=True)
asset_perm_ids = {*asset_perm_ids_from_asset, *asset_perm_ids_from_node}
asset_perms = AssetPermission.objects\
.filter(id__in=asset_perm_ids, accounts__contains=account)\
.order_by('-date_expired')
if asset_perms:
actions = set()
actions_values = asset_perms.values_list('actions', flat=True)
for value in actions_values:
_actions = Action.value_to_choices(value)
actions.update(_actions)
asset_perm: AssetPermission = asset_perms.first()
actions = list(actions)
expire_at = asset_perm.date_expired.timestamp()
else:
actions = []
expire_at = time.time()
# TODO: 组件改造API完成后统一通过actions判断has_perm
has_perm = action in actions
return has_perm, actions, expire_at
def get_asset_system_user_ids_with_actions(asset_perm_ids, asset: Asset):
nodes = asset.get_nodes()
node_keys = set()
for node in nodes:
ancestor_keys = node.get_ancestor_keys(with_self=True)
node_keys.update(ancestor_keys)
queryset = AssetPermission.objects.filter(id__in=asset_perm_ids)\
.filter(Q(assets=asset) | Q(nodes__key__in=node_keys))
asset_protocols = asset.protocols_as_dict.keys()
values = queryset.filter(
system_users__protocol__in=asset_protocols
).distinct().values_list('system_users', 'actions')
system_users_actions = defaultdict(int)
for system_user_id, actions in values:
if None in (system_user_id, actions):
continue
system_users_actions[system_user_id] |= actions
return system_users_actions
def get_asset_system_user_ids_with_actions_by_user(user: User, asset: Asset):
asset_perm_ids = get_user_all_asset_perm_ids(user)
return get_asset_system_user_ids_with_actions(asset_perm_ids, asset)
def has_asset_system_permission(user: User, asset: Asset, account: str):
systemuser_actions_mapper = get_asset_system_user_ids_with_actions_by_user(user, asset)
actions = systemuser_actions_mapper.get(account, 0)
if actions:
return True
return False
@staticmethod
def get_permissions(ids):
perms = AssetPermission.objects.filter(id__in=ids).order_by('-date_expired')
return perms