refactor: 修改 ConnectionToken 关联的逻辑(1)

pull/8997/head
Jiangjie.Bai 2022-10-27 15:47:05 +08:00
parent bd001bb262
commit a260da6cec
8 changed files with 122 additions and 309 deletions

View File

@ -180,9 +180,10 @@ class CommandFilterRule(OrgModelMixin):
@classmethod @classmethod
def get_queryset( def get_queryset(
cls, user_id=None, user_group_id=None, system_user_id=None, cls, user_id=None, user_group_id=None, account=None,
asset_id=None, org_id=None asset_id=None, org_id=None
): ):
from perms.models.const import SpecialAccount
user_groups = [] user_groups = []
user = get_object_or_none(User, pk=user_id) user = get_object_or_none(User, pk=user_id)
if user: if user:
@ -191,7 +192,7 @@ class CommandFilterRule(OrgModelMixin):
if user_group: if user_group:
org_id = user_group.org_id org_id = user_group.org_id
user_groups.append(user_group) user_groups.append(user_group)
account = get_object_or_none(Account, pk=system_user_id)
asset = get_object_or_none(Asset, pk=asset_id) asset = get_object_or_none(Asset, pk=asset_id)
q = Q() q = Q()
if user: if user:
@ -200,7 +201,8 @@ class CommandFilterRule(OrgModelMixin):
q |= Q(user_groups__in=set(user_groups)) q |= Q(user_groups__in=set(user_groups))
if account: if account:
org_id = account.org_id org_id = account.org_id
q |= Q(accounts=account) q |= Q(accounts__contains=list(account)) |\
Q(accounts__contains=SpecialAccount.ALL.value)
if asset: if asset:
org_id = asset.org_id org_id = asset.org_id
q |= Q(assets=asset) q |= Q(assets=asset)

View File

