Merge pull request #11981 from jumpserver/pr@dev@feat_perm_add_protocols

perf: 资产授权添加协议
pull/11992/head
老广 2023-10-30 10:12:45 +08:00 committed by GitHub
commit 7669744312
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 255 additions and 130 deletions

View File

@ -6,4 +6,5 @@ from .label import *
from .mixin import * from .mixin import *
from .node import * from .node import *
from .platform import * from .platform import *
from .protocol import *
from .tree import * from .tree import *

View File

@ -13,7 +13,7 @@ __all__ = ['CategoryViewSet']
class CategoryViewSet(ListModelMixin, JMSGenericViewSet): class CategoryViewSet(ListModelMixin, JMSGenericViewSet):
serializer_classes = { serializer_classes = {
'default': CategorySerializer, 'default': CategorySerializer,
'types': TypeSerializer 'types': TypeSerializer,
} }
permission_classes = (IsValidUser,) permission_classes = (IsValidUser,)

View File

@ -0,0 +1,15 @@
from rest_framework.generics import ListAPIView
from assets import serializers
from assets.const import Protocol
from common.permissions import IsValidUser
__all__ = ['ProtocolListApi']
class ProtocolListApi(ListAPIView):
serializer_class = serializers.ProtocolSerializer
permission_classes = (IsValidUser,)
def get_queryset(self):
return list(Protocol.protocols())

View File

@ -294,6 +294,10 @@ class Protocol(ChoicesMixin, models.TextChoices):
**cls.gpt_protocols(), **cls.gpt_protocols(),
} }
@classmethod
def protocols(cls):
return cls.settings().keys()
@classmethod @classmethod
@cached_method(ttl=600) @cached_method(ttl=600)
def xpack_protocols(cls): def xpack_protocols(cls):

View File

@ -1,5 +1,9 @@
from rest_framework import serializers
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
__all__ = [
'TypeSerializer', 'CategorySerializer', 'ProtocolSerializer'
]
class TypeSerializer(serializers.Serializer): class TypeSerializer(serializers.Serializer):
@ -13,3 +17,8 @@ class CategorySerializer(serializers.Serializer):
label = serializers.CharField(max_length=64, required=False, allow_blank=True, label=_('Label')) label = serializers.CharField(max_length=64, required=False, allow_blank=True, label=_('Label'))
value = serializers.CharField(max_length=64, required=False, allow_blank=True, label=_('Value')) value = serializers.CharField(max_length=64, required=False, allow_blank=True, label=_('Value'))
types = TypeSerializer(many=True, required=False, label=_('Types'), read_only=True) types = TypeSerializer(many=True, required=False, label=_('Types'), read_only=True)
class ProtocolSerializer(serializers.Serializer):
label = serializers.CharField(max_length=64, required=False, allow_blank=True, label=_('Label'))
value = serializers.CharField(max_length=64, required=False, allow_blank=True, label=_('Value'))

View File