@ -1,6 +1,7 @@
import abc import abc
import os import os
import json import json
import time
import base64 import base64
import urllib.parse import urllib.parse
from django.http import HttpResponse from django.http import HttpResponse
@ -41,17 +42,25 @@ class ConnectionTokenMixin:
def get_request_resources(self, serializer): def get_request_resources(self, serializer):
user = self.get_request_resource_user(serializer) user = self.get_request_resource_user(serializer)
asset = serializer.validated_data.get('asset') asset = serializer.validated_data.get('asset')
application = serializer.validated_data.get('application') account = serializer.validated_data.get('account')
system_user = serializer.validated_data.get('system_user') return user, asset, account
return user, asset, application, system_user
@staticmethod @staticmethod
def check_user_has_resource_permission(user, asset, application, system_user): def check_user_has_resource_permission(user, asset, account):
from perms.utils.asset import has_asset_system_permission from perms.utils.account import PermAccountUtil
if not asset or not user:
error = ''
raise PermissionDenied(error)
if asset and not has_asset_system_permission(user, asset, system_user): actions, expire_at = PermAccountUtil().validate_permission(
error = f'User not has this asset and system user permission: ' \ user, asset, account_username=account
f'user={user.id} system_user={system_user.id} asset={asset.id}' )
if not actions:
error = ''
raise PermissionDenied(error)
if expire_at < time.time():
error = ''
raise PermissionDenied(error) raise PermissionDenied(error)
def get_smart_endpoint(self, protocol, asset=None, application=None): def get_smart_endpoint(self, protocol, asset=None, application=None):
@ -69,13 +78,12 @@ class ConnectionTokenMixin:
return true_value if is_true(os.getenv(env_key, env_default)) else false_value return true_value if is_true(os.getenv(env_key, env_default)) else false_value
def get_client_protocol_data(self, token: ConnectionToken): def get_client_protocol_data(self, token: ConnectionToken):
from assets.models import SystemUser protocol = token.protocol
protocol = token.system_user.protocol
username = token.user.username username = token.user.username
rdp_config = ssh_token = '' rdp_config = ssh_token = ''
if protocol == SystemUser.Protocol.rdp: if protocol == 'rdp':
filename, rdp_config = self.get_rdp_file_info(token) filename, rdp_config = self.get_rdp_file_info(token)
elif protocol == SystemUser.Protocol.ssh: elif protocol == 'ssh':
filename, ssh_token = self.get_ssh_token(token) filename, ssh_token = self.get_ssh_token(token)
else: else:
raise ValueError('Protocol not support: {}'.format(protocol)) raise ValueError('Protocol not support: {}'.format(protocol))
@ -134,15 +142,12 @@ class ConnectionTokenMixin:
rdp_options['screen mode id:i'] = '2' if full_screen else '1' rdp_options['screen mode id:i'] = '2' if full_screen else '1'
# 设置 RDP Server 地址 # 设置 RDP Server 地址
endpoint = self.get_smart_endpoint( endpoint = self.get_smart_endpoint(protocol='rdp', asset=token.asset)
protocol='rdp', asset=token.asset, application=token.application
)
rdp_options['full address:s'] = f'{endpoint.host}:{endpoint.rdp_port}' rdp_options['full address:s'] = f'{endpoint.host}:{endpoint.rdp_port}'
# 设置用户名 # 设置用户名
rdp_options['username:s'] = '{}|{}'.format(token.user.username, str(token.id)) rdp_options['username:s'] = '{}|{}'.format(token.user.username, str(token.id))
if token.system_user.ad_domain: # rdp_options['domain:s'] = token.account_ad_domain
rdp_options['domain:s'] = token.system_user.ad_domain
# 设置宽高 # 设置宽高
height = self.request.query_params.get('height') height = self.request.query_params.get('height')
@ -158,13 +163,12 @@ class ConnectionTokenMixin:
if token.asset: if token.asset:
name = token.asset.name name = token.asset.name
elif token.application and token.application.category_remote_app: # remote-app
app = '||jmservisor' # app = '||jmservisor'
name = token.application.name # rdp_options['remoteapplicationmode:i'] = '1'
rdp_options['remoteapplicationmode:i'] = '1' # rdp_options['alternate shell:s'] = app
rdp_options['alternate shell:s'] = app # rdp_options['remoteapplicationprogram:s'] = app
rdp_options['remoteapplicationprogram:s'] = app # rdp_options['remoteapplicationname:s'] = name
rdp_options['remoteapplicationname:s'] = name
else: else:
name = '*' name = '*'
prefix_name = f'{token.user.username}-{name}' prefix_name = f'{token.user.username}-{name}'
@ -188,16 +192,12 @@ class ConnectionTokenMixin:
def get_ssh_token(self, token: ConnectionToken): def get_ssh_token(self, token: ConnectionToken):
if token.asset: if token.asset:
name = token.asset.name name = token.asset.name
elif token.application:
name = token.application.name
else: else:
name = '*' name = '*'
prefix_name = f'{token.user.username}-{name}' prefix_name = f'{token.user.username}-{name}'
filename = self.get_connect_filename(prefix_name) filename = self.get_connect_filename(prefix_name)
endpoint = self.get_smart_endpoint( endpoint = self.get_smart_endpoint(protocol='ssh', asset=token.asset)
protocol='ssh', asset=token.asset, application=token.application
)
data = { data = {
'ip': endpoint.host, 'ip': endpoint.host,
'port': str(endpoint.ssh_port), 'port': str(endpoint.ssh_port),
@ -251,8 +251,8 @@ class ConnectionTokenViewSet(ConnectionTokenMixin, RootOrgViewMixin, JMSModelVie
return token return token
def perform_create(self, serializer): def perform_create(self, serializer):
user, asset, application, system_user = self.get_request_resources(serializer) user, asset, account = self.get_request_resources(serializer)
self.check_user_has_resource_permission(user, asset, application, system_user) self.check_user_has_resource_permission(user, asset, account)
return super(ConnectionTokenViewSet, self).perform_create(serializer) return super(ConnectionTokenViewSet, self).perform_create(serializer)
@action(methods=['POST'], detail=False, url_path='secret-info/detail') @action(methods=['POST'], detail=False, url_path='secret-info/detail')
@ -264,7 +264,6 @@ class ConnectionTokenViewSet(ConnectionTokenMixin, RootOrgViewMixin, JMSModelVie
token_id = request.data.get('token') or '' token_id = request.data.get('token') or ''
token = get_object_or_404(ConnectionToken, pk=token_id) token = get_object_or_404(ConnectionToken, pk=token_id)
self.check_token_valid(token) self.check_token_valid(token)
token.load_system_user_auth()
serializer = self.get_serializer(instance=token) serializer = self.get_serializer(instance=token)
return Response(serializer.data, status=status.HTTP_200_OK) return Response(serializer.data, status=status.HTTP_200_OK)

View File

@ -1,3 +1,4 @@
import time
import uuid import uuid
from datetime import datetime, timedelta from datetime import datetime, timedelta
from django.utils import timezone from django.utils import timezone
@ -63,22 +64,20 @@ def date_expired_default():
class ConnectionToken(OrgModelMixin, JMSBaseModel): class ConnectionToken(OrgModelMixin, JMSBaseModel):
secret = models.CharField(max_length=64, default='', verbose_name=_("Secret"))
date_expired = models.DateTimeField(
default=date_expired_default, verbose_name=_("Date expired")
)
user = models.ForeignKey( user = models.ForeignKey(
'users.User', on_delete=models.SET_NULL, verbose_name=_('User'), 'users.User', on_delete=models.SET_NULL, null=True, blank=True,
related_name='connection_tokens', null=True, blank=True related_name='connection_tokens', verbose_name=_('User')
)
asset = models.ForeignKey(
'assets.Asset', on_delete=models.SET_NULL, null=True, blank=True,
related_name='connection_tokens', verbose_name=_('Asset'),
) )
user_display = models.CharField(max_length=128, default='', verbose_name=_("User display")) user_display = models.CharField(max_length=128, default='', verbose_name=_("User display"))
asset = models.ForeignKey(
'assets.Asset', on_delete=models.SET_NULL, verbose_name=_('Asset'),
related_name='connection_tokens', null=True, blank=True
)
asset_display = models.CharField(max_length=128, default='', verbose_name=_("Asset display")) asset_display = models.CharField(max_length=128, default='', verbose_name=_("Asset display"))
protocol = ''
account = models.CharField(max_length=128, default='', verbose_name=_("Account")) account = models.CharField(max_length=128, default='', verbose_name=_("Account"))
secret = models.CharField(max_length=64, default='', verbose_name=_("Secret"))
date_expired = models.DateTimeField(default=date_expired_default, verbose_name=_("Date expired"))
class Meta: class Meta:
ordering = ('-date_expired',) ordering = ('-date_expired',)
@ -87,10 +86,6 @@ class ConnectionToken(OrgModelMixin, JMSBaseModel):
('view_connectiontokensecret', _('Can view connection token secret')) ('view_connectiontokensecret', _('Can view connection token secret'))
] ]
@classmethod
def get_default_date_expired(cls):
return date_expired_default()
@property @property
def is_expired(self): def is_expired(self):
return self.date_expired < timezone.now() return self.date_expired < timezone.now()
@ -103,32 +98,32 @@ class ConnectionToken(OrgModelMixin, JMSBaseModel):
seconds = 0 seconds = 0
return int(seconds) return int(seconds)
def expire(self):
self.date_expired = timezone.now()
self.save()
@property @property
def is_valid(self): def is_valid(self):
return not self.is_expired return not self.is_expired
def is_type(self, tp): @classmethod
return self.type == tp def get_default_date_expired(cls):
return date_expired_default()
def expire(self):
self.date_expired = timezone.now()
self.save()
def renewal(self): def renewal(self):
""" 续期 Token将来支持用户自定义创建 token 后,续期策略要修改 """ """ 续期 Token将来支持用户自定义创建 token 后,续期策略要修改 """
self.date_expired = self.get_default_date_expired() self.date_expired = self.get_default_date_expired()
self.save() self.save()
actions = expired_at = None # actions 和 expired_at 在 check_valid() 中赋值 # actions 和 expired_at 在 check_valid() 中赋值
actions = expire_at = None
def check_valid(self): def check_valid(self):
from perms.utils.permission import validate_permission as asset_validate_permission from perms.utils.account import PermAccountUtil
if self.is_expired: if self.is_expired:
is_valid = False is_valid = False
error = _('Connection token expired at: {}').format(as_current_tz(self.date_expired)) error = _('Connection token expired at: {}').format(as_current_tz(self.date_expired))
return is_valid, error return is_valid, error
if not self.user: if not self.user:
is_valid = False is_valid = False
error = _('User not exists') error = _('User not exists')
@ -137,44 +132,33 @@ class ConnectionToken(OrgModelMixin, JMSBaseModel):
is_valid = False is_valid = False
error = _('User invalid, disabled or expired') error = _('User invalid, disabled or expired')
return is_valid, error return is_valid, error
if not self.asset:
is_valid = False
error = _('Asset not exists')
return is_valid, error
if not self.asset.is_active:
is_valid = False
error = _('Asset inactive')
return is_valid, error
if not self.account: if not self.account:
is_valid = False is_valid = False
error = _('Account not exists') error = _('Account not exists')
return is_valid, error return is_valid, error
if not self.asset: actions, expire_at = PermAccountUtil().validate_permission(
is_valid = False
error = _('Asset not exists')
return is_valid, error
if not self.asset.is_active:
is_valid = False
error = _('Asset inactive')
return is_valid, error
has_perm, actions, expired_at = asset_validate_permission(
self.user, self.asset, self.account self.user, self.asset, self.account
) )
if not has_perm: if not actions or expire_at < time.time():
is_valid = False is_valid = False
error = _('User has no permission to access asset or permission expired') error = _('User has no permission to access asset or permission expired')
return is_valid, error return is_valid, error
self.actions = actions self.actions = actions
self.expired_at = expired_at self.expire_at = expire_at
return True, '' return True, ''
@lazyproperty @lazyproperty
def domain(self): def domain(self):
if self.asset: domain = self.asset.domain if self.asset else None
return self.asset.domain
if not self.application:
return
if self.application.category_remote_app:
asset = self.application.get_remote_app_asset()
domain = asset.domain if asset else None
else:
domain = self.application.domain
return domain return domain
@lazyproperty @lazyproperty
@ -185,41 +169,18 @@ class ConnectionToken(OrgModelMixin, JMSBaseModel):
self.domain: Domain self.domain: Domain
return self.domain.random_gateway() return self.domain.random_gateway()
@lazyproperty
def remote_app(self):
if not self.application:
return {}
if not self.application.category_remote_app:
return {}
return self.application.get_rdp_remote_app_setting()
@lazyproperty
def asset_or_remote_app_asset(self):
if self.asset:
return self.asset
if self.application and self.application.category_remote_app:
return self.application.get_remote_app_asset()
@lazyproperty @lazyproperty
def cmd_filter_rules(self): def cmd_filter_rules(self):
from assets.models import CommandFilterRule from assets.models import CommandFilterRule
kwargs = { kwargs = {
'user_id': self.user.id, 'user_id': self.user.id,
'system_user_id': self.system_user.id, 'account': self.account,
} }
if self.asset: if self.asset:
kwargs['asset_id'] = self.asset.id kwargs['asset_id'] = self.asset.id
elif self.application:
kwargs['application_id'] = self.application_id
rules = CommandFilterRule.get_queryset(**kwargs) rules = CommandFilterRule.get_queryset(**kwargs)
return rules return rules
def load_system_user_auth(self):
if self.asset:
self.system_user.load_asset_more_auth(self.asset.id, self.user.username, self.user.id)
elif self.application:
self.system_user.load_app_more_auth(self.application.id, self.user.username, self.user.id)
class TempToken(JMSBaseModel): class TempToken(JMSBaseModel):
username = models.CharField(max_length=128, verbose_name=_("Username")) username = models.CharField(max_length=128, verbose_name=_("Username"))

View File

@ -165,6 +165,6 @@ class ConnectionTokenSecretSerializer(OrgResourceModelSerializerMixin):
class Meta: class Meta:
model = ConnectionToken model = ConnectionToken
fields = [ fields = [
'id', 'secret', 'type', 'user', 'asset', 'application', 'system_user', 'id', 'secret', 'type', 'user', 'asset', 'account',
'remote_app', 'cmd_filter_rules', 'domain', 'gateway', 'actions', 'expired_at', 'cmd_filter_rules', 'domain', 'gateway', 'actions', 'expired_at',
] ]

View File

@ -1,34 +1,20 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
import uuid
import time
from collections import defaultdict
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from django.utils.decorators import method_decorator
from rest_framework.views import APIView, Response
from rest_framework import status
from rest_framework.generics import ( from rest_framework.generics import (
ListAPIView, get_object_or_404, RetrieveAPIView ListAPIView, get_object_or_404
)
from orgs.utils import tmp_to_root_org
from perms.utils.permission import (
get_asset_system_user_ids_with_actions_by_user, validate_permission
) )
from common.permissions import IsValidUser from common.permissions import IsValidUser
from common.utils import get_logger, lazyproperty from common.utils import get_logger, lazyproperty
from perms.hands import User, Asset, Account from perms.hands import User, Asset, Account
from perms import serializers from perms import serializers
from perms.models import AssetPermission, Action from perms.models import Action
from perms.utils import PermAccountUtil from perms.utils import PermAccountUtil
logger = get_logger(__name__) logger = get_logger(__name__)
__all__ = [ __all__ = [
'ValidateUserAssetPermissionApi',
'GetUserAssetPermissionActionsApi',
'UserGrantedAssetAccountsApi', 'UserGrantedAssetAccountsApi',
'MyGrantedAssetAccountsApi', 'MyGrantedAssetAccountsApi',
'UserGrantedAssetSpecialAccountsApi', 'UserGrantedAssetSpecialAccountsApi',
@ -36,70 +22,6 @@ __all__ = [
] ]
@method_decorator(tmp_to_root_org(), name='get')
class GetUserAssetPermissionActionsApi(RetrieveAPIView):
serializer_class = serializers.ActionsSerializer
rbac_perms = {
'retrieve': 'perms.view_userassets',
'GET': 'perms.view_userassets',
}
def get_user(self):
user_id = self.request.query_params.get('user_id', '')
user = get_object_or_404(User, id=user_id)
return user
def get_object(self):
asset_id = self.request.query_params.get('asset_id', '')
account = self.request.query_params.get('account', '')
try:
asset_id = uuid.UUID(asset_id)
except ValueError:
return Response({'msg': False}, status=403)
asset = get_object_or_404(Asset, id=asset_id)
system_users_actions = get_asset_system_user_ids_with_actions_by_user(self.get_user(), asset)
# actions = system_users_actions.get(system_user.id)
actions = system_users_actions.get(account)
return {"actions": actions}
@method_decorator(tmp_to_root_org(), name='get')
class ValidateUserAssetPermissionApi(APIView):
rbac_perms = {
'GET': 'perms.view_userassets'
}
def get(self, request, *args, **kwargs):
user_id = self.request.query_params.get('user_id', '')
asset_id = request.query_params.get('asset_id', '')
account = request.query_params.get('account', '')
action_name = request.query_params.get('action_name', '')
data = {
'has_permission': False,
'expire_at': int(time.time()),
'actions': []
}
if not all((user_id, asset_id, account, action_name)):
return Response(data)
user = User.objects.get(id=user_id)
asset = Asset.objects.valid().get(id=asset_id)
has_perm, actions, expire_at = validate_permission(user, asset, account, action_name)
status_code = status.HTTP_200_OK if has_perm else status.HTTP_403_FORBIDDEN
data = {
'has_permission': has_perm,
'actions': actions,
'expire_at': int(expire_at)
}
return Response(data, status=status_code)
class UserGrantedAssetAccountsApi(ListAPIView): class UserGrantedAssetAccountsApi(ListAPIView):
serializer_class = serializers.AccountsGrantedSerializer serializer_class = serializers.AccountsGrantedSerializer
rbac_perms = { rbac_perms = {

View File

@ -84,12 +84,6 @@ permission_urlpatterns = [
# 授权规则中授权的资产 # 授权规则中授权的资产
path('<uuid:pk>/assets/all/', api.AssetPermissionAllAssetListApi.as_view(), name='asset-permission-all-assets'), path('<uuid:pk>/assets/all/', api.AssetPermissionAllAssetListApi.as_view(), name='asset-permission-all-assets'),
path('<uuid:pk>/users/all/', api.AssetPermissionAllUserListApi.as_view(), name='asset-permission-all-users'), path('<uuid:pk>/users/all/', api.AssetPermissionAllUserListApi.as_view(), name='asset-permission-all-users'),
# 验证用户是否有某个资产和系统用户的权限
# Todo: v3 先不动, 可能会修改连接资产时的逻辑, 直接获取认证信息,获取不到就时没有权限,就不需要校验了
path('user/validate/', api.ValidateUserAssetPermissionApi.as_view(), name='validate-user-asset-permission'),
path('user/actions/', api.GetUserAssetPermissionActionsApi.as_view(), name='get-user-asset-permission-actions'),
] ]
asset_permission_urlpatterns = [ asset_permission_urlpatterns = [

View File

@ -1,3 +1,4 @@
import time
from collections import defaultdict from collections import defaultdict
from assets.models import Account from assets.models import Account
from .permission import AssetPermissionUtil from .permission import AssetPermissionUtil
@ -8,25 +9,29 @@ __all__ = ['PermAccountUtil']
class PermAccountUtil(AssetPermissionUtil): class PermAccountUtil(AssetPermissionUtil):
""" 资产授权账号相关的工具 """ """ 资产授权账号相关的工具 """
def get_perm_accounts_for_user_asset(self, user, asset, with_actions=False):
""" 获取授权给用户某个资产的账号 """
perms = self.get_permissions_for_user_asset(user, asset)
accounts = self.get_perm_accounts_for_permissions(perms, with_actions=with_actions)
return accounts
def get_perm_accounts_for_user(self, user, with_actions=False): def get_perm_accounts_for_user(self, user, with_actions=False):
""" 获取授权给用户的所有账号 """ """ 获取授权给用户的所有账号 """
perms = self.get_permissions_for_user(user) perms = self.get_permissions_for_user(user)
accounts = self.get_perm_accounts_for_permissions(perms, with_actions=with_actions) accounts = self.get_perm_accounts_for_permissions(perms, with_actions=with_actions)
return accounts return accounts
def get_perm_accounts_for_user_asset(self, user, asset, with_actions=False, with_perms=False):
""" 获取授权给用户某个资产的账号 """
perms = self.get_permissions_for_user_asset(user, asset)
accounts = self.get_perm_accounts_for_permissions(perms, with_actions=with_actions)
if with_perms:
return perms, accounts
return accounts
def get_perm_accounts_for_user_group_asset(self, user_group, asset, with_actions=False): def get_perm_accounts_for_user_group_asset(self, user_group, asset, with_actions=False):
""" 获取授权给用户组某个资产的账号 """
perms = self.get_permissions_for_user_group_asset(user_group, asset) perms = self.get_permissions_for_user_group_asset(user_group, asset)
accounts = self.get_perm_accounts_for_permissions(perms, with_actions=with_actions) accounts = self.get_perm_accounts_for_permissions(perms, with_actions=with_actions)
return accounts return accounts
@staticmethod @staticmethod
def get_perm_accounts_for_permissions(permissions, with_actions=False): def get_perm_accounts_for_permissions(permissions, with_actions=False):
""" 获取授权规则包含的账号 """
aid_actions_map = defaultdict(int) aid_actions_map = defaultdict(int)
for perm in permissions: for perm in permissions:
account_ids = perm.get_all_accounts(flat=True) account_ids = perm.get_all_accounts(flat=True)
@ -40,3 +45,14 @@ class PermAccountUtil(AssetPermissionUtil):
account.actions = aid_actions_map.get(str(account.id)) account.actions = aid_actions_map.get(str(account.id))
return accounts return accounts
def validate_permission(self, user, asset, account_username):
""" 校验用户有某个资产下某个账号名的权限 """
perms, accounts = self.get_perm_accounts_for_user_asset(
user, asset, with_actions=True, with_perms=True
)
perm = perms.first()
# Todo: 后面可能需要加上 protocol 进行过滤, 因为同名的账号协议是不一样可能会存在多个
account = accounts.filter(username=account_username).first()
actions = account.actions if account else []
expire_at = perm.date_expired if perm else time.time()
return actions, expire_at

View File

@ -14,21 +14,6 @@ logger = get_logger(__file__)
class AssetPermissionUtil(object): class AssetPermissionUtil(object):
""" 资产授权相关的方法工具 """ """ 资产授权相关的方法工具 """
def get_permissions_for_user_asset(self, user, asset):
""" 获取同时包含用户、资产的授权规则 """
user_perm_ids = self.get_permissions_for_user(user, flat=True)
asset_perm_ids = self.get_permissions_for_asset(asset, flat=True)
perm_ids = set(user_perm_ids) & set(asset_perm_ids)
perms = AssetPermission.objects.filter(id__in=perm_ids)
return perms
def get_permissions_for_user_group_asset(self, user_group, asset):
user_perm_ids = self.get_permissions_for_user_groups([user_group], flat=True)
asset_perm_ids = self.get_permissions_for_asset(asset, flat=True)
perm_ids = set(user_perm_ids) & set(asset_perm_ids)
perms = AssetPermission.objects.filter(id__in=perm_ids)
return perms
def get_permissions_for_user(self, user, with_group=True, flat=False): def get_permissions_for_user(self, user, with_group=True, flat=False):
""" 获取用户的授权规则 """ """ 获取用户的授权规则 """
perm_ids = set() perm_ids = set()
@ -43,22 +28,21 @@ class AssetPermissionUtil(object):
perm_ids.update(group_perm_ids) perm_ids.update(group_perm_ids)
if flat: if flat:
return perm_ids return perm_ids
perms = AssetPermission.objects.filter(id__in=perm_ids) perms = self.get_permissions(ids=perm_ids)
return perms return perms
@staticmethod def get_permissions_for_user_groups(self, user_groups, flat=False):
def get_permissions_for_user_groups(user_groups, flat=False):
""" 获取用户组的授权规则 """ """ 获取用户组的授权规则 """
if isinstance(user_groups, list): if isinstance(user_groups, list):
group_ids = [g.id for g in user_groups] group_ids = [g.id for g in user_groups]
else: else:
group_ids = user_groups.values_list('id', flat=True).distinct() group_ids = user_groups.values_list('id', flat=True).distinct()
group_perm_ids = AssetPermission.user_groups.through.objects\ group_perm_ids = AssetPermission.user_groups.through.objects \
.filter(usergroup_id__in=group_ids)\ .filter(usergroup_id__in=group_ids) \
.values_list('assetpermission_id', flat=True).distinct() .values_list('assetpermission_id', flat=True).distinct()
if flat: if flat:
return group_perm_ids return group_perm_ids
perms = AssetPermission.objects.filter(id__in=group_perm_ids) perms = self.get_permissions(ids=group_perm_ids)
return perms return perms
def get_permissions_for_asset(self, asset, with_node=True, flat=False): def get_permissions_for_asset(self, asset, with_node=True, flat=False):
@ -73,11 +57,10 @@ class AssetPermissionUtil(object):
perm_ids.update(node_perm_ids) perm_ids.update(node_perm_ids)
if flat: if flat:
return perm_ids return perm_ids
perms = AssetPermission.objects.filter(id__in=perm_ids) perms = self.get_permissions(ids=perm_ids)
return perms return perms
@staticmethod def get_permissions_for_nodes(self, nodes, with_ancestor=False, flat=False):
def get_permissions_for_nodes(nodes, with_ancestor=False, flat=False):
""" 获取节点的授权规则 """ """ 获取节点的授权规则 """
if with_ancestor: if with_ancestor:
node_ids = set() node_ids = set()
@ -87,93 +70,29 @@ class AssetPermissionUtil(object):
node_ids.update(_node_ids) node_ids.update(_node_ids)
else: else:
node_ids = nodes.values_list('id', flat=True).distinct() node_ids = nodes.values_list('id', flat=True).distinct()
node_perm_ids = AssetPermission.nodes.through.objects.filter(node_id__in=node_ids) \ perm_ids = AssetPermission.nodes.through.objects.filter(node_id__in=node_ids) \
.values_list('assetpermission_id', flat=True).distinct() .values_list('assetpermission_id', flat=True).distinct()
if flat: if flat:
return node_perm_ids return perm_ids
perms = AssetPermission.objects.filter(id__in=node_perm_ids) perms = self.get_permissions(ids=perm_ids)
return perms return perms
def get_permissions_for_user_asset(self, user, asset):
""" 获取同时包含用户、资产的授权规则 """
user_perm_ids = self.get_permissions_for_user(user, flat=True)
asset_perm_ids = self.get_permissions_for_asset(asset, flat=True)
perm_ids = set(user_perm_ids) & set(asset_perm_ids)
perms = self.get_permissions(ids=perm_ids)
return perms
# TODO: 下面的方法放到类中进行实现 def get_permissions_for_user_group_asset(self, user_group, asset):
user_perm_ids = self.get_permissions_for_user_groups([user_group], flat=True)
asset_perm_ids = self.get_permissions_for_asset(asset, flat=True)
def validate_permission(user, asset, account, action='connect'): perm_ids = set(user_perm_ids) & set(asset_perm_ids)
asset_perm_ids = get_user_all_asset_perm_ids(user) perms = self.get_permissions(ids=perm_ids)
return perms
asset_perm_ids_from_asset = AssetPermission.assets.through.objects.filter(
assetpermission_id__in=asset_perm_ids,
asset_id=asset.id
).values_list('assetpermission_id', flat=True)
nodes = asset.get_nodes()
node_keys = set()
for node in nodes:
ancestor_keys = node.get_ancestor_keys(with_self=True)
node_keys.update(ancestor_keys)
node_ids = set(Node.objects.filter(key__in=node_keys).values_list('id', flat=True))
asset_perm_ids_from_node = AssetPermission.nodes.through.objects.filter(
assetpermission_id__in=asset_perm_ids,
node_id__in=node_ids
).values_list('assetpermission_id', flat=True)
asset_perm_ids = {*asset_perm_ids_from_asset, *asset_perm_ids_from_node}
asset_perms = AssetPermission.objects\
.filter(id__in=asset_perm_ids, accounts__contains=account)\
.order_by('-date_expired')
if asset_perms:
actions = set()
actions_values = asset_perms.values_list('actions', flat=True)
for value in actions_values:
_actions = Action.value_to_choices(value)
actions.update(_actions)
asset_perm: AssetPermission = asset_perms.first()
actions = list(actions)
expire_at = asset_perm.date_expired.timestamp()
else:
actions = []
expire_at = time.time()
# TODO: 组件改造API完成后统一通过actions判断has_perm
has_perm = action in actions
return has_perm, actions, expire_at
def get_asset_system_user_ids_with_actions(asset_perm_ids, asset: Asset):
nodes = asset.get_nodes()
node_keys = set()
for node in nodes:
ancestor_keys = node.get_ancestor_keys(with_self=True)
node_keys.update(ancestor_keys)
queryset = AssetPermission.objects.filter(id__in=asset_perm_ids)\
.filter(Q(assets=asset) | Q(nodes__key__in=node_keys))
asset_protocols = asset.protocols_as_dict.keys()
values = queryset.filter(
system_users__protocol__in=asset_protocols
).distinct().values_list('system_users', 'actions')
system_users_actions = defaultdict(int)
for system_user_id, actions in values:
if None in (system_user_id, actions):
continue
system_users_actions[system_user_id] |= actions
return system_users_actions
def get_asset_system_user_ids_with_actions_by_user(user: User, asset: Asset):
asset_perm_ids = get_user_all_asset_perm_ids(user)
return get_asset_system_user_ids_with_actions(asset_perm_ids, asset)
def has_asset_system_permission(user: User, asset: Asset, account: str):
systemuser_actions_mapper = get_asset_system_user_ids_with_actions_by_user(user, asset)
actions = systemuser_actions_mapper.get(account, 0)
if actions:
return True
return False
@staticmethod
def get_permissions(ids):
perms = AssetPermission.objects.filter(id__in=ids).order_by('-date_expired')
return perms