@ -26,6 +26,7 @@ router.register(r'protocol-settings', api.PlatformProtocolViewSet, 'protocol-set
urlpatterns = [ urlpatterns = [
# path('assets/<uuid:pk>/gateways/', api.AssetGatewayListApi.as_view(), name='asset-gateway-list'), # path('assets/<uuid:pk>/gateways/', api.AssetGatewayListApi.as_view(), name='asset-gateway-list'),
path('protocols/', api.ProtocolListApi.as_view(), name='asset-protocol'),
path('assets/<uuid:pk>/tasks/', api.AssetTaskCreateApi.as_view(), name='asset-task-create'), path('assets/<uuid:pk>/tasks/', api.AssetTaskCreateApi.as_view(), name='asset-task-create'),
path('assets/tasks/', api.AssetsTaskCreateApi.as_view(), name='assets-task-create'), path('assets/tasks/', api.AssetsTaskCreateApi.as_view(), name='assets-task-create'),
path('assets/<uuid:pk>/perm-users/', api.AssetPermUserListApi.as_view(), name='asset-perm-user-list'), path('assets/<uuid:pk>/perm-users/', api.AssetPermUserListApi.as_view(), name='asset-perm-user-list'),

View File

@ -351,8 +351,9 @@ class ConnectionTokenViewSet(ExtraActionApiMixin, RootOrgViewMixin, JMSModelView
self._insert_connect_options(data, user) self._insert_connect_options(data, user)
asset = data.get('asset') asset = data.get('asset')
account_name = data.get('account') account_name = data.get('account')
protocol = data.get('protocol')
self.input_username = self.get_input_username(data) self.input_username = self.get_input_username(data)
_data = self._validate(user, asset, account_name) _data = self._validate(user, asset, account_name, protocol)
data.update(_data) data.update(_data)
return serializer return serializer
@ -360,12 +361,12 @@ class ConnectionTokenViewSet(ExtraActionApiMixin, RootOrgViewMixin, JMSModelView
user = token.user user = token.user
asset = token.asset asset = token.asset
account_name = token.account account_name = token.account
_data = self._validate(user, asset, account_name) _data = self._validate(user, asset, account_name, token.protocol)
for k, v in _data.items(): for k, v in _data.items():
setattr(token, k, v) setattr(token, k, v)
return token return token
def _validate(self, user, asset, account_name): def _validate(self, user, asset, account_name, protocol):
data = dict() data = dict()
data['org_id'] = asset.org_id data['org_id'] = asset.org_id
data['user'] = user data['user'] = user
@ -374,7 +375,7 @@ class ConnectionTokenViewSet(ExtraActionApiMixin, RootOrgViewMixin, JMSModelView
if account_name == AliasAccount.ANON and asset.category not in ['web', 'custom']: if account_name == AliasAccount.ANON and asset.category not in ['web', 'custom']:
raise ValidationError(_('Anonymous account is not supported for this asset')) raise ValidationError(_('Anonymous account is not supported for this asset'))
account = self._validate_perm(user, asset, account_name) account = self._validate_perm(user, asset, account_name, protocol)
if account.has_secret: if account.has_secret:
data['input_secret'] = '' data['input_secret'] = ''
@ -387,9 +388,9 @@ class ConnectionTokenViewSet(ExtraActionApiMixin, RootOrgViewMixin, JMSModelView
return data return data
@staticmethod @staticmethod
def _validate_perm(user, asset, account_name): def _validate_perm(user, asset, account_name, protocol):
from perms.utils.account import PermAccountUtil from perms.utils.asset_perm import PermAssetDetailUtil
account = PermAccountUtil().validate_permission(user, asset, account_name) account = PermAssetDetailUtil(user, asset).validate_permission(account_name, protocol)
if not account or not account.actions: if not account or not account.actions:
msg = _('Account not found') msg = _('Account not found')
raise JMSException(code='perm_account_invalid', detail=msg) raise JMSException(code='perm_account_invalid', detail=msg)

View File

@ -97,10 +97,9 @@ class ConnectionToken(JMSOrgBaseModel):
@lazyproperty @lazyproperty
def permed_account(self): def permed_account(self):
from perms.utils import PermAccountUtil from perms.utils import PermAssetDetailUtil
permed_account = PermAccountUtil().validate_permission( permed_account = PermAssetDetailUtil(self.user, self.asset) \
self.user, self.asset, self.account .validate_permission(self.account, self.protocol)
)
return permed_account return permed_account
@lazyproperty @lazyproperty
@ -115,6 +114,7 @@ class ConnectionToken(JMSOrgBaseModel):
if not self.is_active: if not self.is_active:
error = _('Connection token inactive') error = _('Connection token inactive')
raise PermissionDenied(error) raise PermissionDenied(error)
if self.is_expired: if self.is_expired:
error = _('Connection token expired at: {}').format(as_current_tz(self.date_expired)) error = _('Connection token expired at: {}').format(as_current_tz(self.date_expired))
raise PermissionDenied(error) raise PermissionDenied(error)

View File

@ -55,4 +55,4 @@ class IsValidUserOrConnectionToken(IsValidUser):
return False return False
with tmp_to_root_org(): with tmp_to_root_org():
token = get_object_or_none(ConnectionToken, id=token_id) token = get_object_or_none(ConnectionToken, id=token_id)
return token and token.is_valid return token and token.is_valid()

View File

@ -4,9 +4,10 @@ from rest_framework.generics import ListAPIView, get_object_or_404
from common.utils import get_logger, lazyproperty from common.utils import get_logger, lazyproperty
from perms import serializers from perms import serializers
from perms.hands import Asset from perms.hands import Asset
from perms.utils import PermAccountUtil from perms.utils import PermAssetDetailUtil
from .mixin import SelfOrPKUserMixin from .mixin import SelfOrPKUserMixin
from ...models import AssetPermission from ...models import AssetPermission
logger = get_logger(__name__) logger = get_logger(__name__)
__all__ = [ __all__ = [
@ -26,5 +27,5 @@ class UserPermedAssetAccountsApi(SelfOrPKUserMixin, ListAPIView):
return asset return asset
def get_queryset(self): def get_queryset(self):
accounts = PermAccountUtil().get_permed_accounts_for_user(self.user, self.asset) accounts = PermAssetDetailUtil(self.user, self.asset).get_permed_accounts_for_user()
return accounts return accounts

View File

@ -1,14 +1,15 @@
import abc import abc
from rest_framework.generics import ListAPIView from rest_framework.generics import ListAPIView, RetrieveAPIView
from assets.api.asset.asset import AssetFilterSet from assets.api.asset.asset import AssetFilterSet
from assets.models import Asset, Node from assets.models import Asset, Node
from common.utils import get_logger, lazyproperty, is_uuid from common.utils import get_logger, lazyproperty, is_uuid
from orgs.utils import tmp_to_root_org
from perms import serializers from perms import serializers
from perms.pagination import AllPermedAssetPagination from perms.pagination import AllPermedAssetPagination
from perms.pagination import NodePermedAssetPagination from perms.pagination import NodePermedAssetPagination
from perms.utils import UserPermAssetUtil from perms.utils import UserPermAssetUtil, PermAssetDetailUtil
from .mixin import ( from .mixin import (
SelfOrPKUserMixin SelfOrPKUserMixin
) )
@ -18,11 +19,25 @@ __all__ = [
'UserDirectPermedAssetsApi', 'UserDirectPermedAssetsApi',
'UserFavoriteAssetsApi', 'UserFavoriteAssetsApi',
'UserPermedNodeAssetsApi', 'UserPermedNodeAssetsApi',
'UserPermedAssetRetrieveApi',
] ]
logger = get_logger(__name__) logger = get_logger(__name__)
class UserPermedAssetRetrieveApi(SelfOrPKUserMixin, RetrieveAPIView):
serializer_class = serializers.AssetPermedDetailSerializer
def get_object(self):
with tmp_to_root_org():
asset_id = self.kwargs.get('pk')
util = PermAssetDetailUtil(self.user, asset_id)
asset = util.asset
asset.permed_accounts = util.get_permed_accounts_for_user()
asset.permed_protocols = util.get_permed_protocols_for_user()
return asset
class BaseUserPermedAssetsApi(SelfOrPKUserMixin, ListAPIView): class BaseUserPermedAssetsApi(SelfOrPKUserMixin, ListAPIView):
ordering = ('name',) ordering = ('name',)
search_fields = ('name', 'address', 'comment') search_fields = ('name', 'address', 'comment')
@ -30,12 +45,6 @@ class BaseUserPermedAssetsApi(SelfOrPKUserMixin, ListAPIView):
filterset_class = AssetFilterSet filterset_class = AssetFilterSet
serializer_class = serializers.AssetPermedSerializer serializer_class = serializers.AssetPermedSerializer
def get_serializer_class(self):
serializer_class = super().get_serializer_class()
if self.request.query_params.get('id'):
serializer_class = serializers.AssetPermedDetailSerializer
return serializer_class
def get_queryset(self): def get_queryset(self):
if getattr(self, 'swagger_fake_view', False): if getattr(self, 'swagger_fake_view', False):
return Asset.objects.none() return Asset.objects.none()

View File

@ -21,7 +21,7 @@ from common.utils import get_object_or_none, lazyproperty
from common.utils.common import timeit from common.utils.common import timeit
from perms.hands import Node from perms.hands import Node
from perms.models import PermNode from perms.models import PermNode
from perms.utils import PermAccountUtil, UserPermNodeUtil from perms.utils import PermAssetDetailUtil, UserPermNodeUtil
from perms.utils import UserPermAssetUtil from perms.utils import UserPermAssetUtil
from .mixin import RebuildTreeMixin from .mixin import RebuildTreeMixin
from ..mixin import SelfOrPKUserMixin from ..mixin import SelfOrPKUserMixin
@ -225,8 +225,8 @@ class UserGrantedK8sAsTreeApi(SelfOrPKUserMixin, ListAPIView):
return token return token
def get_account_secret(self, token: ConnectionToken): def get_account_secret(self, token: ConnectionToken):
util = PermAccountUtil() util = PermAssetDetailUtil(self.user, token.asset)
accounts = util.get_permed_accounts_for_user(self.user, token.asset) accounts = util.get_permed_accounts_for_user()
account_name = token.account account_name = token.account
if account_name in [AliasAccount.INPUT, AliasAccount.USER]: if account_name in [AliasAccount.INPUT, AliasAccount.USER]:

View File

@ -0,0 +1,20 @@
# Generated by Django 4.1.10 on 2023-10-25 02:45
from django.db import migrations, models
import perms.models.asset_permission
class Migration(migrations.Migration):
dependencies = [
('perms', '0034_auto_20230525_1734'),
]
operations = [
migrations.AddField(
model_name='assetpermission',
name='protocols',
field=models.JSONField(default=perms.models.asset_permission.default_protocols, verbose_name='Protocols'),
),
]

View File

@ -52,6 +52,10 @@ class AssetPermissionManager(OrgManager):
return self.get_queryset().filter(Q(date_start__lte=now) | Q(date_expired__gte=now)) return self.get_queryset().filter(Q(date_start__lte=now) | Q(date_expired__gte=now))
def default_protocols():
return ['all']
class AssetPermission(JMSOrgBaseModel): class AssetPermission(JMSOrgBaseModel):
name = models.CharField(max_length=128, verbose_name=_('Name')) name = models.CharField(max_length=128, verbose_name=_('Name'))
users = models.ManyToManyField( users = models.ManyToManyField(
@ -68,6 +72,7 @@ class AssetPermission(JMSOrgBaseModel):
) )
# 特殊的账号: @ALL, @INPUT @USER 默认包含,将来在全局设置中进行控制. # 特殊的账号: @ALL, @INPUT @USER 默认包含,将来在全局设置中进行控制.
accounts = models.JSONField(default=list, verbose_name=_("Account")) accounts = models.JSONField(default=list, verbose_name=_("Account"))
protocols = models.JSONField(default=default_protocols, verbose_name=_("Protocols"))
actions = models.IntegerField(default=ActionChoices.connect, verbose_name=_("Actions")) actions = models.IntegerField(default=ActionChoices.connect, verbose_name=_("Actions"))
date_start = models.DateTimeField(default=timezone.now, db_index=True, verbose_name=_("Date start")) date_start = models.DateTimeField(default=timezone.now, db_index=True, verbose_name=_("Date start"))
date_expired = models.DateTimeField( date_expired = models.DateTimeField(

View File

@ -37,6 +37,7 @@ class AssetPermissionSerializer(BulkOrgResourceModelSerializer):
is_valid = serializers.BooleanField(read_only=True, label=_("Is valid")) is_valid = serializers.BooleanField(read_only=True, label=_("Is valid"))
is_expired = serializers.BooleanField(read_only=True, label=_("Is expired")) is_expired = serializers.BooleanField(read_only=True, label=_("Is expired"))
accounts = serializers.ListField(label=_("Account"), required=False) accounts = serializers.ListField(label=_("Account"), required=False)
protocols = serializers.ListField(label=_("Protocols"), required=False)
template_accounts = AccountTemplate.objects.none() template_accounts = AccountTemplate.objects.none()
@ -44,7 +45,7 @@ class AssetPermissionSerializer(BulkOrgResourceModelSerializer):
model = AssetPermission model = AssetPermission
fields_mini = ["id", "name"] fields_mini = ["id", "name"]
fields_generic = [ fields_generic = [
"accounts", "actions", "created_by", "date_created", "accounts", "protocols", "actions", "created_by", "date_created",
"date_start", "date_expired", "is_active", "is_expired", "date_start", "date_expired", "is_active", "is_expired",
"is_valid", "comment", "from_ticket", "is_valid", "comment", "from_ticket",
] ]

View File

@ -8,7 +8,7 @@ from rest_framework import serializers
from accounts.models import Account from accounts.models import Account
from assets.const import Category, AllTypes from assets.const import Category, AllTypes
from assets.models import Node, Asset, Platform from assets.models import Node, Asset, Platform
from assets.serializers.asset.common import AssetProtocolsPermsSerializer, AssetLabelSerializer from assets.serializers.asset.common import AssetLabelSerializer, AssetProtocolsPermsSerializer
from common.serializers.fields import ObjectRelatedField, LabeledChoiceField from common.serializers.fields import ObjectRelatedField, LabeledChoiceField
from orgs.mixins.serializers import OrgResourceModelSerializerMixin from orgs.mixins.serializers import OrgResourceModelSerializerMixin
from perms.serializers.permission import ActionChoicesField from perms.serializers.permission import ActionChoicesField
@ -22,7 +22,6 @@ __all__ = [
class AssetPermedSerializer(OrgResourceModelSerializerMixin): class AssetPermedSerializer(OrgResourceModelSerializerMixin):
""" 被授权资产的数据结构 """ """ 被授权资产的数据结构 """
platform = ObjectRelatedField(required=False, queryset=Platform.objects, label=_('Platform')) platform = ObjectRelatedField(required=False, queryset=Platform.objects, label=_('Platform'))
protocols = AssetProtocolsPermsSerializer(many=True, required=False, label=_('Protocols'))
category = LabeledChoiceField(choices=Category.choices, read_only=True, label=_('Category')) category = LabeledChoiceField(choices=Category.choices, read_only=True, label=_('Category'))
type = LabeledChoiceField(choices=AllTypes.choices(), read_only=True, label=_('Type')) type = LabeledChoiceField(choices=AllTypes.choices(), read_only=True, label=_('Type'))
labels = AssetLabelSerializer(many=True, required=False, label=_('Label')) labels = AssetLabelSerializer(many=True, required=False, label=_('Label'))
@ -35,30 +34,25 @@ class AssetPermedSerializer(OrgResourceModelSerializerMixin):
'comment', 'org_id', 'is_active', 'date_verified', 'comment', 'org_id', 'is_active', 'date_verified',
'created_by', 'date_created', 'connectivity', 'nodes', 'labels' 'created_by', 'date_created', 'connectivity', 'nodes', 'labels'
] ]
fields = only_fields + ['protocols', 'category', 'type'] + ['org_name'] fields = only_fields + ['category', 'type'] + ['org_name']
read_only_fields = fields read_only_fields = fields
@classmethod @classmethod
def setup_eager_loading(cls, queryset): def setup_eager_loading(cls, queryset):
""" Perform necessary eager loading of data. """ """ Perform necessary eager loading of data. """
queryset = queryset.prefetch_related('domain', 'nodes', 'labels') \ queryset = queryset.prefetch_related('domain', 'nodes', 'labels') \
.prefetch_related('platform', 'protocols') \ .prefetch_related('platform') \
.annotate(category=F("platform__category")) \ .annotate(category=F("platform__category")) \
.annotate(type=F("platform__type")) .annotate(type=F("platform__type"))
return queryset return queryset
class AssetPermedDetailSerializer(AssetPermedSerializer):
class Meta(AssetPermedSerializer.Meta):
fields = AssetPermedSerializer.Meta.fields + ['spec_info']
read_only_fields = fields
class NodePermedSerializer(serializers.ModelSerializer): class NodePermedSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = Node model = Node
fields = [ fields = [
'id', 'name', 'key', 'value', 'org_id', "assets_amount" 'id', 'name', 'key', 'value',
'org_id', "assets_amount"
] ]
read_only_fields = fields read_only_fields = fields
@ -73,3 +67,13 @@ class AccountsPermedSerializer(serializers.ModelSerializer):
'has_secret', 'secret_type', 'actions' 'has_secret', 'secret_type', 'actions'
] ]
read_only_fields = fields read_only_fields = fields
class AssetPermedDetailSerializer(AssetPermedSerializer):
# 前面特意加了 permed避免返回的是资产本身的
permed_protocols = AssetProtocolsPermsSerializer(many=True, required=False, label=_('Protocols'))
permed_accounts = AccountsPermedSerializer(label=_("Accounts"), required=False, many=True)
class Meta(AssetPermedSerializer.Meta):
fields = AssetPermedSerializer.Meta.fields + ['spec_info', 'permed_protocols', 'permed_accounts']
read_only_fields = fields

View File

@ -5,8 +5,11 @@ from .. import api
user_permission_urlpatterns = [ user_permission_urlpatterns = [
# <str:user> such as: my | self | user.id # <str:user> such as: my | self | user.id
# assets # assets
path('<str:user>/assets/<uuid:pk>/', api.UserPermedAssetRetrieveApi.as_view(),
name='user-permed-asset'),
path('<str:user>/assets/', api.UserAllPermedAssetsApi.as_view(), path('<str:user>/assets/', api.UserAllPermedAssetsApi.as_view(),
name='user-all-assets'), name='user-all-assets'),
path('<str:user>/nodes/ungrouped/assets/', api.UserDirectPermedAssetsApi.as_view(), path('<str:user>/nodes/ungrouped/assets/', api.UserDirectPermedAssetsApi.as_view(),
name='user-direct-assets'), name='user-direct-assets'),
path('<str:user>/nodes/favorite/assets/', api.UserFavoriteAssetsApi.as_view(), path('<str:user>/nodes/favorite/assets/', api.UserFavoriteAssetsApi.as_view(),
@ -47,9 +50,6 @@ user_permission_urlpatterns = [
path('<str:user>/nodes/children-with-k8s/tree/', path('<str:user>/nodes/children-with-k8s/tree/',
api.UserGrantedK8sAsTreeApi.as_view(), api.UserGrantedK8sAsTreeApi.as_view(),
name='user-nodes-children-with-k8s-as-tree'), name='user-nodes-children-with-k8s-as-tree'),
# accounts
path('<str:user>/assets/<uuid:asset_id>/accounts/', api.UserPermedAssetAccountsApi.as_view(),
name='user-permed-asset-accounts'),
] ]
user_group_permission_urlpatterns = [ user_group_permission_urlpatterns = [

View File

@ -1,4 +1,4 @@
from .asset_perm import *
from .permission import * from .permission import *
from .account import *
from .user_perm_tree import *
from .user_perm import * from .user_perm import *
from .user_perm_tree import *

View File

@ -1,84 +0,0 @@
from collections import defaultdict
from accounts.const import AliasAccount
from accounts.models import VirtualAccount
from orgs.utils import tmp_to_org
from .permission import AssetPermissionUtil
__all__ = ['PermAccountUtil']
class PermAccountUtil(AssetPermissionUtil):
""" 资产授权账号相关的工具 """
def validate_permission(self, user, asset, account_name):
""" 校验用户有某个资产下某个账号名的权限
:param user: User
:param asset: Asset
:param account_name: 可能是 @USER @INPUT 字符串
"""
with tmp_to_org(asset.org):
permed_accounts = self.get_permed_accounts_for_user(user, asset)
accounts_mapper = {account.alias: account for account in permed_accounts}
account = accounts_mapper.get(account_name)
return account
def get_permed_accounts_for_user(self, user, asset):
""" 获取授权给用户某个资产的账号 """
perms = self.get_permissions_for_user_asset(user, asset)
permed_accounts = self.get_permed_accounts_from_perms(perms, user, asset)
return permed_accounts
@staticmethod
def get_permed_accounts_from_perms(perms, user, asset):
# alias: is a collection of account usernames and special accounts [@ALL, @INPUT, @USER, @ANON]
alias_action_bit_mapper = defaultdict(int)
alias_date_expired_mapper = defaultdict(list)
for perm in perms:
for alias in perm.accounts:
alias_action_bit_mapper[alias] |= perm.actions
alias_date_expired_mapper[alias].append(perm.date_expired)
asset_accounts = asset.accounts.all().active()
username_accounts_mapper = defaultdict(list)
for account in asset_accounts:
username_accounts_mapper[account.username].append(account)
cleaned_accounts_action_bit = defaultdict(int)
cleaned_accounts_expired = defaultdict(list)
# @ALL 账号先处理,后面的每个最多映射一个账号
all_action_bit = alias_action_bit_mapper.pop(AliasAccount.ALL, None)
if all_action_bit:
for account in asset_accounts:
cleaned_accounts_action_bit[account] |= all_action_bit
cleaned_accounts_expired[account].extend(
alias_date_expired_mapper[AliasAccount.ALL]
)
for alias, action_bit in alias_action_bit_mapper.items():
account = None
_accounts = []
if alias == AliasAccount.USER and user.username in username_accounts_mapper:
_accounts = username_accounts_mapper[user.username]
elif alias in username_accounts_mapper:
_accounts = username_accounts_mapper[alias]
elif alias in ['@INPUT', '@ANON', '@USER']:
account = VirtualAccount.get_special_account(alias, user, asset, from_permed=True)
elif alias.startswith('@'):
continue
if account:
_accounts += [account]
for account in _accounts:
cleaned_accounts_action_bit[account] |= action_bit
cleaned_accounts_expired[account].extend(alias_date_expired_mapper[alias])
accounts = []
for account, action_bit in cleaned_accounts_action_bit.items():
account.actions = action_bit
account.date_expired = max(cleaned_accounts_expired[account])
accounts.append(account)
return accounts

View File

@ -0,0 +1,139 @@
from collections import defaultdict
from accounts.const import AliasAccount
from accounts.models import VirtualAccount
from assets.models import Asset
from common.utils import lazyproperty
from orgs.utils import tmp_to_org, tmp_to_root_org
from .permission import AssetPermissionUtil
__all__ = ['PermAssetDetailUtil']
class PermAssetDetailUtil:
""" 资产授权账号相关的工具 """
def __init__(self, user, asset_or_id):
self.user = user
if isinstance(asset_or_id, Asset):
self.asset_id = asset_or_id.id
self.asset = asset_or_id
else:
self.asset_id = asset_or_id
@lazyproperty
def asset(self):
if self.user_asset_perms:
return self._asset
raise Asset.DoesNotExist()
@lazyproperty
def _asset(self):
from assets.models import Asset
with tmp_to_root_org():
queryset = Asset.objects.filter(id=self.asset_id)
return queryset.get()
def validate_permission(self, account_name, protocol):
with tmp_to_org(self.asset.org):
protocols = self.get_permed_protocols_for_user(only_name=True)
if 'all' not in protocols and protocol not in protocols:
return None
permed_accounts = self.get_permed_accounts_for_user()
accounts_mapper = {account.alias: account for account in permed_accounts}
account = accounts_mapper.get(account_name)
return account
@lazyproperty
def user_asset_perms(self):
perm_util = AssetPermissionUtil()
perms = perm_util.get_permissions_for_user_asset(self.user, self.asset_id)
return perms
def get_permed_accounts_for_user(self):
""" 获取授权给用户某个资产的账号 """
perms = self.user_asset_perms
permed_accounts = self.get_permed_accounts_from_perms(perms, self.user, self.asset)
return permed_accounts
def get_permed_protocols_for_user(self, only_name=False):
""" 获取授权给用户某个资产的账号 """
perms = self.user_asset_perms
names = set()
for perm in perms:
names |= set(perm.protocols)
if only_name:
return names
protocols = self.asset.protocols.all()
if 'all' not in names:
protocols = protocols.filter(name__in=names)
return protocols
@staticmethod
def parse_alias_action_date_expire(perms, asset):
alias_action_bit_mapper = defaultdict(int)
alias_date_expired_mapper = defaultdict(list)
for perm in perms:
for alias in perm.accounts:
alias_action_bit_mapper[alias] |= perm.actions
alias_date_expired_mapper[alias].append(perm.date_expired)
# @ALL 账号先处理,后面的每个最多映射一个账号
all_action_bit = alias_action_bit_mapper.pop(AliasAccount.ALL, None)
if not all_action_bit:
return alias_action_bit_mapper, alias_date_expired_mapper
asset_account_usernames = asset.accounts.all().active().values_list('username', flat=True)
for username in asset_account_usernames:
alias_action_bit_mapper[username] |= all_action_bit
alias_date_expired_mapper[username].extend(
alias_date_expired_mapper[AliasAccount.ALL]
)
return alias_action_bit_mapper, alias_date_expired_mapper
@classmethod
def map_alias_to_accounts(cls, alias_action_bit_mapper, alias_date_expired_mapper, asset, user):
username_accounts_mapper = defaultdict(list)
cleaned_accounts_expired = defaultdict(list)
asset_accounts = asset.accounts.all().active()
# 用户名 -> 账号
for account in asset_accounts:
username_accounts_mapper[account.username].append(account)
cleaned_accounts_action_bit = defaultdict(int)
for alias, action_bit in alias_action_bit_mapper.items():
account = None
_accounts = []
if alias == AliasAccount.USER and user.username in username_accounts_mapper:
_accounts = username_accounts_mapper[user.username]
elif alias in username_accounts_mapper:
_accounts = username_accounts_mapper[alias]
elif alias in ['@INPUT', '@ANON', '@USER']:
account = VirtualAccount.get_special_account(alias, user, asset, from_permed=True)
elif alias.startswith('@'):
continue
if account:
_accounts += [account]
for account in _accounts:
cleaned_accounts_action_bit[account] |= action_bit
cleaned_accounts_expired[account].extend(alias_date_expired_mapper[alias])
return cleaned_accounts_action_bit, cleaned_accounts_expired
@classmethod
def get_permed_accounts_from_perms(cls, perms, user, asset):
# alias: is a collection of account usernames and special accounts [@ALL, @INPUT, @USER, @ANON]
alias_action_bit_mapper, alias_date_expired_mapper = cls.parse_alias_action_date_expire(perms, asset)
cleaned_accounts_action_bit, cleaned_accounts_expired = cls.map_alias_to_accounts(
alias_action_bit_mapper, alias_date_expired_mapper, asset, user
)
accounts = []
for account, action_bit in cleaned_accounts_action_bit.items():
account.actions = action_bit
account.date_expired = max(cleaned_accounts_expired[account])
accounts.append(account)
return accounts

View File

@ -3,7 +3,7 @@ from django.db.models import Q
from assets.models import FavoriteAsset, Asset from assets.models import FavoriteAsset, Asset
from common.utils.common import timeit from common.utils.common import timeit
from perms.models import AssetPermission, PermNode, UserAssetGrantedTreeNodeRelation from perms.models import PermNode, UserAssetGrantedTreeNodeRelation
from .permission import AssetPermissionUtil from .permission import AssetPermissionUtil
__all__ = ['AssetPermissionPermAssetUtil', 'UserPermAssetUtil', 'UserPermNodeUtil'] __all__ = ['AssetPermissionPermAssetUtil', 'UserPermAssetUtil', 'UserPermNodeUtil']
@ -218,4 +218,3 @@ class UserPermNodeUtil:
nodes.extend(list(key_node_mapper.values())) nodes.extend(list(key_node_mapper.values()))
return nodes return nodes