Merge pull request #4754 from jumpserver/dev

Dev
pull/4765/head
Jiangjie.Bai 2020-10-09 10:37:47 +08:00 committed by GitHub
commit 6584890ab1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
99 changed files with 3427 additions and 1773 deletions

View File

@ -1,3 +1,4 @@
from .mixin import *
from .admin_user import *
from .asset import *
from .label import *

View File

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
#
from assets.api import FilterAssetByNodeMixin
from rest_framework.viewsets import ModelViewSet
from rest_framework.generics import RetrieveAPIView
from django.shortcuts import get_object_or_404
@ -14,7 +14,7 @@ from .. import serializers
from ..tasks import (
update_asset_hardware_info_manual, test_asset_connectivity_manual
)
from ..filters import AssetByNodeFilterBackend, LabelFilterBackend, IpInFilterBackend
from ..filters import FilterAssetByNodeFilterBackend, LabelFilterBackend, IpInFilterBackend
logger = get_logger(__file__)
@ -25,7 +25,7 @@ __all__ = [
]
class AssetViewSet(OrgBulkModelViewSet):
class AssetViewSet(FilterAssetByNodeMixin, OrgBulkModelViewSet):
"""
API endpoint that allows Asset to be viewed or edited.
"""
@ -41,7 +41,7 @@ class AssetViewSet(OrgBulkModelViewSet):
'display': serializers.AssetDisplaySerializer,
}
permission_classes = (IsOrgAdminOrAppUser,)
extra_filter_backends = [AssetByNodeFilterBackend, LabelFilterBackend, IpInFilterBackend]
extra_filter_backends = [FilterAssetByNodeFilterBackend, LabelFilterBackend, IpInFilterBackend]
def set_assets_node(self, assets):
if not isinstance(assets, list):

88
apps/assets/api/mixin.py Normal file
View File

@ -0,0 +1,88 @@
from typing import List
from assets.models import Node, Asset
from assets.pagination import AssetLimitOffsetPagination
from common.utils import lazyproperty, dict_get_any, is_uuid, get_object_or_none
from assets.utils import get_node, is_query_node_all_assets
class SerializeToTreeNodeMixin:
permission_classes = ()
def serialize_nodes(self, nodes: List[Node], with_asset_amount=False):
if with_asset_amount:
def _name(node: Node):
return '{} ({})'.format(node.value, node.assets_amount)
else:
def _name(node: Node):
return node.value
data = [
{
'id': node.key,
'name': _name(node),
'title': _name(node),
'pId': node.parent_key,
'isParent': True,
'open': node.is_org_root(),
'meta': {
'node': {
"id": node.id,
"key": node.key,
"value": node.value,
},
'type': 'node'
}
}
for node in nodes
]
return data
def get_platform(self, asset: Asset):
default = 'file'
icon = {'windows', 'linux'}
platform = asset.platform_base.lower()
if platform in icon:
return platform
return default
def serialize_assets(self, assets, node_key=None):
if node_key is None:
get_pid = lambda asset: getattr(asset, 'parent_key', '')
else:
get_pid = lambda asset: node_key
data = [
{
'id': str(asset.id),
'name': asset.hostname,
'title': asset.ip,
'pId': get_pid(asset),
'isParent': False,
'open': False,
'iconSkin': self.get_platform(asset),
'meta': {
'type': 'asset',
'asset': {
'id': asset.id,
'hostname': asset.hostname,
'ip': asset.ip,
'protocols': asset.protocols_as_list,
'platform': asset.platform_base,
},
}
}
for asset in assets
]
return data
class FilterAssetByNodeMixin:
pagination_class = AssetLimitOffsetPagination
@lazyproperty
def is_query_node_all_assets(self):
return is_query_node_all_assets(self.request)
@lazyproperty
def node(self):
return get_node(self.request)

View File

@ -1,16 +1,24 @@
# ~*~ coding: utf-8 ~*~
from functools import partial
from collections import namedtuple, defaultdict
from collections import namedtuple
from rest_framework import status
from rest_framework.serializers import ValidationError
from rest_framework.response import Response
from django.utils.translation import ugettext_lazy as _
from django.shortcuts import get_object_or_404, Http404
from django.utils.decorators import method_decorator
from django.db.models.signals import m2m_changed
from common.exceptions import SomeoneIsDoingThis
from common.const.signals import PRE_REMOVE, POST_REMOVE
from assets.models import Asset
from common.utils import get_logger, get_object_or_none
from common.tree import TreeNodeSerializer
from common.const.distributed_lock_key import UPDATE_NODE_TREE_LOCK_KEY
from orgs.mixins.api import OrgModelViewSet
from orgs.mixins import generics
from orgs.lock import org_level_transaction_lock
from ..hands import IsOrgAdmin
from ..models import Node
from ..tasks import (
@ -18,12 +26,13 @@ from ..tasks import (
test_node_assets_connectivity_manual,
)
from .. import serializers
from .mixin import SerializeToTreeNodeMixin
logger = get_logger(__file__)
__all__ = [
'NodeViewSet', 'NodeChildrenApi', 'NodeAssetsApi',
'NodeAddAssetsApi', 'NodeRemoveAssetsApi', 'NodeReplaceAssetsApi',
'NodeAddAssetsApi', 'NodeRemoveAssetsApi', 'MoveAssetsToNodeApi',
'NodeAddChildrenApi', 'NodeListAsTreeApi',
'NodeChildrenAsTreeApi',
'NodeTaskCreateApi',
@ -136,7 +145,7 @@ class NodeChildrenApi(generics.ListCreateAPIView):
return queryset
class NodeChildrenAsTreeApi(NodeChildrenApi):
class NodeChildrenAsTreeApi(SerializeToTreeNodeMixin, NodeChildrenApi):
"""
节点子节点作为树返回
[
@ -150,31 +159,23 @@ class NodeChildrenAsTreeApi(NodeChildrenApi):
"""
model = Node
serializer_class = TreeNodeSerializer
http_method_names = ['get']
def get_queryset(self):
queryset = super().get_queryset()
queryset = [node.as_tree_node() for node in queryset]
queryset = self.add_assets_if_need(queryset)
queryset = sorted(queryset)
return queryset
def list(self, request, *args, **kwargs):
nodes = self.get_queryset().order_by('value')
nodes = self.serialize_nodes(nodes, with_asset_amount=True)
assets = self.get_assets()
data = [*nodes, *assets]
return Response(data=data)
def add_assets_if_need(self, queryset):
def get_assets(self):
include_assets = self.request.query_params.get('assets', '0') == '1'
if not include_assets:
return queryset
return []
assets = self.instance.get_assets().only(
"id", "hostname", "ip", "os",
"org_id", "protocols",
)
for asset in assets:
queryset.append(asset.as_tree_node(self.instance))
return queryset
def check_need_refresh_nodes(self):
if self.request.query_params.get('refresh', '0') == '1':
Node.refresh_nodes()
return self.serialize_assets(assets, self.instance.key)
class NodeAssetsApi(generics.ListAPIView):
@ -208,6 +209,8 @@ class NodeAddChildrenApi(generics.UpdateAPIView):
return Response("OK")
@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='patch')
@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='put')
class NodeAddAssetsApi(generics.UpdateAPIView):
model = Node
serializer_class = serializers.NodeAssetsSerializer
@ -220,6 +223,8 @@ class NodeAddAssetsApi(generics.UpdateAPIView):
instance.assets.add(*tuple(assets))
@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='patch')
@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='put')
class NodeRemoveAssetsApi(generics.UpdateAPIView):
model = Node
serializer_class = serializers.NodeAssetsSerializer
@ -228,15 +233,17 @@ class NodeRemoveAssetsApi(generics.UpdateAPIView):
def perform_update(self, serializer):
assets = serializer.validated_data.get('assets')
instance = self.get_object()
if instance != Node.org_root():
instance.assets.remove(*tuple(assets))
else:
assets = [asset for asset in assets if asset.nodes.count() > 1]
instance.assets.remove(*tuple(assets))
node = self.get_object()
node.assets.remove(*assets)
# 把孤儿资产添加到 root 节点
orphan_assets = Asset.objects.filter(id__in=[a.id for a in assets], nodes__isnull=True).distinct()
Node.org_root().assets.add(*orphan_assets)
class NodeReplaceAssetsApi(generics.UpdateAPIView):
@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='patch')
@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='put')
class MoveAssetsToNodeApi(generics.UpdateAPIView):
model = Node
serializer_class = serializers.NodeAssetsSerializer
permission_classes = (IsOrgAdmin,)
@ -244,9 +251,39 @@ class NodeReplaceAssetsApi(generics.UpdateAPIView):
def perform_update(self, serializer):
assets = serializer.validated_data.get('assets')
instance = self.get_object()
for asset in assets:
asset.nodes.set([instance])
node = self.get_object()
self.remove_old_nodes(assets)
node.assets.add(*assets)
def remove_old_nodes(self, assets):
m2m_model = Asset.nodes.through
# 查询资产与节点关系表,查出要移动资产与节点的所有关系
relates = m2m_model.objects.filter(asset__in=assets).values_list('asset_id', 'node_id')
if relates:
# 对关系以资产进行分组,用来发 `reverse=False` 信号
asset_nodes_mapper = defaultdict(set)
for asset_id, node_id in relates:
asset_nodes_mapper[asset_id].add(node_id)
# 组建一个资产 id -> Asset 的 mapper
asset_mapper = {asset.id: asset for asset in assets}
# 创建删除关系信号发送函数
senders = []
for asset_id, node_id_set in asset_nodes_mapper.items():
senders.append(partial(m2m_changed.send, sender=m2m_model, instance=asset_mapper[asset_id],
reverse=False, model=Node, pk_set=node_id_set))
# 发送 pre 信号
[sender(action=PRE_REMOVE) for sender in senders]
num = len(relates)
asset_ids, node_ids = zip(*relates)
# 删除之前的关系
rows, _i = m2m_model.objects.filter(asset_id__in=asset_ids, node_id__in=node_ids).delete()
if rows != num:
raise SomeoneIsDoingThis
# 发送 post 信号
[sender(action=POST_REMOVE) for sender in senders]
class NodeTaskCreateApi(generics.CreateAPIView):
@ -267,7 +304,6 @@ class NodeTaskCreateApi(generics.CreateAPIView):
@staticmethod
def refresh_nodes_cache():
Node.refresh_nodes()
Task = namedtuple('Task', ['id'])
task = Task(id="0")
return task

View File

@ -0,0 +1,6 @@
from rest_framework import status
from common.exceptions import JMSException
class NodeIsBeingUpdatedByOthers(JMSException):
status_code = status.HTTP_409_CONFLICT

View File

@ -5,8 +5,8 @@ from rest_framework.compat import coreapi, coreschema
from rest_framework import filters
from django.db.models import Q
from common.utils import dict_get_any, is_uuid, get_object_or_none
from .models import Node, Label
from .models import Label
from assets.utils import is_query_node_all_assets, get_node
class AssetByNodeFilterBackend(filters.BaseFilterBackend):
@ -21,47 +21,54 @@ class AssetByNodeFilterBackend(filters.BaseFilterBackend):
for field in self.fields
]
@staticmethod
def is_query_all(request):
query_all_arg = request.query_params.get('all')
show_current_asset_arg = request.query_params.get('show_current_asset')
def filter_node_related_all(self, queryset, node):
return queryset.filter(
Q(nodes__key__istartswith=f'{node.key}:') |
Q(nodes__key=node.key)
).distinct()
query_all = query_all_arg == '1'
if show_current_asset_arg is not None:
query_all = show_current_asset_arg != '1'
return query_all
@staticmethod
def get_query_node(request):
node_id = dict_get_any(request.query_params, ['node', 'node_id'])
if not node_id:
return None, False
if is_uuid(node_id):
node = get_object_or_none(Node, id=node_id)
else:
node = get_object_or_none(Node, key=node_id)
return node, True
@staticmethod
def perform_query(pattern, queryset):
return queryset.filter(nodes__key__regex=pattern).distinct()
def filter_node_related_direct(self, queryset, node):
return queryset.filter(nodes__key=node.key).distinct()
def filter_queryset(self, request, queryset, view):
node, has_query_arg = self.get_query_node(request)
if not has_query_arg:
return queryset
node = get_node(request)
if node is None:
return queryset
query_all = self.is_query_all(request)
query_all = is_query_node_all_assets(request)
if query_all:
pattern = node.get_all_children_pattern(with_self=True)
return self.filter_node_related_all(queryset, node)
else:
# pattern = node.get_children_key_pattern(with_self=True)
# 只显示当前节点下资产
pattern = r"^{}$".format(node.key)
return self.perform_query(pattern, queryset)
return self.filter_node_related_direct(queryset, node)
class FilterAssetByNodeFilterBackend(filters.BaseFilterBackend):
"""
需要与 `assets.api.mixin.FilterAssetByNodeMixin` 配合使用
"""
fields = ['node', 'all']
def get_schema_fields(self, view):
return [
coreapi.Field(
name=field, location='query', required=False,
type='string', example='', description='', schema=None,
)
for field in self.fields
]
def filter_queryset(self, request, queryset, view):
node = view.node
if node is None:
return queryset
query_all = view.is_query_node_all_assets
if query_all:
return queryset.filter(
Q(nodes__key__istartswith=f'{node.key}:') |
Q(nodes__key=node.key)
).distinct()
else:
return queryset.filter(nodes__key=node.key).distinct()
class LabelFilterBackend(filters.BaseFilterBackend):
@ -113,9 +120,14 @@ class LabelFilterBackend(filters.BaseFilterBackend):
class AssetRelatedByNodeFilterBackend(AssetByNodeFilterBackend):
@staticmethod
def perform_query(pattern, queryset):
return queryset.filter(asset__nodes__key__regex=pattern).distinct()
def filter_node_related_all(self, queryset, node):
return queryset.filter(
Q(asset__nodes__key__istartswith=f'{node.key}:') |
Q(asset__nodes__key=node.key)
).distinct()
def filter_node_related_direct(self, queryset, node):
return queryset.filter(asset__nodes__key=node.key).distinct()
class IpInFilterBackend(filters.BaseFilterBackend):

View File

@ -0,0 +1,23 @@
# Generated by Django 2.2.13 on 2020-09-04 09:51
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('assets', '0055_auto_20200811_1845'),
]
operations = [
migrations.AddField(
model_name='node',
name='assets_amount',
field=models.IntegerField(default=0),
),
migrations.AddField(
model_name='node',
name='parent_key',
field=models.CharField(db_index=True, default='', max_length=64, verbose_name='Parent key'),
),
]

View File

@ -0,0 +1,38 @@
# Generated by Django 2.2.13 on 2020-08-21 08:20
from django.db import migrations
from django.db.models import Q
def fill_node_value(apps, schema_editor):
Node = apps.get_model('assets', 'Node')
Asset = apps.get_model('assets', 'Asset')
node_queryset = Node.objects.all()
node_amount = node_queryset.count()
width = len(str(node_amount))
print('\n')
for i, node in enumerate(node_queryset):
print(f'\t{i+1:0>{width}}/{node_amount} compute node[{node.key}]`s assets_amount ...')
assets_amount = Asset.objects.filter(
Q(nodes__key__istartswith=f'{node.key}:') | Q(nodes=node)
).distinct().count()
key = node.key
try:
parent_key = key[:key.rindex(':')]
except ValueError:
parent_key = ''
node.assets_amount = assets_amount
node.parent_key = parent_key
node.save()
print(' ' + '.'*65, end='')
class Migration(migrations.Migration):
dependencies = [
('assets', '0056_auto_20200904_1751'),
]
operations = [
migrations.RunPython(fill_node_value)
]

View File

@ -47,6 +47,10 @@ class AssetManager(OrgManager):
)
class AssetOrgManager(OrgManager):
pass
class AssetQuerySet(models.QuerySet):
def active(self):
return self.filter(is_active=True)
@ -226,6 +230,7 @@ class Asset(ProtocolsMixin, NodesRelationMixin, OrgModelMixin):
comment = models.TextField(max_length=128, default='', blank=True, verbose_name=_('Comment'))
objects = AssetManager.from_queryset(AssetQuerySet)()
org_objects = AssetOrgManager.from_queryset(AssetQuerySet)()
_connectivity = None
def __str__(self):

View File

@ -18,3 +18,11 @@ class FavoriteAsset(CommonModelMixin):
@classmethod
def get_user_favorite_assets_id(cls, user):
return cls.objects.filter(user=user).values_list('asset', flat=True)
@classmethod
def get_user_favorite_assets(cls, user):
from assets.models import Asset
from perms.utils.user_asset_permission import get_user_granted_all_assets
asset_ids = get_user_granted_all_assets(user).values_list('id', flat=True)
query_name = cls.asset.field.related_query_name()
return Asset.org_objects.filter(**{f'{query_name}__user_id': user.id}, id__in=asset_ids).distinct()

View File

@ -2,132 +2,35 @@
#
import uuid
import re
import time
from django.db import models, transaction
from django.db.models import Q
from django.db.utils import IntegrityError
from django.utils.translation import ugettext_lazy as _
from django.utils.translation import ugettext
from django.core.cache import cache
from django.db.transaction import atomic
from common.utils import get_logger, lazyproperty
from common.utils import get_logger
from common.utils.common import lazyproperty
from orgs.mixins.models import OrgModelMixin, OrgManager
from orgs.utils import get_current_org, tmp_to_org, current_org
from orgs.models import Organization
__all__ = ['Node']
__all__ = ['Node', 'FamilyMixin', 'compute_parent_key']
logger = get_logger(__name__)
def compute_parent_key(key):
try:
return key[:key.rindex(':')]
except ValueError:
return ''
class NodeQuerySet(models.QuerySet):
def delete(self):
raise PermissionError("Bulk delete node deny")
class TreeCache:
updated_time_cache_key = 'NODE_TREE_UPDATED_AT_{}'
cache_time = 3600
assets_updated_time_cache_key = 'NODE_TREE_ASSETS_UPDATED_AT_{}'
def __init__(self, tree, org_id):
now = time.time()
self.created_time = now
self.assets_created_time = now
self.tree = tree
self.org_id = org_id
def _has_changed(self, tp="tree"):
if tp == "assets":
key = self.assets_updated_time_cache_key.format(self.org_id)
else:
key = self.updated_time_cache_key.format(self.org_id)
updated_time = cache.get(key, 0)
if updated_time > self.created_time:
return True
else:
return False
@classmethod
def set_changed(cls, tp="tree", t=None, org_id=None):
if org_id is None:
org_id = current_org.id
if tp == "assets":
key = cls.assets_updated_time_cache_key.format(org_id)
else:
key = cls.updated_time_cache_key.format(org_id)
ttl = cls.cache_time
if not t:
t = time.time()
cache.set(key, t, ttl)
def tree_has_changed(self):
return self._has_changed("tree")
def set_tree_changed(self, t=None):
logger.debug("Set tree tree changed")
self.__class__.set_changed(t=t, tp="tree")
def assets_has_changed(self):
return self._has_changed("assets")
def set_tree_assets_changed(self, t=None):
logger.debug("Set tree assets changed")
self.__class__.set_changed(t=t, tp="assets")
def get(self):
if self.tree_has_changed():
self.renew()
return self.tree
if self.assets_has_changed():
self.tree.init_assets()
return self.tree
def renew(self):
new_obj = self.__class__.new(self.org_id)
self.tree = new_obj.tree
self.created_time = new_obj.created_time
self.assets_created_time = new_obj.assets_created_time
@classmethod
def new(cls, org_id=None):
from ..utils import TreeService
logger.debug("Create node tree")
if not org_id:
org_id = current_org.id
with tmp_to_org(org_id):
tree = TreeService.new()
obj = cls(tree, org_id)
obj.tree = tree
return obj
class TreeMixin:
_org_tree_map = {}
@classmethod
def tree(cls):
org_id = current_org.org_id()
t = cls.get_local_tree_cache(org_id)
if t is None:
t = TreeCache.new()
cls._org_tree_map[org_id] = t
return t.get()
@classmethod
def get_local_tree_cache(cls, org_id=None):
t = cls._org_tree_map.get(org_id)
return t
@classmethod
def refresh_tree(cls, t=None):
TreeCache.set_changed(tp="tree", t=t, org_id=current_org.id)
@classmethod
def refresh_node_assets(cls, t=None):
TreeCache.set_changed(tp="assets", t=t, org_id=current_org.id)
raise NotImplementedError
class FamilyMixin:
@ -138,16 +41,16 @@ class FamilyMixin:
@staticmethod
def clean_children_keys(nodes_keys):
nodes_keys = sorted(list(nodes_keys), key=lambda x: (len(x), x))
sort_key = lambda k: [int(i) for i in k.split(':')]
nodes_keys = sorted(list(nodes_keys), key=sort_key)
nodes_keys_clean = []
for key in nodes_keys[::-1]:
found = False
for k in nodes_keys:
if key.startswith(k + ':'):
found = True
break
if not found:
nodes_keys_clean.append(key)
base_key = ''
for key in nodes_keys:
if key.startswith(base_key + ':'):
continue
nodes_keys_clean.append(key)
base_key = key
return nodes_keys_clean
@classmethod
@ -175,13 +78,16 @@ class FamilyMixin:
return re.match(children_pattern, self.key)
def get_children(self, with_self=False):
pattern = self.get_children_key_pattern(with_self=with_self)
return Node.objects.filter(key__regex=pattern)
q = Q(parent_key=self.key)
if with_self:
q |= Q(key=self.key)
return Node.objects.filter(q)
def get_all_children(self, with_self=False):
pattern = self.get_all_children_pattern(with_self=with_self)
children = Node.objects.filter(key__regex=pattern)
return children
q = Q(key__istartswith=f'{self.key}:')
if with_self:
q |= Q(key=self.key)
return Node.objects.filter(q)
@property
def children(self):
@ -192,10 +98,10 @@ class FamilyMixin:
return self.get_all_children(with_self=False)
def create_child(self, value, _id=None):
with transaction.atomic():
with atomic(savepoint=False):
child_key = self.get_next_child_key()
child = self.__class__.objects.create(
id=_id, key=child_key, value=value
id=_id, key=child_key, value=value, parent_key=self.key,
)
return child
@ -255,10 +161,13 @@ class FamilyMixin:
ancestor_keys = self.get_ancestor_keys(with_self=with_self)
return self.__class__.objects.filter(key__in=ancestor_keys)
@property
def parent_key(self):
parent_key = ":".join(self.key.split(":")[:-1])
return parent_key
# @property
# def parent_key(self):
# parent_key = ":".join(self.key.split(":")[:-1])
# return parent_key
def compute_parent_key(self):
return compute_parent_key(self.key)
def is_parent(self, other):
return other.is_children(self)
@ -300,39 +209,33 @@ class FamilyMixin:
return [*tuple(ancestors), self, *tuple(children)]
class FullValueMixin:
key = ''
@lazyproperty
def full_value(self):
if self.is_org_root():
return self.value
value = self.tree().get_node_full_tag(self.key)
return value
class NodeAssetsMixin:
key = ''
id = None
@lazyproperty
def assets_amount(self):
amount = self.tree().assets_amount(self.key)
return amount
def get_all_assets(self):
from .asset import Asset
if self.is_org_root():
return Asset.objects.filter(org_id=self.org_id)
pattern = '^{0}$|^{0}:'.format(self.key)
return Asset.objects.filter(nodes__key__regex=pattern).distinct()
q = Q(nodes__key__startswith=f'{self.key}:') | Q(nodes__key=self.key)
return Asset.objects.filter(q).distinct()
@classmethod
def get_node_all_assets_by_key_v2(cls, key):
# 最初的写法是:
# Asset.objects.filter(Q(nodes__key__startswith=f'{node.key}:') | Q(nodes__id=node.id))
# 可是 startswith 会导致表关联时 Asset 索引失效
from .asset import Asset
node_ids = cls.objects.filter(
Q(key__startswith=f'{key}:') |
Q(key=key)
).values_list('id', flat=True).distinct()
assets = Asset.objects.filter(
nodes__id__in=list(node_ids)
).distinct()
return assets
def get_assets(self):
from .asset import Asset
if self.is_org_root():
assets = Asset.objects.filter(Q(nodes=self) | Q(nodes__isnull=True))
else:
assets = Asset.objects.filter(nodes=self)
assets = Asset.objects.filter(nodes=self)
return assets.distinct()
def get_valid_assets(self):
@ -341,51 +244,54 @@ class NodeAssetsMixin:
def get_all_valid_assets(self):
return self.get_all_assets().valid()
@classmethod
def _get_nodes_all_assets(cls, nodes_keys):
"""
当节点比较多的时候这种正则方式性能差极了
:param nodes_keys:
:return:
"""
from .asset import Asset
nodes_keys = cls.clean_children_keys(nodes_keys)
nodes_children_pattern = set()
for key in nodes_keys:
children_pattern = cls.get_node_all_children_key_pattern(key)
nodes_children_pattern.add(children_pattern)
pattern = '|'.join(nodes_children_pattern)
return Asset.objects.filter(nodes__key__regex=pattern).distinct()
@classmethod
def get_nodes_all_assets_ids(cls, nodes_keys):
nodes_keys = cls.clean_children_keys(nodes_keys)
assets_ids = set()
for key in nodes_keys:
node_assets_ids = cls.tree().all_assets(key)
assets_ids.update(set(node_assets_ids))
assets_ids = cls.get_nodes_all_assets(nodes_keys).values_list('id', flat=True)
return assets_ids
@classmethod
def get_nodes_all_assets(cls, nodes_keys, extra_assets_ids=None):
from .asset import Asset
nodes_keys = cls.clean_children_keys(nodes_keys)
assets_ids = cls.get_nodes_all_assets_ids(nodes_keys)
q = Q()
node_ids = ()
for key in nodes_keys:
q |= Q(key__startswith=f'{key}:')
q |= Q(key=key)
if q:
node_ids = Node.objects.filter(q).distinct().values_list('id', flat=True)
q = Q(nodes__id__in=list(node_ids))
if extra_assets_ids:
assets_ids.update(set(extra_assets_ids))
return Asset.objects.filter(id__in=assets_ids)
q |= Q(id__in=extra_assets_ids)
if q:
return Asset.org_objects.filter(q).distinct()
else:
return Asset.objects.none()
class SomeNodesMixin:
key = ''
default_key = '1'
default_value = 'Default'
ungrouped_key = '-10'
ungrouped_value = _('ungrouped')
empty_key = '-11'
empty_value = _("empty")
favorite_key = '-12'
favorite_value = _("favorite")
@classmethod
def default_node(cls):
with tmp_to_org(Organization.default()):
defaults = {'value': cls.default_value}
try:
obj, created = cls.objects.get_or_create(
defaults=defaults, key=cls.default_key,
)
except IntegrityError as e:
logger.error("Create default node failed: {}".format(e))
cls.modify_other_org_root_node_key()
obj, created = cls.objects.get_or_create(
defaults=defaults, key=cls.default_key,
)
return obj
def is_default_node(self):
return self.key == self.default_key
@ -420,51 +326,15 @@ class SomeNodesMixin:
@classmethod
def org_root(cls):
root = cls.objects.filter(key__regex=r'^[0-9]+$')
root = cls.objects.filter(parent_key='').exclude(key__startswith='-')
if root:
return root[0]
else:
return cls.create_org_root_node()
@classmethod
def ungrouped_node(cls):
with tmp_to_org(Organization.system()):
defaults = {'value': cls.ungrouped_value}
obj, created = cls.objects.get_or_create(
defaults=defaults, key=cls.ungrouped_key
)
return obj
@classmethod
def default_node(cls):
with tmp_to_org(Organization.default()):
defaults = {'value': cls.default_value}
try:
obj, created = cls.objects.get_or_create(
defaults=defaults, key=cls.default_key,
)
except IntegrityError as e:
logger.error("Create default node failed: {}".format(e))
cls.modify_other_org_root_node_key()
obj, created = cls.objects.get_or_create(
defaults=defaults, key=cls.default_key,
)
return obj
@classmethod
def favorite_node(cls):
with tmp_to_org(Organization.system()):
defaults = {'value': cls.favorite_value}
obj, created = cls.objects.get_or_create(
defaults=defaults, key=cls.favorite_key
)
return obj
@classmethod
def initial_some_nodes(cls):
cls.default_node()
cls.ungrouped_node()
cls.favorite_node()
@classmethod
def modify_other_org_root_node_key(cls):
@ -496,12 +366,15 @@ class SomeNodesMixin:
logger.info('Modify key ( {} > {} )'.format(old_key, new_key))
class Node(OrgModelMixin, SomeNodesMixin, TreeMixin, FamilyMixin, FullValueMixin, NodeAssetsMixin):
class Node(OrgModelMixin, SomeNodesMixin, FamilyMixin, NodeAssetsMixin):
id = models.UUIDField(default=uuid.uuid4, primary_key=True)
key = models.CharField(unique=True, max_length=64, verbose_name=_("Key")) # '1:1:1:1'
value = models.CharField(max_length=128, verbose_name=_("Value"))
child_mark = models.IntegerField(default=0)
date_create = models.DateTimeField(auto_now_add=True)
parent_key = models.CharField(max_length=64, verbose_name=_("Parent key"),
db_index=True, default='')
assets_amount = models.IntegerField(default=0)
objects = OrgManager.from_queryset(NodeQuerySet)()
is_node = True
@ -536,18 +409,20 @@ class Node(OrgModelMixin, SomeNodesMixin, TreeMixin, FamilyMixin, FullValueMixin
def name(self):
return self.value
@lazyproperty
def full_value(self):
# 不要在列表中调用该属性
values = self.__class__.objects.filter(
key__in=self.get_ancestor_keys()
).values_list('key', 'value')
values = [v for k, v in sorted(values, key=lambda x: len(x[0]))]
values.append(self.value)
return ' / '.join(values)
@property
def level(self):
return len(self.key.split(':'))
@classmethod
def refresh_nodes(cls):
cls.refresh_tree()
@classmethod
def refresh_assets(cls):
cls.refresh_node_assets()
def as_tree_node(self):
from common.tree import TreeNode
name = '{} ({})'.format(self.value, self.assets_amount)

39
apps/assets/pagination.py Normal file
View File

@ -0,0 +1,39 @@
from rest_framework.pagination import LimitOffsetPagination
from rest_framework.request import Request
from assets.models import Node
class AssetLimitOffsetPagination(LimitOffsetPagination):
"""
需要与 `assets.api.mixin.FilterAssetByNodeMixin` 配合使用
"""
def get_count(self, queryset):
"""
1. 如果查询节点下的所有资产 count 使用 Node.assets_amount
2. 如果有其他过滤条件使用 super
3. 如果只查询该节点下的资产使用 super
"""
exclude_query_params = {
self.limit_query_param,
self.offset_query_param,
'node', 'all', 'show_current_asset',
'node_id', 'display', 'draw',
}
for k, v in self._request.query_params.items():
if k not in exclude_query_params and v is not None:
return super().get_count(queryset)
is_query_all = self._view.is_query_node_all_assets
if is_query_all:
node = self._view.node
if not node:
node = Node.org_root()
return node.assets_amount
return super().get_count(queryset)
def paginate_queryset(self, queryset, request: Request, view=None):
self._request = request
self._view = view
return super().paginate_queryset(queryset, request, view=None)

View File

@ -75,7 +75,6 @@ class AssetSerializer(BulkOrgResourceModelSerializer):
"""
class Meta:
model = Asset
list_serializer_class = AdaptedBulkListSerializer
fields_mini = ['id', 'hostname', 'ip']
fields_small = fields_mini + [
'protocol', 'port', 'protocols', 'is_active', 'public_ip',
@ -116,6 +115,7 @@ class AssetSerializer(BulkOrgResourceModelSerializer):
def setup_eager_loading(cls, queryset):
""" Perform necessary eager loading of data. """
queryset = queryset.select_related('admin_user', 'domain', 'platform')
queryset = queryset.prefetch_related('nodes', 'labels')
return queryset
def compatible_with_old_protocol(self, validated_data):
@ -153,7 +153,7 @@ class AssetDisplaySerializer(AssetSerializer):
@classmethod
def setup_eager_loading(cls, queryset):
""" Perform necessary eager loading of data. """
queryset = super().setup_eager_loading(queryset)
queryset = queryset\
.annotate(admin_user_username=F('admin_user__username'))
return queryset

View File

@ -229,15 +229,8 @@ class SystemUserNodeRelationSerializer(RelationMixin, serializers.ModelSerialize
'id', 'node', "node_display",
]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tree = Node.tree()
def get_node_display(self, obj):
if hasattr(obj, 'node_key'):
return self.tree.get_node_full_tag(obj.node_key)
else:
return obj.node.full_value
return obj.node.full_value
class SystemUserUserRelationSerializer(RelationMixin, serializers.ModelSerializer):

View File

@ -1,17 +1,21 @@
# -*- coding: utf-8 -*-
#
from collections import defaultdict
from operator import add, sub
from assets.utils import is_asset_exists_in_node
from django.db.models.signals import (
post_save, m2m_changed, post_delete
post_save, m2m_changed, pre_delete, post_delete
)
from django.db.models.aggregates import Count
from django.db.models import Q, F
from django.dispatch import receiver
from common.local import thread_local
from common.exceptions import M2MReverseNotAllowed
from common.const.signals import PRE_ADD, POST_ADD, POST_REMOVE, PRE_CLEAR
from common.utils import get_logger
from common.decorator import on_transaction_commit
from orgs.utils import tmp_to_root_org
from .models import Asset, SystemUser, Node, AuthBook
from .utils import TreeService
from .models import Asset, SystemUser, Node, compute_parent_key
from users.models import User
from .tasks import (
update_assets_hardware_info_util,
test_asset_connectivity_util,
@ -54,15 +58,6 @@ def on_asset_created_or_update(sender, instance=None, created=False, **kwargs):
instance.nodes.add(Node.org_root())
@receiver(post_delete, sender=Asset)
def on_asset_delete(sender, instance=None, **kwargs):
"""
当资产删除时刷新节点节点中存在节点和资产的关系
"""
logger.debug("Asset delete signal recv: {}".format(instance))
Node.refresh_assets()
@receiver(post_save, sender=SystemUser, dispatch_uid="jms")
def on_system_user_update(sender, instance=None, created=True, **kwargs):
"""
@ -82,7 +77,7 @@ def on_system_user_assets_change(sender, instance=None, action='', model=None, p
"""
当系统用户和资产关系发生变化时应该重新推送系统用户到新添加的资产中
"""
if action != "post_add":
if action != POST_ADD:
return
logger.debug("System user assets change signal recv: {}".format(instance))
queryset = model.objects.filter(pk__in=pk_set)
@ -101,7 +96,7 @@ def on_system_user_users_change(sender, instance=None, action='', model=None, pk
"""
当系统用户和用户关系发生变化时应该重新推送系统用户资产中
"""
if action != "post_add":
if action != POST_ADD:
return
if not instance.username_same_with_user:
return
@ -120,7 +115,7 @@ def on_system_user_nodes_change(sender, instance=None, action=None, model=None,
"""
当系统用户和节点关系发生变化时应该将节点下资产关联到新的系统用户上
"""
if action != "post_add":
if action != POST_ADD:
return
logger.info("System user nodes update signal recv: {}".format(instance))
@ -135,104 +130,217 @@ def on_system_user_nodes_change(sender, instance=None, action=None, model=None,
@receiver(m2m_changed, sender=SystemUser.groups.through)
def on_system_user_groups_change(sender, instance=None, action=None, model=None,
pk_set=None, reverse=False, **kwargs):
def on_system_user_groups_change(instance, action, pk_set, reverse, **kwargs):
"""
当系统用户和用户组关系发生变化时应该将组下用户关联到新的系统用户上
"""
if action != "post_add" or reverse:
if action != POST_ADD:
return
if reverse:
raise M2MReverseNotAllowed
logger.info("System user groups update signal recv: {}".format(instance))
groups = model.objects.filter(pk__in=pk_set).annotate(users_count=Count("users"))
users = groups.filter(users_count__gt=0).values_list('users', flat=True)
instance.users.add(*tuple(users))
users = User.objects.filter(groups__id__in=pk_set).distinct()
instance.users.add(users)
@receiver(m2m_changed, sender=Asset.nodes.through)
def on_asset_nodes_change(sender, instance=None, action='', **kwargs):
def on_asset_nodes_add(instance, action, reverse, pk_set, **kwargs):
"""
资产节点发生变化时刷新节点
"""
if action.startswith('post'):
logger.debug("Asset nodes change signal recv: {}".format(instance))
Node.refresh_assets()
with tmp_to_root_org():
Node.refresh_assets()
本操作共访问 4 次数据库
@receiver(m2m_changed, sender=Asset.nodes.through)
def on_asset_nodes_add(sender, instance=None, action='', model=None, pk_set=None, **kwargs):
"""
当资产的节点发生变化时或者 当节点的资产关系发生变化时
节点下新增的资产添加到节点关联的系统用户中
"""
if action != "post_add":
if action != POST_ADD:
return
logger.debug("Assets node add signal recv: {}".format(action))
if model == Node:
nodes = model.objects.filter(pk__in=pk_set).values_list('key', flat=True)
assets = [instance.id]
else:
if reverse:
nodes = [instance.key]
assets = model.objects.filter(pk__in=pk_set).values_list('id', flat=True)
asset_ids = pk_set
else:
nodes = Node.objects.filter(pk__in=pk_set).values_list('key', flat=True)
asset_ids = [instance.id]
# 节点资产发生变化时,将资产关联到节点及祖先节点关联的系统用户, 只关注新增的
nodes_ancestors_keys = set()
node_tree = TreeService.new()
for node in nodes:
ancestors_keys = node_tree.ancestors_ids(nid=node)
nodes_ancestors_keys.update(ancestors_keys)
system_users = SystemUser.objects.filter(nodes__key__in=nodes_ancestors_keys)
nodes_ancestors_keys.update(Node.get_node_ancestor_keys(node, with_self=True))
system_users_assets = defaultdict(set)
for system_user in system_users:
assets_has_set = system_user.assets.all().filter(id__in=assets).values_list('id', flat=True)
assets_remain = set(assets) - set(assets_has_set)
system_users_assets[system_user].update(assets_remain)
for system_user, _assets in system_users_assets.items():
system_user.assets.add(*tuple(_assets))
# 查询所有祖先节点关联的系统用户,都是要跟资产建立关系的
system_user_ids = SystemUser.objects.filter(
nodes__key__in=nodes_ancestors_keys
).distinct().values_list('id', flat=True)
# 查询所有已存在的关系
m2m_model = SystemUser.assets.through
exist = set(m2m_model.objects.filter(
systemuser_id__in=system_user_ids, asset_id__in=asset_ids
).values_list('systemuser_id', 'asset_id'))
# TODO 优化
to_create = []
for system_user_id in system_user_ids:
asset_ids_to_push = []
for asset_id in asset_ids:
if (system_user_id, asset_id) in exist:
continue
asset_ids_to_push.append(asset_id)
to_create.append(m2m_model(
systemuser_id=system_user_id,
asset_id=asset_id
))
push_system_user_to_assets.delay(system_user_id, asset_ids_to_push)
m2m_model.objects.bulk_create(to_create)
def _update_node_assets_amount(node: Node, asset_pk_set: set, operator=add):
"""
一个节点与多个资产关系变化时更新计数
:param node: 节点实例
:param asset_pk_set: 资产的`id`集合, 内部不会修改该值
:param operator: 操作
* -> Node
# -> Asset
* [3]
/ \
* * [2]
/ \
* * [1]
/ / \
* [a] # # [b]
"""
# 获取节点[1]祖先节点的 `key` 含自己,也就是[1, 2, 3]节点的`key`
ancestor_keys = node.get_ancestor_keys(with_self=True)
ancestors = Node.objects.filter(key__in=ancestor_keys).order_by('-key')
to_update = []
for ancestor in ancestors:
# 迭代祖先节点的`key`,顺序是 [1] -> [2] -> [3]
# 查询该节点及其后代节点是否包含要操作的资产,将包含的从要操作的
# 资产集合中去掉,他们是重复节点,无论增加或删除都不会影响节点的资产数量
asset_pk_set -= set(Asset.objects.filter(
id__in=asset_pk_set
).filter(
Q(nodes__key__istartswith=f'{ancestor.key}:') |
Q(nodes__key=ancestor.key)
).distinct().values_list('id', flat=True))
if not asset_pk_set:
# 要操作的资产集合为空,说明都是重复资产,不用改变节点资产数量
# 而且既然它包含了,它的祖先节点肯定也包含了,所以祖先节点都不用
# 处理了
break
ancestor.assets_amount = operator(F('assets_amount'), len(asset_pk_set))
to_update.append(ancestor)
Node.objects.bulk_update(to_update, fields=('assets_amount', 'parent_key'))
def _remove_ancestor_keys(ancestor_key, tree_set):
# 这里判断 `ancestor_key` 不能是空,防止数据错误导致的死循环
# 判断是否在集合里,来区分是否已被处理过
while ancestor_key and ancestor_key in tree_set:
tree_set.remove(ancestor_key)
ancestor_key = compute_parent_key(ancestor_key)
def _update_nodes_asset_amount(node_keys, asset_pk, operator):
"""
一个资产与多个节点关系变化时更新计数
:param node_keys: 节点 id 的集合
:param asset_pk: 资产 id
:param operator: 操作
"""
# 所有相关节点的祖先节点,组成一棵局部树
ancestor_keys = set()
for key in node_keys:
ancestor_keys.update(Node.get_node_ancestor_keys(key))
# 相关节点可能是其他相关节点的祖先节点,如果是从相关节点里干掉
node_keys -= ancestor_keys
to_update_keys = []
for key in node_keys:
# 遍历相关节点,处理它及其祖先节点
# 查询该节点是否包含待处理资产
exists = is_asset_exists_in_node(asset_pk, key)
parent_key = compute_parent_key(key)
if exists:
# 如果资产在该节点,那么他及其祖先节点都不用处理
_remove_ancestor_keys(parent_key, ancestor_keys)
continue
else:
# 不存在,要更新本节点
to_update_keys.append(key)
# 这里判断 `parent_key` 不能是空,防止数据错误导致的死循环
# 判断是否在集合里,来区分是否已被处理过
while parent_key and parent_key in ancestor_keys:
exists = is_asset_exists_in_node(asset_pk, parent_key)
if exists:
_remove_ancestor_keys(parent_key, ancestor_keys)
break
else:
to_update_keys.append(parent_key)
ancestor_keys.remove(parent_key)
parent_key = compute_parent_key(parent_key)
Node.objects.filter(key__in=to_update_keys).update(
assets_amount=operator(F('assets_amount'), 1)
)
@receiver(m2m_changed, sender=Asset.nodes.through)
def on_asset_nodes_remove(sender, instance=None, action='', model=None,
pk_set=None, **kwargs):
def update_nodes_assets_amount(action, instance, reverse, pk_set, **kwargs):
# 不允许 `pre_clear` ,因为该信号没有 `pk_set`
# [官网](https://docs.djangoproject.com/en/3.1/ref/signals/#m2m-changed)
refused = (PRE_CLEAR,)
if action in refused:
raise ValueError
"""
监控资产删除节点关系, 或节点删除资产避免产生游离资产
"""
if action not in ["post_remove", "pre_clear", "post_clear"]:
mapper = {
PRE_ADD: add,
POST_REMOVE: sub
}
if action not in mapper:
return
if action == "pre_clear":
if model == Node:
instance._nodes = list(instance.nodes.all())
else:
instance._assets = list(instance.assets.all())
return
logger.debug("Assets node remove signal recv: {}".format(action))
if action == "post_remove":
queryset = model.objects.filter(pk__in=pk_set)
operator = mapper[action]
if reverse:
node: Node = instance
asset_pk_set = set(pk_set)
_update_node_assets_amount(node, asset_pk_set, operator)
else:
if model == Node:
queryset = instance._nodes
else:
queryset = instance._assets
if model == Node:
assets = [instance]
else:
assets = queryset
if isinstance(assets, list):
assets_not_has_node = []
for asset in assets:
if asset.nodes.all().count() == 0:
assets_not_has_node.append(asset.id)
else:
assets_not_has_node = assets.annotate(nodes_count=Count('nodes'))\
.filter(nodes_count=0).values_list('id', flat=True)
Node.org_root().assets.add(*tuple(assets_not_has_node))
asset_pk = instance.id
# 与资产直接关联的节点
node_keys = set(Node.objects.filter(id__in=pk_set).values_list('key', flat=True))
_update_nodes_asset_amount(node_keys, asset_pk, operator)
@receiver([post_save, post_delete], sender=Node)
def on_node_update_or_created(sender, **kwargs):
# 刷新节点
Node.refresh_nodes()
with tmp_to_root_org():
Node.refresh_nodes()
ASSET_DELETE_SIGNAL_FOR_NODE_TREE_PARAMS = 'asset_delete_signal_for_node_tree_params'
@receiver(pre_delete, sender=Asset)
def on_asset_delete(instance: Asset, **kwargs):
node_keys = Node.objects.filter(
assets=instance
).distinct().values_list('key', flat=True)
params = {
'node_keys': set(node_keys),
'asset_pk': instance.id,
'operator': sub
}
setattr(thread_local, ASSET_DELETE_SIGNAL_FOR_NODE_TREE_PARAMS, params)
@receiver(post_delete, sender=Asset)
def on_asset_post_delete(instance: Asset, **kwargs):
params = getattr(thread_local, ASSET_DELETE_SIGNAL_FOR_NODE_TREE_PARAMS, None)
if params and params.get('asset_pk') == instance.id:
_update_nodes_asset_amount(**params)

View File

@ -9,3 +9,4 @@ from .gather_asset_users import *
from .gather_asset_hardware_info import *
from .push_system_user import *
from .system_user_connectivity import *
from .nodes_amount import *

View File

@ -0,0 +1,14 @@
from celery import shared_task
from assets.utils import check_node_assets_amount
from common.utils import get_logger
from common.utils.timezone import now
logger = get_logger(__file__)
@shared_task()
def check_node_assets_amount_celery_task():
logger.info(f'>>> {now()} begin check_node_assets_amount_celery_task ...')
check_node_assets_amount()
logger.info(f'>>> {now()} end check_node_assets_amount_celery_task ...')

View File

@ -2,10 +2,12 @@
from itertools import groupby
from celery import shared_task
from common.db.utils import get_object_if_need, get_objects_if_need
from django.utils.translation import ugettext as _
from django.db.models import Empty
from common.utils import encrypt_password, get_logger
from assets.models import SystemUser, Asset
from orgs.utils import org_aware_func
from . import const
from .utils import clean_ansible_task_hosts, group_asset_by_platform
@ -221,6 +223,7 @@ def push_system_user_util(system_user, assets, task_name, username=None):
@shared_task(queue="ansible")
def push_system_user_to_assets_manual(system_user, username=None):
system_user = get_object_if_need(SystemUser, system_user)
assets = system_user.get_related_assets()
task_name = _("Push system users to assets: {}").format(system_user.name)
return push_system_user_util(system_user, assets, task_name=task_name, username=username)
@ -239,10 +242,10 @@ def push_system_user_a_asset_manual(system_user, asset, username=None):
@shared_task(queue="ansible")
def push_system_user_to_assets(system_user, assets, username=None):
task_name = _("Push system users to assets: {}").format(system_user.name)
system_user = get_object_if_need(SystemUser, system_user)
assets = get_objects_if_need(Asset, assets)
return push_system_user_util(system_user, assets, task_name, username=username)
# @shared_task
# @register_as_period_task(interval=3600)
# @after_app_ready_start

View File

@ -2,6 +2,7 @@
from django.urls import path, re_path
from rest_framework_nested import routers
from rest_framework_bulk.routes import BulkRouter
from django.db.transaction import non_atomic_requests
from common import api as capi
@ -54,9 +55,9 @@ urlpatterns = [
path('nodes/children/', api.NodeChildrenApi.as_view(), name='node-children-2'),
path('nodes/<uuid:pk>/children/add/', api.NodeAddChildrenApi.as_view(), name='node-add-children'),
path('nodes/<uuid:pk>/assets/', api.NodeAssetsApi.as_view(), name='node-assets'),
path('nodes/<uuid:pk>/assets/add/', api.NodeAddAssetsApi.as_view(), name='node-add-assets'),
path('nodes/<uuid:pk>/assets/replace/', api.NodeReplaceAssetsApi.as_view(), name='node-replace-assets'),
path('nodes/<uuid:pk>/assets/remove/', api.NodeRemoveAssetsApi.as_view(), name='node-remove-assets'),
path('nodes/<uuid:pk>/assets/add/', non_atomic_requests(api.NodeAddAssetsApi.as_view()), name='node-add-assets'),
path('nodes/<uuid:pk>/assets/replace/', non_atomic_requests(api.MoveAssetsToNodeApi.as_view()), name='node-replace-assets'),
path('nodes/<uuid:pk>/assets/remove/', non_atomic_requests(api.NodeRemoveAssetsApi.as_view()), name='node-remove-assets'),
path('nodes/<uuid:pk>/tasks/', api.NodeTaskCreateApi.as_view(), name='node-task-create'),
path('gateways/<uuid:pk>/test-connective/', api.GatewayTestConnectionApi.as_view(), name='test-gateway-connective'),

View File

@ -1,195 +1,52 @@
# ~*~ coding: utf-8 ~*~
#
from treelib import Tree
from treelib.exceptions import NodeIDAbsentError
from collections import defaultdict
from copy import deepcopy
from django.db.models import Q
from common.utils import get_logger, timeit, lazyproperty
from common.utils import get_logger, dict_get_any, is_uuid, get_object_or_none
from common.http import is_true
from .models import Asset, Node
logger = get_logger(__file__)
class TreeService(Tree):
tag_sep = ' / '
def check_node_assets_amount():
for node in Node.objects.all():
assets_amount = Asset.objects.filter(
Q(nodes__key__istartswith=f'{node.key}:') | Q(nodes=node)
).distinct().count()
@staticmethod
@timeit
def get_nodes_assets_map():
nodes_assets_map = defaultdict(set)
asset_node_list = Node.assets.through.objects.values_list(
'asset', 'node__key'
)
for asset_id, key in asset_node_list:
nodes_assets_map[key].add(asset_id)
return nodes_assets_map
if node.assets_amount != assets_amount:
print(f'>>> <Node:{node.key}> wrong assets amount '
f'{node.assets_amount} right is {assets_amount}')
node.assets_amount = assets_amount
node.save()
@classmethod
@timeit
def new(cls):
from .models import Node
all_nodes = list(Node.objects.all().values("key", "value"))
all_nodes.sort(key=lambda x: len(x["key"].split(":")))
tree = cls()
tree.create_node(tag='', identifier='', data={})
for node in all_nodes:
key = node["key"]
value = node["value"]
parent_key = ":".join(key.split(":")[:-1])
tree.safe_create_node(
tag=value, identifier=key,
parent=parent_key,
)
tree.init_assets()
return tree
def init_assets(self):
node_assets_map = self.get_nodes_assets_map()
for node in self.all_nodes_itr():
key = node.identifier
assets = node_assets_map.get(key, set())
data = {"assets": assets, "all_assets": None}
node.data = data
def is_asset_exists_in_node(asset_pk, node_key):
return Asset.objects.filter(
id=asset_pk
).filter(
Q(nodes__key__istartswith=f'{node_key}:') | Q(nodes__key=node_key)
).exists()
def safe_create_node(self, **kwargs):
parent = kwargs.get("parent")
if not self.contains(parent):
kwargs['parent'] = self.root
self.create_node(**kwargs)
def all_children_ids(self, nid, with_self=True):
children_ids = self.expand_tree(nid)
if not with_self:
next(children_ids)
return list(children_ids)
def is_query_node_all_assets(request):
request = request
query_all_arg = request.query_params.get('all')
show_current_asset_arg = request.query_params.get('show_current_asset')
if show_current_asset_arg is not None:
return not is_true(show_current_asset_arg)
return is_true(query_all_arg)
def all_children(self, nid, with_self=True, deep=False):
children_ids = self.all_children_ids(nid, with_self=with_self)
return [self.get_node(i, deep=deep) for i in children_ids]
def ancestors_ids(self, nid, with_self=True):
ancestor_ids = list(self.rsearch(nid))
ancestor_ids.pop()
if not with_self:
ancestor_ids.pop(0)
return ancestor_ids
def get_node(request):
node_id = dict_get_any(request.query_params, ['node', 'node_id'])
if not node_id:
return None
def ancestors(self, nid, with_self=False, deep=False):
ancestor_ids = self.ancestors_ids(nid, with_self=with_self)
ancestors = [self.get_node(i, deep=deep) for i in ancestor_ids]
return ancestors
def get_node_full_tag(self, nid):
ancestors = self.ancestors(nid, with_self=True)
ancestors.reverse()
return self.tag_sep.join([n.tag for n in ancestors])
def get_family(self, nid, deep=False):
ancestors = self.ancestors(nid, with_self=False, deep=deep)
children = self.all_children(nid, with_self=False)
return ancestors + [self[nid]] + children
@staticmethod
def is_parent(child, parent):
parent_id = child.bpointer
return parent_id == parent.identifier
def root_node(self):
return self.get_node(self.root)
def get_node(self, nid, deep=False):
node = super().get_node(nid)
if deep:
node = self.copy_node(node)
node.data = {}
return node
def parent(self, nid, deep=False):
parent = super().parent(nid)
if deep:
parent = self.copy_node(parent)
return parent
@lazyproperty
def invalid_assets(self):
assets = Asset.objects.filter(is_active=False).values_list('id', flat=True)
return assets
def set_assets(self, nid, assets):
node = self.get_node(nid)
if node.data is None:
node.data = {}
node.data["assets"] = assets
def assets(self, nid):
node = self.get_node(nid)
return node.data.get("assets", set())
def valid_assets(self, nid):
return set(self.assets(nid)) - set(self.invalid_assets)
def all_assets(self, nid):
node = self.get_node(nid)
if node.data is None:
node.data = {}
all_assets = node.data.get("all_assets")
if all_assets is not None:
return all_assets
all_assets = set(self.assets(nid))
try:
children = self.children(nid)
except NodeIDAbsentError:
children = []
for child in children:
all_assets.update(self.all_assets(child.identifier))
node.data["all_assets"] = all_assets
return all_assets
def all_valid_assets(self, nid):
return set(self.all_assets(nid)) - set(self.invalid_assets)
def assets_amount(self, nid):
return len(self.all_assets(nid))
def valid_assets_amount(self, nid):
return len(self.all_valid_assets(nid))
@staticmethod
def copy_node(node):
new_node = deepcopy(node)
new_node.fpointer = None
return new_node
def safe_add_ancestors(self, node, ancestors):
# 如果没有祖先节点,那么添加该节点, 父节点是root node
if len(ancestors) == 0:
parent = self.root_node()
else:
parent = ancestors[0]
# 如果当前节点已再树中,则移动当前节点到父节点中
# 这个是由于 当前节点放到了二级节点中
if not self.contains(parent.identifier):
# logger.debug('Add parent: {}'.format(parent.identifier))
self.safe_add_ancestors(parent, ancestors[1:])
if self.contains(node.identifier):
# msg = 'Move node to parent: {} => {}'.format(
# node.identifier, parent.identifier
# )
# logger.debug(msg)
self.move_node(node.identifier, parent.identifier)
else:
# logger.debug('Add node: {}'.format(node.identifier))
self.add_node(node, parent)
#
# def __getstate__(self):
# self.mutex = None
# self.all_nodes_assets_map = {}
# self.nodes_assets_map = {}
# return self.__dict__
# def __setstate__(self, state):
# self.__dict__ = state
if is_uuid(node_id):
node = get_object_or_none(Node, id=node_id)
else:
node = get_object_or_none(Node, key=node_id)
return node

View File

@ -73,12 +73,12 @@ class SSOViewSet(AuthMixin, JmsGenericViewSet):
token.save()
except (ValueError, SSOToken.DoesNotExist):
self.send_auth_signal(success=False, reason='authkey_invalid')
return HttpResponseRedirect(reverse('authentication:login'))
return HttpResponseRedirect(next_url)
# 判断是否过期
if (utcnow().timestamp() - token.date_created.timestamp()) > settings.AUTH_SSO_AUTHKEY_TTL:
self.send_auth_signal(success=False, reason='authkey_timeout')
return HttpResponseRedirect(reverse('authentication:login'))
return HttpResponseRedirect(next_url)
user = token.user
login(self.request, user, 'authentication.backends.api.SSOAuthentication')

View File

@ -0,0 +1,10 @@
from django_cas_ng.middleware import CASMiddleware as _CASMiddleware
from django.core.exceptions import MiddlewareNotUsed
from django.conf import settings
class CASMiddleware(_CASMiddleware):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not settings.AUTH_CAS:
raise MiddlewareNotUsed

View File

@ -0,0 +1,10 @@
from jms_oidc_rp.middleware import OIDCRefreshIDTokenMiddleware as _OIDCRefreshIDTokenMiddleware
from django.core.exceptions import MiddlewareNotUsed
from django.conf import settings
class OIDCRefreshIDTokenMiddleware(_OIDCRefreshIDTokenMiddleware):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not settings.AUTH_OPENID:
raise MiddlewareNotUsed

View File

@ -57,7 +57,7 @@
<div class="text-muted text-center">
<div>
<a href="{% url 'authentication:forgot-password' %}">
<a id="forgot_password" href="#">
<small>{% trans 'Forgot password' %}?</small>
</a>
</div>
@ -90,5 +90,17 @@
$('#password-hidden').val(passwordEncrypted); //返回给密码输入input
$('#form').submit();//post提交
}
var authDB = '{{ AUTH_DB }}';
var forgotPasswordUrl = "{% url 'authentication:forgot-password' %}";
$(document).ready(function () {
}).on('click', '#forgot_password', function () {
if (authDB === 'True'){
window.open(forgotPasswordUrl, "_blank")
}
else{
alert("{% trans 'You are using another authentication server, please contact your administrator' %}")
}
})
</script>
{% endblock %}

View File

@ -131,7 +131,7 @@
<button type="submit" class="btn btn-transparent" onclick="doLogin();return false;">{% trans 'Login' %}</button>
</div>
<div style="text-align: center">
<a href="{% url 'authentication:forgot-password' %}">
<a id="forgot_password" href="#">
<small>{% trans 'Forgot password' %}?</small>
</a>
</div>
@ -144,6 +144,7 @@
</div>
</div>
</div>
</div>
</body>
<script type="text/javascript" src="/static/js/plugins/jsencrypt/jsencrypt.min.js"></script>
@ -161,6 +162,18 @@
$('#password-hidden').val(passwordEncrypted); //返回给密码输入input
$('#contact-form').submit();//post提交
}
var authDB = '{{ AUTH_DB }}';
var forgotPasswordUrl = "{% url 'authentication:forgot-password' %}";
$(document).ready(function () {
}).on('click', '#forgot_password', function () {
if (authDB === 'True'){
window.open(forgotPasswordUrl, "_blank")
}
else{
alert("{% trans 'You are using another authentication server, please contact your administrator' %}")
}
})
</script>
</html>

View File

@ -131,7 +131,8 @@ class UserLoginView(mixins.AuthMixin, FormView):
context = {
'demo_mode': os.environ.get("DEMO_MODE"),
'AUTH_OPENID': settings.AUTH_OPENID,
'rsa_public_key': rsa_public_key
'rsa_public_key': rsa_public_key,
'AUTH_DB': settings.AUTH_DB
}
kwargs.update(context)
return super().get_context_data(**kwargs)

View File

@ -0,0 +1,2 @@
UPDATE_NODE_TREE_LOCK_KEY = 'org_level_transaction_lock_{org_id}_assets_update_node_tree'
UPDATE_MAPPING_NODE_TASK_LOCK_KEY = 'org_level_transaction_lock_{user_id}_update_mapping_node_task'

View File

@ -0,0 +1,14 @@
"""
`m2m_changed`
```
def m2m_signals_handler(action, instance, reverse, model, pk_set, using):
pass
```
"""
PRE_ADD = 'pre_add'
POST_ADD = 'post_add'
PRE_REMOVE = 'pre_remove'
POST_REMOVE = 'post_remove'
PRE_CLEAR = 'pre_clear'
POST_CLEAR = 'post_clear'

27
apps/common/db/utils.py Normal file
View File

@ -0,0 +1,27 @@
from common.utils import get_logger
logger = get_logger(__file__)
def get_object_if_need(model, pk):
if not isinstance(pk, model):
try:
return model.objects.get(id=pk)
except model.DoesNotExist as e:
logger.error(f'DoesNotExist: <{model.__name__}:{pk}> not exist')
raise e
return pk
def get_objects_if_need(model, pks):
if not pks:
return pks
if not isinstance(pks[0], model):
objs = list(model.objects.filter(id__in=pks))
if len(objs) != len(pks):
pks = set(pks)
exists_pks = {o.id for o in objs}
not_found_pks = ','.join(pks - exists_pks)
logger.error(f'DoesNotExist: <{model.__name__}: {not_found_pks}>')
return objs
return pks

View File

@ -1,12 +1,14 @@
from django.core.exceptions import PermissionDenied, ObjectDoesNotExist as DJObjectDoesNotExist
from django.http import Http404
from django.utils.translation import gettext
from rest_framework import exceptions
from rest_framework.views import set_rollback
from rest_framework.response import Response
from common.exceptions import JMSObjectDoesNotExist
from common.utils import get_logger
logger = get_logger(__name__)
def extract_object_name(exc, index=0):
@ -20,6 +22,8 @@ def extract_object_name(exc, index=0):
def common_exception_handler(exc, context):
logger.exception('')
if isinstance(exc, Http404):
exc = JMSObjectDoesNotExist(object_name=extract_object_name(exc, 1))
elif isinstance(exc, PermissionDenied):

View File

@ -18,3 +18,18 @@ class JMSObjectDoesNotExist(APIException):
if detail is None and object_name:
detail = self.default_detail % object_name
super(JMSObjectDoesNotExist, self).__init__(detail=detail, code=code)
class SomeoneIsDoingThis(JMSException):
status_code = status.HTTP_409_CONFLICT
default_detail = _('Someone else is doing this. Please wait for complete')
class Timeout(JMSException):
status_code = status.HTTP_408_REQUEST_TIMEOUT
default_detail = _('Your request timeout')
class M2MReverseNotAllowed(JMSException):
status_code = status.HTTP_400_BAD_REQUEST
default_detail = _('M2M reverse not allowed')

View File

@ -3,6 +3,8 @@
from django.http import HttpResponse
from django.utils.encoding import iri_to_uri
from rest_framework.serializers import BooleanField
class HttpResponseTemporaryRedirect(HttpResponse):
status_code = 307
@ -14,3 +16,7 @@ class HttpResponseTemporaryRedirect(HttpResponse):
def get_remote_addr(request):
return request.META.get("HTTP_X_FORWARDED_HOST") or request.META.get("REMOTE_ADDR")
def is_true(value):
return value in BooleanField.TRUE_VALUES

View File

@ -11,8 +11,6 @@ from django.core.cache import cache
from django.http import JsonResponse
from rest_framework.response import Response
from rest_framework.settings import api_settings
from rest_framework import status
from rest_framework_bulk.drf3.mixins import BulkDestroyModelMixin
from common.drf.filters import IDSpmFilter, CustomFilter, IDInFilter
from ..utils import lazyproperty
@ -237,6 +235,7 @@ class RelationMixin:
for i in instances:
to_id = getattr(i, self.to_field).id
# TODO 优化,不应该每次都查询数据库
from_obj = getattr(i, self.from_field)
from_to_mapper[from_obj].append(to_id)

View File

@ -0,0 +1,17 @@
from concurrent.futures import ThreadPoolExecutor
class SingletonThreadPoolExecutor(ThreadPoolExecutor):
"""
该类不要直接实例化
"""
def __new__(cls, max_workers=None, thread_name_prefix=None):
if cls is SingletonThreadPoolExecutor:
raise NotImplementedError
if getattr(cls, '_object', None) is None:
cls._object = ThreadPoolExecutor(
max_workers=max_workers,
thread_name_prefix=thread_name_prefix
)
return cls._object

View File

@ -164,6 +164,11 @@ def get_request_ip_or_data(request):
return ip
def get_request_user_agent(request):
user_agent = request.META.get('HTTP_USER_AGENT', '')
return user_agent
def validate_ip(ip):
try:
ipaddress.ip_address(ip)

View File

@ -18,6 +18,7 @@ from importlib import import_module
from django.urls import reverse_lazy
from django.contrib.staticfiles.templatetags.staticfiles import static
from urllib.parse import urljoin, urlparse
from django.utils.translation import ugettext_lazy as _
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
PROJECT_DIR = os.path.dirname(BASE_DIR)
@ -256,6 +257,7 @@ class Config(dict):
'SYSLOG_FACILITY': 'user',
'SYSLOG_SOCKTYPE': 2,
'PERM_SINGLE_ASSET_TO_UNGROUP_NODE': False,
'PERM_EXPIRED_CHECK_PERIODIC': 60 * 60,
'WINDOWS_SSH_DEFAULT_SHELL': 'cmd',
'FLOWER_URL': "127.0.0.1:5555",
'DEFAULT_ORG_SHOW_ALL_USERS': True,
@ -268,7 +270,11 @@ class Config(dict):
'TIME_ZONE': 'Asia/Shanghai',
'CHANGE_AUTH_PLAN_SECURE_MODE_ENABLED': True,
'USER_LOGIN_SINGLE_MACHINE_ENABLED': False,
'TICKETS_ENABLED': True
'TICKETS_ENABLED': True,
'SESSION_COOKIE_SECURE': False,
'CSRF_COOKIE_SECURE': False,
'REFERER_CHECK_ENABLED': False,
'SERVER_REPLAY_STORAGE': {}
}
def compatible_auth_openid_of_key(self):
@ -449,6 +455,9 @@ class DynamicConfig:
backends.insert(0, 'authentication.backends.api.SSOAuthentication')
return backends
def AUTH_DB(self):
return len(self.AUTHENTICATION_BACKENDS()) == 2
def XPACK_LICENSE_IS_VALID(self):
if not HAS_XPACK:
return False
@ -458,6 +467,16 @@ class DynamicConfig:
except:
return False
def XPACK_INTERFACE_LOGIN_TITLE(self):
default_title = _('Welcome to the JumpServer open source fortress')
if not HAS_XPACK:
return default_title
try:
from xpack.plugins.interface.models import Interface
return Interface.get_login_title()
except:
return default_title
def LOGO_URLS(self):
logo_urls = {'logo_logout': static('img/logo.png'),
'logo_index': static('img/logo_text.png'),

View File

@ -6,6 +6,8 @@ import pytz
from django.utils import timezone
from django.shortcuts import HttpResponse
from django.conf import settings
from django.core.exceptions import MiddlewareNotUsed
from django.http.response import HttpResponseForbidden
from .utils import set_current_request
@ -43,6 +45,7 @@ class DemoMiddleware:
if self.DEMO_MODE_ENABLED:
print("Demo mode enabled, reject unsafe method and url")
raise MiddlewareNotUsed
def __call__(self, request):
if self.DEMO_MODE_ENABLED and request.method not in self.SAFE_METHOD \
@ -61,7 +64,31 @@ class RequestMiddleware:
set_current_request(request)
response = self.get_response(request)
is_request_api = request.path.startswith('/api')
if not settings.SESSION_EXPIRE_AT_BROWSER_CLOSE and not is_request_api:
if not settings.SESSION_EXPIRE_AT_BROWSER_CLOSE and \
not is_request_api:
age = request.session.get_expiry_age()
request.session.set_expiry(age)
return response
class RefererCheckMiddleware:
def __init__(self, get_response):
if not settings.REFERER_CHECK_ENABLED:
raise MiddlewareNotUsed
self.get_response = get_response
self.http_pattern = re.compile('https?://')
def check_referer(self, request):
referer = request.META.get('HTTP_REFERER', '')
referer = self.http_pattern.sub('', referer)
if not referer:
return True
remote_host = request.get_host()
return referer.startswith(remote_host)
def __call__(self, request):
match = self.check_referer(request)
if not match:
return HttpResponseForbidden('CSRF CHECK ERROR')
response = self.get_response(request)
return response

View File

@ -9,6 +9,9 @@ from ..const import CONFIG, DYNAMIC, PROJECT_DIR
OTP_ISSUER_NAME = CONFIG.OTP_ISSUER_NAME
OTP_VALID_WINDOW = CONFIG.OTP_VALID_WINDOW
# Auth DB
AUTH_DB = DYNAMIC.AUTH_DB
# Auth LDAP settings
AUTH_LDAP = DYNAMIC.AUTH_LDAP
AUTH_LDAP_SERVER_URI = DYNAMIC.AUTH_LDAP_SERVER_URI

View File

@ -76,12 +76,13 @@ MIDDLEWARE = [
'django.contrib.auth.middleware.AuthenticationMiddleware',
'django.contrib.messages.middleware.MessageMiddleware',
'django.middleware.clickjacking.XFrameOptionsMiddleware',
'jms_oidc_rp.middleware.OIDCRefreshIDTokenMiddleware',
'django_cas_ng.middleware.CASMiddleware',
'jumpserver.middleware.TimezoneMiddleware',
'jumpserver.middleware.DemoMiddleware',
'jumpserver.middleware.RequestMiddleware',
'jumpserver.middleware.RefererCheckMiddleware',
'orgs.middleware.OrgMiddleware',
'authentication.backends.oidc.middleware.OIDCRefreshIDTokenMiddleware',
'authentication.backends.cas.middleware.CASMiddleware',
]
@ -245,6 +246,6 @@ CACHES = {
}
}
FORCE_SCRIPT_NAME = CONFIG.FORCE_SCRIPT_NAME
SESSION_COOKIE_SECURE = CONFIG.SESSION_COOKIE_SECURE
CSRF_COOKIE_SECURE = CONFIG.CSRF_COOKIE_SECURE

View File

@ -12,6 +12,17 @@ DEFAULT_TERMINAL_COMMAND_STORAGE = {
},
}
TERMINAL_COMMAND_STORAGE = DYNAMIC.TERMINAL_COMMAND_STORAGE or {}
# Server 类型的录像存储
SERVER_REPLAY_STORAGE = CONFIG.SERVER_REPLAY_STORAGE
# SERVER_REPLAY_STORAGE = {
# 'TYPE': 's3',
# 'BUCKET': '',
# 'ACCESS_KEY': '',
# 'SECRET_KEY': '',
# 'ENDPOINT': ''
# }
DEFAULT_TERMINAL_REPLAY_STORAGE = {
"default": {
"TYPE": "server",
@ -64,6 +75,7 @@ BACKEND_ASSET_USER_AUTH_VAULT = False
DEFAULT_ORG_SHOW_ALL_USERS = CONFIG.DEFAULT_ORG_SHOW_ALL_USERS
PERM_SINGLE_ASSET_TO_UNGROUP_NODE = CONFIG.PERM_SINGLE_ASSET_TO_UNGROUP_NODE
PERM_EXPIRED_CHECK_PERIODIC = CONFIG.PERM_EXPIRED_CHECK_PERIODIC
WINDOWS_SSH_DEFAULT_SHELL = CONFIG.WINDOWS_SSH_DEFAULT_SHELL
FLOWER_URL = CONFIG.FLOWER_URL
@ -96,6 +108,8 @@ AUTH_EXPIRED_SECONDS = 60 * 5
# XPACK
XPACK_LICENSE_IS_VALID = DYNAMIC.XPACK_LICENSE_IS_VALID
XPACK_INTERFACE_LOGIN_TITLE = DYNAMIC.XPACK_INTERFACE_LOGIN_TITLE
LOGO_URLS = DYNAMIC.LOGO_URLS
CHANGE_AUTH_PLAN_SECURE_MODE_ENABLED = CONFIG.CHANGE_AUTH_PLAN_SECURE_MODE_ENABLED
@ -103,3 +117,4 @@ CHANGE_AUTH_PLAN_SECURE_MODE_ENABLED = CONFIG.CHANGE_AUTH_PLAN_SECURE_MODE_ENABL
DATETIME_DISPLAY_FORMAT = '%Y-%m-%d %H:%M:%S'
TICKETS_ENABLED = CONFIG.TICKETS_ENABLED
REFERER_CHECK_ENABLED = CONFIG.REFERER_CHECK_ENABLED

Binary file not shown.

File diff suppressed because it is too large Load Diff

View File

@ -3,12 +3,13 @@
from rest_framework import viewsets
from rest_framework.exceptions import ValidationError
from django.db import transaction
from django.db.models import Q
from django.utils.translation import ugettext as _
from django.conf import settings
from assets.models import Asset, Node
from orgs.mixins.api import RootOrgViewMixin
from common.permissions import IsValidUser
from perms.utils import AssetPermissionUtil
from ..models import CommandExecution
from ..serializers import CommandExecutionSerializer
from ..tasks import run_command_execution
@ -27,9 +28,34 @@ class CommandExecutionViewSet(RootOrgViewMixin, viewsets.ModelViewSet):
data = serializer.validated_data
assets = data["hosts"]
system_user = data["run_as"]
util = AssetPermissionUtil(self.request.user)
util.filter_permissions(system_users=system_user.id)
permed_assets = util.get_assets().filter(id__in=[a.id for a in assets])
user = self.request.user
q = Q(granted_by_permissions__system_users__id=system_user.id) & (
Q(granted_by_permissions__users=user) |
Q(granted_by_permissions__user_groups__users=user)
)
permed_assets = set()
permed_assets.update(
Asset.objects.filter(
id__in=[a.id for a in assets]
).filter(q).distinct()
)
node_keys = Node.objects.filter(q).distinct().values_list('key', flat=True)
nodes_assets_q = Q()
for _key in node_keys:
nodes_assets_q |= Q(nodes__key__startswith=f'{_key}:')
nodes_assets_q |= Q(nodes__key=_key)
permed_assets.update(
Asset.objects.filter(
id__in=[a.id for a in assets]
).filter(
nodes_assets_q
).distinct()
)
invalid_assets = set(assets) - set(permed_assets)
if invalid_assets:
msg = _("Not has host {} permission").format(

View File

@ -4,6 +4,7 @@ import os
from kombu import Exchange, Queue
from celery import Celery
from celery.schedules import crontab
# set the default Django settings module for the 'celery' program.
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'jumpserver.settings')
@ -19,6 +20,7 @@ configs = {k: v for k, v in settings.__dict__.items() if k.startswith('CELERY')}
configs["CELERY_QUEUES"] = [
Queue("celery", Exchange("celery"), routing_key="celery"),
Queue("ansible", Exchange("ansible"), routing_key="ansible"),
Queue("celery_node_tree", Exchange("celery_node_tree"), routing_key="celery_node_tree")
]
configs["CELERY_ROUTES"] = {
"ops.tasks.run_ansible_task": {'exchange': 'ansible', 'routing_key': 'ansible'},
@ -27,3 +29,16 @@ configs["CELERY_ROUTES"] = {
app.namespace = 'CELERY'
app.conf.update(configs)
app.autodiscover_tasks(lambda: [app_config.split('.')[0] for app_config in settings.INSTALLED_APPS])
app.conf.beat_schedule = {
'check-asset-permission-expired': {
'task': 'perms.tasks.check_asset_permission_expired',
'schedule': settings.PERM_EXPIRED_CHECK_PERIODIC,
'args': ()
},
'check-node-assets-amount': {
'task': 'assets.tasks.nodes_amount.check_node_assets_amount_celery_task',
'schedule': crontab(minute=0, hour=0),
'args': ()
},
}

View File

@ -2,9 +2,11 @@
#
from django.utils.translation import ugettext as _
from django.db.models import Q
from rest_framework import status, generics
from rest_framework.views import Response
from rest_framework_bulk import BulkModelViewSet
from rest_framework.mixins import CreateModelMixin
from common.permissions import IsSuperUserOrAppUser
from common.drf.api import JMSBulkRelationModelViewSet
@ -19,6 +21,8 @@ from perms.models import AssetPermission
from orgs.utils import current_org
from common.utils import get_logger
from .filters import OrgMemberRelationFilterSet
from .models import OrganizationMember
logger = get_logger(__file__)
@ -71,6 +75,57 @@ class OrgMemberRelationBulkViewSet(JMSBulkRelationModelViewSet):
serializer_class = OrgMemberSerializer
filterset_class = OrgMemberRelationFilterSet
@staticmethod
def clear_request_data(request):
data = request.data
ignore_already_exist = request.query_params.get('ignore_already_exist')
if not ignore_already_exist:
return data
query_params = Q()
for _data in data:
query_fields = {}
org = _data.get('org')
if org:
query_fields.update({'org': org})
user = _data.get('user')
if user:
query_fields.update({'user': user})
role = _data.get('role')
if role:
query_fields.update({'role': role})
query_params |= Q(**query_fields)
if not query_params:
return data
members = OrganizationMember.objects.filter(query_params)
members = [
{'org': str(member.org_id), 'user': str(member.user_id), 'role': member.role}
for member in members
]
if not members:
return data
for member in members:
if member in data:
data.remove(member)
return data
def create(self, request, *args, **kwargs):
bulk = isinstance(request.data, list)
if not bulk:
return CreateModelMixin.create(self, request, *args, **kwargs)
else:
data = self.clear_request_data(request)
serializer = self.get_serializer(data=data, many=True)
serializer.is_valid(raise_exception=True)
self.perform_bulk_create(serializer)
return Response(serializer.data, status=status.HTTP_201_CREATED)
def perform_bulk_destroy(self, queryset):
objs = list(queryset.all().prefetch_related('user', 'org'))
queryset.delete()

131
apps/orgs/lock.py Normal file
View File

@ -0,0 +1,131 @@
from uuid import uuid4
from functools import wraps
from django.core.cache import cache
from django.db.transaction import atomic
from rest_framework.request import Request
from rest_framework.exceptions import NotAuthenticated
from orgs.utils import current_org
from common.exceptions import SomeoneIsDoingThis, Timeout
from common.utils.timezone import dt_formater, now
# Redis 中锁值得模板,该模板提供了很强的可读性,方便调试与排错
VALUE_TEMPLATE = '{stage}:{username}:{user_id}:{now}:{rand_str}'
# 锁的状态
DOING = 'doing' # 处理中,此状态的锁可以被干掉
COMMITING = 'commiting' # 提交事务中,此状态很重要,要确保事务在锁消失之前返回了,不要轻易删除该锁
client = cache.client.get_client(write=True)
"""
将锁的状态从 `doing` 切换到 `commiting`
KEYS[1] key
ARGV[1]: doingvalue
ARGV[2]: commitingvalue
ARGV[3]: timeout
"""
change_lock_state_to_commiting_lua = '''
if (redis.call("get", KEYS[1]) == ARGV[1])
then
return redis.call("set", KEYS[1], ARGV[2], "EX", ARGV[3], "XX")
else
return 0
end
'''
change_lock_state_to_commiting_lua_obj = client.register_script(change_lock_state_to_commiting_lua)
"""
释放锁两种`value`都要检查`doing``commiting`
KEYS[1] key
ARGV[1]: 两个 `value` 中的其中一个
ARGV[2]: 两个 `value` 中的其中一个
"""
release_lua = '''
if (redis.call("get",KEYS[1]) == ARGV[1] or redis.call("get",KEYS[1]) == ARGV[2])
then
return redis.call("del",KEYS[1])
else
return 0
end
'''
release_lua_obj = client.register_script(release_lua)
def acquire(key, value, timeout):
return client.set(key, value, ex=timeout, nx=True)
def get(key):
return client.get(key)
def change_lock_state_to_commiting(key, doingvalue, commitingvalue, timeout=600):
# 将锁的状态从 `doing` 切换到 `commiting`
return bool(change_lock_state_to_commiting_lua_obj(keys=(key,), args=(doingvalue, commitingvalue, timeout)))
def release(key, value1, value2):
# 释放锁,两种`value` `doing`和`commiting` 都要检查
return release_lua_obj(keys=(key,), args=(value1, value2))
def _generate_value(request: Request, stage=DOING):
# 不支持匿名用户
user = request.user
if user.is_anonymous:
raise NotAuthenticated
return VALUE_TEMPLATE.format(
stage=stage, username=user.username, user_id=user.id,
now=dt_formater(now()), rand_str=uuid4()
)
default_wait_msg = SomeoneIsDoingThis.default_detail
def org_level_transaction_lock(key, timeout=300, wait_msg=default_wait_msg):
"""
被装饰的 `View` 必须取消自身的 `ATOMIC_REQUESTS`因为该装饰器要有事务的完全控制权
[官网](https://docs.djangoproject.com/en/3.1/topics/db/transactions/#tying-transactions-to-http-requests)
1. 获取锁只有当锁对应的 `key` 不存在时成功获取`value` 设置为 `doing`
2. 开启事务本次请求的事务必须确保在这里开启
3. 执行 `View`
4. `View` 体执行结束未异常此时事务还未提交
5. 检查锁是否过时过时事务回滚不过时重新设置`key`延长`key`有效期已确保足够时间提交事务同时把`key`的状态改为`commiting`
6. 提交事务
7. 释放锁释放的时候会检查`doing``commiting`的值因为删除或者更改锁必须提供与当前锁的`value`相同的值确保不误删
[锁参考文章](http://doc.redisfans.com/string/set.html#id2)
"""
def decorator(fun):
@wraps(fun)
def wrapper(request, *args, **kwargs):
# `key`可能是组织相关的,如果是把组织`id`加上
_key = key.format(org_id=current_org.id)
doing_value = _generate_value(request)
commiting_value = _generate_value(request, stage=COMMITING)
try:
lock = acquire(_key, doing_value, timeout)
if not lock:
raise SomeoneIsDoingThis(detail=wait_msg)
with atomic(savepoint=False):
ret = fun(request, *args, **kwargs)
# 提交事务前,检查一下锁是否还在
# 锁在的话,更新锁的状态为 `commiting`,延长锁时间,确保事务提交
# 锁不在的话回滚
ok = change_lock_state_to_commiting(_key, doing_value, commiting_value)
if not ok:
# 超时或者被中断了
raise Timeout
return ret
finally:
# 释放锁,锁的两个值都要尝试,不确定异常是从什么位置抛出的
release(_key, commiting_value, doing_value)
return wrapper
return decorator

View File

@ -1,5 +1,6 @@
import uuid
from functools import partial
from itertools import chain
from django.db import models
from django.db.models import signals
@ -229,6 +230,10 @@ def _none2list(*args):
return ([] if v is None else v for v in args)
def _users2pks(users, admins, auditors):
return [user.pk for user in chain(users, admins, auditors)]
class UserRoleMapper(dict):
def __init__(self, container=set):
super().__init__()
@ -266,7 +271,7 @@ class OrgMemeberManager(models.Manager):
users, admins, auditors = _none2list(users, admins, auditors)
send = partial(signals.m2m_changed.send, sender=self.model, instance=org, reverse=False,
model=User, pk_set=[*users, *admins, *auditors], using=self.db)
model=User, pk_set=_users2pks(users, admins, auditors), using=self.db)
send(action="pre_remove")
self.filter(org_id=org.id).filter(
@ -297,7 +302,7 @@ class OrgMemeberManager(models.Manager):
oms_add.append(self.model(org_id=org.id, user_id=user, role=role))
send = partial(signals.m2m_changed.send, sender=self.model, instance=org, reverse=False,
model=User, pk_set=[*users, *admins, *auditors], using=self.db)
model=User, pk_set=_users2pks(users, admins, auditors), using=self.db)
send(action='pre_add')
self.bulk_create(oms_add)

View File

@ -2,9 +2,12 @@
#
from rest_framework import generics
from django.db.models import F, Value
from django.db.models import Q
from django.db.models.functions import Concat
from django.shortcuts import get_object_or_404
from assets.models import Node, Asset
from orgs.mixins.api import OrgRelationMixin
from orgs.mixins.api import OrgBulkModelViewSet
from orgs.utils import current_org
from common.permissions import IsOrgAdmin
@ -19,9 +22,9 @@ __all__ = [
]
class RelationMixin(OrgBulkModelViewSet):
class RelationMixin(OrgRelationMixin, OrgBulkModelViewSet):
def get_queryset(self):
queryset = self.model.objects.all()
queryset = super().get_queryset()
org_id = current_org.org_id()
if org_id is not None:
queryset = queryset.filter(assetpermission__org_id=org_id)
@ -31,7 +34,7 @@ class RelationMixin(OrgBulkModelViewSet):
class AssetPermissionUserRelationViewSet(RelationMixin):
serializer_class = serializers.AssetPermissionUserRelationSerializer
model = models.AssetPermission.users.through
m2m_field = models.AssetPermission.users.field
permission_classes = (IsOrgAdmin,)
filter_fields = [
'id', "user", "assetpermission",
@ -62,7 +65,7 @@ class AssetPermissionAllUserListApi(generics.ListAPIView):
class AssetPermissionUserGroupRelationViewSet(RelationMixin):
serializer_class = serializers.AssetPermissionUserGroupRelationSerializer
model = models.AssetPermission.user_groups.through
m2m_field = models.AssetPermission.user_groups.field
permission_classes = (IsOrgAdmin,)
filter_fields = [
'id', "usergroup", "assetpermission"
@ -78,7 +81,7 @@ class AssetPermissionUserGroupRelationViewSet(RelationMixin):
class AssetPermissionAssetRelationViewSet(RelationMixin):
serializer_class = serializers.AssetPermissionAssetRelationSerializer
model = models.AssetPermission.assets.through
m2m_field = models.AssetPermission.assets.field
permission_classes = (IsOrgAdmin,)
filter_fields = [
'id', 'asset', 'assetpermission',
@ -101,15 +104,20 @@ class AssetPermissionAllAssetListApi(generics.ListAPIView):
def get_queryset(self):
pk = self.kwargs.get("pk")
perm = get_object_or_404(models.AssetPermission, pk=pk)
assets = perm.get_all_assets().only(
*self.serializer_class.Meta.only_fields
)
asset_q = Q(granted_by_permissions=perm)
granted_node_keys = Node.objects.filter(granted_by_permissions=perm).distinct().values_list('key', flat=True)
for key in granted_node_keys:
asset_q |= Q(nodes__key__startswith=f'{key}:')
asset_q |= Q(nodes__key=key)
assets = Asset.objects.filter(asset_q).only(*self.serializer_class.Meta.only_fields).distinct()
return assets
class AssetPermissionNodeRelationViewSet(RelationMixin):
serializer_class = serializers.AssetPermissionNodeRelationSerializer
model = models.AssetPermission.nodes.through
m2m_field = models.AssetPermission.nodes.field
permission_classes = (IsOrgAdmin,)
filter_fields = [
'id', 'node', 'assetpermission',
@ -125,7 +133,7 @@ class AssetPermissionNodeRelationViewSet(RelationMixin):
class AssetPermissionSystemUserRelationViewSet(RelationMixin):
serializer_class = serializers.AssetPermissionSystemUserRelationSerializer
model = models.AssetPermission.system_users.through
m2m_field = models.AssetPermission.system_users.field
permission_classes = (IsOrgAdmin,)
filter_fields = [
'id', 'systemuser', 'assetpermission',

View File

@ -1,21 +1,24 @@
from rest_framework import generics
from django.db.models import Q
from django.utils.decorators import method_decorator
from assets.models import SystemUser
from common.permissions import IsValidUser
from orgs.utils import tmp_to_root_org
from .. import serializers
@method_decorator(tmp_to_root_org(), name='list')
class SystemUserPermission(generics.ListAPIView):
permission_classes = (IsValidUser,)
serializer_class = serializers.SystemUserSerializer
def get_queryset(self):
return self.get_user_system_users()
def get_user_system_users(self):
from perms.utils import AssetPermissionUtil
user = self.request.user
with tmp_to_root_org():
util = AssetPermissionUtil(user)
return util.get_system_users()
queryset = SystemUser.objects.filter(
Q(granted_by_permissions__users=user) |
Q(granted_by_permissions__user_groups__users=user)
).distinct()
return queryset

View File

@ -59,7 +59,8 @@ class UserGrantedDatabaseAppsAsTreeApi(UserGrantedDatabaseAppsApi):
tree_root = None
data = []
if not only_database_app:
tree_root = utils.construct_database_apps_tree_root()
amount = len(database_apps)
tree_root = utils.construct_database_apps_tree_root(amount)
data.append(tree_root)
for database_app in database_apps:
node = utils.parse_database_app_to_tree_node(tree_root, database_app)

View File

@ -1,44 +1,169 @@
# -*- coding: utf-8 -*-
#
from itertools import chain
from django.db.models import Q
from rest_framework.generics import ListAPIView
from rest_framework.response import Response
from common.permissions import IsOrgAdminOrAppUser
from common.utils import lazyproperty
from perms.models import AssetPermission
from assets.models import Asset, Node
from . import user_permission as uapi
from .mixin import UserGroupPermissionMixin
from perms import serializers
from perms.utils.asset_permission import get_asset_system_users_id_with_actions_by_group
from assets.api.mixin import SerializeToTreeNodeMixin
from users.models import UserGroup
__all__ = [
'UserGroupGrantedAssetsApi', 'UserGroupGrantedNodesApi',
'UserGroupGrantedNodeAssetsApi', 'UserGroupGrantedNodeChildrenApi',
'UserGroupGrantedNodeAssetsApi',
'UserGroupGrantedNodeChildrenAsTreeApi',
'UserGroupGrantedNodeChildrenWithAssetsAsTreeApi',
'UserGroupGrantedAssetSystemUsersApi',
# 'UserGroupGrantedNodeChildrenWithAssetsAsTreeApi',
]
class UserGroupGrantedAssetsApi(UserGroupPermissionMixin, uapi.UserGrantedAssetsApi):
pass
class UserGroupMixin:
@lazyproperty
def group(self):
group_id = self.kwargs.get('pk')
return UserGroup.objects.get(id=group_id)
class UserGroupGrantedNodeAssetsApi(UserGroupPermissionMixin, uapi.UserGrantedNodeAssetsApi):
pass
class UserGroupGrantedAssetsApi(ListAPIView):
"""
获取用户组直接授权的资产
"""
permission_classes = (IsOrgAdminOrAppUser,)
serializer_class = serializers.AssetGrantedSerializer
only_fields = serializers.AssetGrantedSerializer.Meta.only_fields
filter_fields = ['hostname', 'ip', 'id', 'comment']
search_fields = ['hostname', 'ip', 'comment']
def get_queryset(self):
user_group_id = self.kwargs.get('pk', '')
return Asset.objects.filter(
Q(granted_by_permissions__user_groups__id=user_group_id)
).distinct().only(
*self.only_fields
)
class UserGroupGrantedNodesApi(UserGroupPermissionMixin, uapi.UserGrantedNodesApi):
pass
class UserGroupGrantedNodeAssetsApi(ListAPIView):
permission_classes = (IsOrgAdminOrAppUser,)
serializer_class = serializers.AssetGrantedSerializer
only_fields = serializers.AssetGrantedSerializer.Meta.only_fields
filter_fields = ['hostname', 'ip', 'id', 'comment']
search_fields = ['hostname', 'ip', 'comment']
def get_queryset(self):
user_group_id = self.kwargs.get('pk', '')
node_id = self.kwargs.get("node_id")
node = Node.objects.get(id=node_id)
granted = AssetPermission.objects.filter(
user_groups__id=user_group_id,
nodes__id=node_id
).exists()
if granted:
assets = Asset.objects.filter(
Q(nodes__key__startswith=f'{node.key}:') |
Q(nodes__key=node.key)
)
return assets
else:
granted_node_keys = Node.objects.filter(
granted_by_permissions__user_groups__id=user_group_id,
key__startswith=f'{node.key}:'
).distinct().values_list('key', flat=True)
granted_node_q = Q()
for _key in granted_node_keys:
granted_node_q |= Q(nodes__key__startswith=f'{_key}:')
granted_node_q |= Q(nodes__key=_key)
granted_asset_q = (
Q(granted_by_permissions__user_groups__id=user_group_id) &
(
Q(nodes__key__startswith=f'{node.key}:') |
Q(nodes__key=node.key)
)
)
assets = Asset.objects.filter(
granted_node_q | granted_asset_q
).distinct()
return assets
class UserGroupGrantedNodeChildrenApi(UserGroupPermissionMixin, uapi.UserGrantedNodeChildrenApi):
pass
class UserGroupGrantedNodesApi(ListAPIView):
serializer_class = serializers.NodeGrantedSerializer
permission_classes = (IsOrgAdminOrAppUser,)
def get_queryset(self):
user_group_id = self.kwargs.get('pk', '')
nodes = Node.objects.filter(
Q(granted_by_permissions__user_groups__id=user_group_id) |
Q(assets__granted_by_permissions__user_groups__id=user_group_id)
)
return nodes
class UserGroupGrantedNodeChildrenAsTreeApi(UserGroupPermissionMixin, uapi.UserGrantedNodeChildrenAsTreeApi):
pass
class UserGroupGrantedNodeChildrenAsTreeApi(SerializeToTreeNodeMixin, ListAPIView):
permission_classes = (IsOrgAdminOrAppUser,)
def get_children_nodes(self, parent_key):
return Node.objects.filter(parent_key=parent_key)
def add_children_key(self, node_key, key, key_set):
if key.startswith(f'{node_key}:'):
try:
end = key.index(':', len(node_key) + 1)
key_set.add(key[:end])
except ValueError:
key_set.add(key)
def get_nodes(self):
group_id = self.kwargs.get('pk')
node_key = self.request.query_params.get('key', None)
granted_keys = Node.objects.filter(
granted_by_permissions__user_groups__id=group_id
).values_list('key', flat=True)
asset_granted_keys = Node.objects.filter(
assets__granted_by_permissions__user_groups__id=group_id
).values_list('key', flat=True)
if node_key is None:
root_keys = set()
for _key in chain(granted_keys, asset_granted_keys):
root_keys.add(_key.split(':', 1)[0])
return Node.objects.filter(key__in=root_keys)
else:
children_keys = set()
for _key in granted_keys:
# 判断当前节点是否是授权节点
if node_key == _key:
return self.get_children_nodes(node_key)
# 判断当前节点有没有授权的父节点
if node_key.startswith(f'{_key}:'):
return self.get_children_nodes(node_key)
self.add_children_key(node_key, _key, children_keys)
for _key in asset_granted_keys:
self.add_children_key(node_key, _key, children_keys)
return Node.objects.filter(key__in=children_keys)
def list(self, request, *args, **kwargs):
nodes = self.get_nodes()
nodes = self.serialize_nodes(nodes)
return Response(data=nodes)
class UserGroupGrantedNodeChildrenWithAssetsAsTreeApi(UserGroupPermissionMixin, uapi.UserGrantedNodeChildrenWithAssetsAsTreeApi):
pass
class UserGroupGrantedAssetSystemUsersApi(UserGroupPermissionMixin, uapi.UserGrantedAssetSystemUsersApi):
pass
class UserGroupGrantedAssetSystemUsersApi(UserGroupMixin, uapi.UserGrantedAssetSystemUsersForAdminApi):
def get_asset_system_users_id_with_actions(self, asset):
return get_asset_system_users_id_with_actions_by_group(self.group, asset)

View File

@ -51,7 +51,8 @@ class UserGrantedK8sAppsAsTreeApi(UserGrantedK8sAppsApi):
tree_root = None
data = []
if not only_k8s_app:
tree_root = utils.construct_k8s_apps_tree_root()
amount = len(k8s_apps)
tree_root = utils.construct_k8s_apps_tree_root(amount)
data.append(tree_root)
for k8s_app in k8s_apps:
node = utils.parse_k8s_app_to_tree_node(tree_root, k8s_app)

View File

@ -3,38 +3,38 @@
import uuid
from django.shortcuts import get_object_or_404
from django.utils.decorators import method_decorator
from rest_framework.views import APIView, Response
from rest_framework.generics import (
ListAPIView, get_object_or_404, RetrieveAPIView, DestroyAPIView
)
from common.permissions import IsOrgAdminOrAppUser, IsOrgAdmin
from common.utils import get_logger
from ...utils import (
AssetPermissionUtil
)
from orgs.utils import tmp_to_root_org
from perms.utils.asset_permission import get_asset_system_users_id_with_actions_by_user
from common.permissions import IsOrgAdminOrAppUser, IsOrgAdmin, IsValidUser
from common.utils import get_logger, lazyproperty
from ...hands import User, Asset, SystemUser
from ... import serializers
from ...models import Action
from .mixin import UserAssetPermissionMixin
logger = get_logger(__name__)
__all__ = [
'RefreshAssetPermissionCacheApi',
'UserGrantedAssetSystemUsersApi',
'UserGrantedAssetSystemUsersForAdminApi',
'ValidateUserAssetPermissionApi',
'GetUserAssetPermissionActionsApi',
'UserAssetPermissionsCacheApi',
'MyGrantedAssetSystemUsersApi',
]
class GetUserAssetPermissionActionsApi(UserAssetPermissionMixin,
RetrieveAPIView):
class GetUserAssetPermissionActionsApi(RetrieveAPIView):
permission_classes = (IsOrgAdminOrAppUser,)
serializer_class = serializers.ActionsSerializer
def get_obj(self):
def get_user(self):
user_id = self.request.query_params.get('user_id', '')
user = get_object_or_404(User, id=user_id)
return user
@ -52,18 +52,18 @@ class GetUserAssetPermissionActionsApi(UserAssetPermissionMixin,
asset = get_object_or_404(Asset, id=asset_id)
system_user = get_object_or_404(SystemUser, id=system_id)
system_users_actions = self.util.get_asset_system_users_id_with_actions(asset)
system_users_actions = get_asset_system_users_id_with_actions_by_user(self.get_user(), asset)
actions = system_users_actions.get(system_user.id)
return {"actions": actions}
class ValidateUserAssetPermissionApi(UserAssetPermissionMixin, APIView):
class ValidateUserAssetPermissionApi(APIView):
permission_classes = (IsOrgAdminOrAppUser,)
def get_cache_policy(self):
return 0
def get_obj(self):
def get_user(self):
user_id = self.request.query_params.get('user_id', '')
user = get_object_or_404(User, id=user_id)
return user
@ -82,7 +82,7 @@ class ValidateUserAssetPermissionApi(UserAssetPermissionMixin, APIView):
asset = get_object_or_404(Asset, id=asset_id)
system_user = get_object_or_404(SystemUser, id=system_id)
system_users_actions = self.util.get_asset_system_users_id_with_actions(asset)
system_users_actions = get_asset_system_users_id_with_actions_by_user(self.get_user(), asset)
actions = system_users_actions.get(system_user.id)
if actions is None:
return Response({'msg': False}, status=403)
@ -91,23 +91,31 @@ class ValidateUserAssetPermissionApi(UserAssetPermissionMixin, APIView):
return Response({'msg': False}, status=403)
# TODO 删除
class RefreshAssetPermissionCacheApi(RetrieveAPIView):
permission_classes = (IsOrgAdmin,)
def retrieve(self, request, *args, **kwargs):
AssetPermissionUtil.expire_all_user_tree_cache()
return Response({'msg': True}, status=200)
class UserGrantedAssetSystemUsersApi(UserAssetPermissionMixin, ListAPIView):
class UserGrantedAssetSystemUsersForAdminApi(ListAPIView):
permission_classes = (IsOrgAdminOrAppUser,)
serializer_class = serializers.AssetSystemUserSerializer
only_fields = serializers.AssetSystemUserSerializer.Meta.only_fields
@lazyproperty
def user(self):
user_id = self.kwargs.get('pk')
return User.objects.get(id=user_id)
def get_asset_system_users_id_with_actions(self, asset):
return get_asset_system_users_id_with_actions_by_user(self.user, asset)
def get_queryset(self):
asset_id = self.kwargs.get('asset_id')
asset = get_object_or_404(Asset, id=asset_id)
system_users_with_actions = self.util.get_asset_system_users_id_with_actions(asset)
system_users_with_actions = self.get_asset_system_users_id_with_actions(asset)
system_users_id = system_users_with_actions.keys()
system_users = SystemUser.objects.filter(id__in=system_users_id)\
.only(*self.serializer_class.Meta.only_fields) \
@ -119,9 +127,18 @@ class UserGrantedAssetSystemUsersApi(UserAssetPermissionMixin, ListAPIView):
return system_users
class UserAssetPermissionsCacheApi(UserAssetPermissionMixin, DestroyAPIView):
@method_decorator(tmp_to_root_org(), name='list')
class MyGrantedAssetSystemUsersApi(UserGrantedAssetSystemUsersForAdminApi):
permission_classes = (IsValidUser,)
@lazyproperty
def user(self):
return self.request.user
# TODO 删除
class UserAssetPermissionsCacheApi(DestroyAPIView):
permission_classes = (IsOrgAdmin,)
def destroy(self, request, *args, **kwargs):
self.util.expire_user_tree_cache()
return Response(status=204)

View File

@ -1,83 +1,52 @@
# -*- coding: utf-8 -*-
#
from rest_framework.request import Request
from common.permissions import IsOrgAdminOrAppUser, IsValidUser
from common.utils import lazyproperty
from common.tree import TreeNodeSerializer
from django.db.models import QuerySet
from ..mixin import UserPermissionMixin
from ...utils import AssetPermissionUtil, ParserNode
from ...hands import Node, Asset
from users.models import User
from perms.models import UserGrantedMappingNode
class UserAssetPermissionMixin(UserPermissionMixin):
util = None
def get_cache_policy(self):
return self.request.query_params.get('cache_policy', '0')
@lazyproperty
def util(self):
cache_policy = self.get_cache_policy()
system_user_id = self.request.query_params.get("system_user")
util = AssetPermissionUtil(self.obj, cache_policy=cache_policy)
if system_user_id:
util.filter_permissions(system_users=system_user_id)
return util
@lazyproperty
def tree(self):
return self.util.get_user_tree()
class UserNodeTreeMixin:
serializer_class = TreeNodeSerializer
nodes_only_fields = ParserNode.nodes_only_fields
def parse_nodes_to_queryset(self, nodes):
if isinstance(nodes, QuerySet):
nodes = nodes.only(*self.nodes_only_fields)
_queryset = []
for node in nodes:
assets_amount = self.tree.valid_assets_amount(node.key)
if assets_amount == 0 and not node.key.startswith('-'):
continue
node.assets_amount = assets_amount
data = ParserNode.parse_node_to_tree_node(node)
_queryset.append(data)
return _queryset
def get_serializer_queryset(self, queryset):
queryset = self.parse_nodes_to_queryset(queryset)
return queryset
def get_serializer(self, queryset=None, many=True, **kwargs):
if queryset is None:
queryset = Node.objects.none()
queryset = self.get_serializer_queryset(queryset)
queryset.sort()
return super().get_serializer(queryset, many=many, **kwargs)
class UserAssetTreeMixin:
serializer_class = TreeNodeSerializer
nodes_only_fields = ParserNode.assets_only_fields
class UserNodeGrantStatusDispatchMixin:
@staticmethod
def parse_assets_to_queryset(assets, node):
_queryset = []
for asset in assets:
data = ParserNode.parse_asset_to_tree_node(node, asset)
_queryset.append(data)
return _queryset
def get_mapping_node_by_key(key, user):
return UserGrantedMappingNode.objects.get(key=key, user=user)
def get_serializer_queryset(self, queryset):
queryset = queryset.only(*self.nodes_only_fields)
_queryset = self.parse_assets_to_queryset(queryset, None)
return _queryset
def dispatch_get_data(self, key, user):
status = UserGrantedMappingNode.get_node_granted_status(key, user)
if status == UserGrantedMappingNode.GRANTED_DIRECT:
return self.get_data_on_node_direct_granted(key)
elif status == UserGrantedMappingNode.GRANTED_INDIRECT:
return self.get_data_on_node_indirect_granted(key)
else:
return self.get_data_on_node_not_granted(key)
def get_serializer(self, queryset=None, many=True, **kwargs):
if queryset is None:
queryset = Asset.objects.none()
queryset = self.get_serializer_queryset(queryset)
queryset.sort()
return super().get_serializer(queryset, many=many, **kwargs)
def get_data_on_node_direct_granted(self, key):
raise NotImplementedError
def get_data_on_node_indirect_granted(self, key):
raise NotImplementedError
def get_data_on_node_not_granted(self, key):
raise NotImplementedError
class ForAdminMixin:
permission_classes = (IsOrgAdminOrAppUser,)
kwargs: dict
@lazyproperty
def user(self):
user_id = self.kwargs.get('pk')
return User.objects.get(id=user_id)
class ForUserMixin:
permission_classes = (IsValidUser,)
request: Request
@lazyproperty
def user(self):
return self.request.user

View File

@ -1,66 +1,145 @@
# -*- coding: utf-8 -*-
#
from django.shortcuts import get_object_or_404
from django.conf import settings
from django.utils.decorators import method_decorator
from perms.api.user_permission.mixin import UserNodeGrantStatusDispatchMixin
from rest_framework.generics import ListAPIView
from rest_framework.response import Response
from django.conf import settings
from common.permissions import IsOrgAdminOrAppUser
from common.utils import get_logger, timeit
from ...hands import Node
from assets.api.mixin import SerializeToTreeNodeMixin
from common.utils import get_logger
from perms.pagination import GrantedAssetLimitOffsetPagination
from assets.models import Asset, Node, FavoriteAsset
from orgs.utils import tmp_to_root_org
from ... import serializers
from .mixin import UserAssetPermissionMixin, UserAssetTreeMixin
from ...utils.user_asset_permission import (
get_node_all_granted_assets, get_user_direct_granted_assets,
get_user_granted_all_assets
)
from .mixin import ForAdminMixin, ForUserMixin
logger = get_logger(__name__)
__all__ = [
'UserGrantedAssetsApi', 'UserGrantedAssetsAsTreeApi',
'UserGrantedNodeAssetsApi',
]
class UserGrantedAssetsApi(UserAssetPermissionMixin, ListAPIView):
permission_classes = (IsOrgAdminOrAppUser,)
@method_decorator(tmp_to_root_org(), name='list')
class UserDirectGrantedAssetsApi(ListAPIView):
"""
用户直接授权的资产的列表也就是授权规则上直接授权的资产并非是来自节点的
"""
serializer_class = serializers.AssetGrantedSerializer
only_fields = serializers.AssetGrantedSerializer.Meta.only_fields
filter_fields = ['hostname', 'ip', 'id', 'comment']
search_fields = ['hostname', 'ip', 'comment']
def filter_by_nodes(self, queryset):
node_id = self.request.query_params.get("node")
if not node_id:
return queryset
node = get_object_or_404(Node, pk=node_id)
query_all = self.request.query_params.get("all", "0") in ["1", "true"]
if query_all:
pattern = '^{0}$|^{0}:'.format(node.key)
queryset = queryset.filter(nodes__key__regex=pattern).distinct()
else:
queryset = queryset.filter(nodes=node)
return queryset
def get_queryset(self):
user = self.user
assets = get_user_direct_granted_assets(user)\
.prefetch_related('platform')\
.only(*self.only_fields)
return assets
def filter_queryset(self, queryset):
queryset = super().filter_queryset(queryset)
queryset = self.filter_by_nodes(queryset)
return queryset
@method_decorator(tmp_to_root_org(), name='list')
class UserFavoriteGrantedAssetsApi(ListAPIView):
serializer_class = serializers.AssetGrantedSerializer
only_fields = serializers.AssetGrantedSerializer.Meta.only_fields
filter_fields = ['hostname', 'ip', 'id', 'comment']
search_fields = ['hostname', 'ip', 'comment']
def get_queryset(self):
queryset = self.util.get_assets().only(*self.only_fields).order_by(
settings.TERMINAL_ASSET_LIST_SORT_BY
)
return queryset
user = self.user
assets = FavoriteAsset.get_user_favorite_assets(user)\
.prefetch_related('platform')\
.only(*self.only_fields)
return assets
class UserGrantedAssetsAsTreeApi(UserAssetTreeMixin, UserGrantedAssetsApi):
@method_decorator(tmp_to_root_org(), name='list')
class AssetsAsTreeMixin(SerializeToTreeNodeMixin):
"""
资产 序列化成树的结构返回
"""
def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())
data = self.serialize_assets(queryset, None)
return Response(data=data)
class UserDirectGrantedAssetsForAdminApi(ForAdminMixin, UserDirectGrantedAssetsApi):
pass
class UserGrantedNodeAssetsApi(UserGrantedAssetsApi):
class MyDirectGrantedAssetsApi(ForUserMixin, UserDirectGrantedAssetsApi):
pass
class UserFavoriteGrantedAssetsForAdminApi(ForAdminMixin, UserFavoriteGrantedAssetsApi):
pass
class MyFavoriteGrantedAssetsApi(ForUserMixin, UserFavoriteGrantedAssetsApi):
pass
@method_decorator(tmp_to_root_org(), name='list')
class UserDirectGrantedAssetsAsTreeForAdminApi(ForAdminMixin, AssetsAsTreeMixin, UserDirectGrantedAssetsApi):
pass
@method_decorator(tmp_to_root_org(), name='list')
class MyUngroupAssetsAsTreeApi(ForUserMixin, AssetsAsTreeMixin, UserDirectGrantedAssetsApi):
def get_queryset(self):
queryset = super().get_queryset()
if not settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE:
queryset = queryset.none()
return queryset
@method_decorator(tmp_to_root_org(), name='list')
class UserAllGrantedAssetsApi(ListAPIView):
only_fields = serializers.AssetGrantedSerializer.Meta.only_fields
def get_queryset(self):
queryset = get_user_granted_all_assets(self.user)
queryset = queryset.prefetch_related('platform')
return queryset.only(*self.only_fields)
class MyAllAssetsAsTreeApi(ForUserMixin, AssetsAsTreeMixin, UserAllGrantedAssetsApi):
search_fields = ['hostname', 'ip']
@method_decorator(tmp_to_root_org(), name='list')
class UserGrantedNodeAssetsApi(UserNodeGrantStatusDispatchMixin, ListAPIView):
serializer_class = serializers.AssetGrantedSerializer
only_fields = serializers.AssetGrantedSerializer.Meta.only_fields
filter_fields = ['hostname', 'ip', 'id', 'comment']
search_fields = ['hostname', 'ip', 'comment']
pagination_class = GrantedAssetLimitOffsetPagination
pagination_node: Node
def get_queryset(self):
node_id = self.kwargs.get("node_id")
node = get_object_or_404(Node, pk=node_id)
deep = self.request.query_params.get("all", "0") == "1"
queryset = self.util.get_nodes_assets(node, deep=deep)\
.only(*self.only_fields)
return queryset
node = Node.objects.get(id=node_id)
self.pagination_node = node
return self.dispatch_get_data(node.key, self.user)
def get_data_on_node_direct_granted(self, key):
# 如果这个节点是直接授权的(或者说祖先节点直接授权的), 获取下面的所有资产
return Node.get_node_all_assets_by_key_v2(key)
def get_data_on_node_indirect_granted(self, key):
self.pagination_node = self.get_mapping_node_by_key(key, self.user)
return get_node_all_granted_assets(self.user, key)
def get_data_on_node_not_granted(self, key):
return Asset.objects.none()
class UserGrantedNodeAssetsForAdminApi(ForAdminMixin, UserGrantedNodeAssetsApi):
pass
class MyGrantedNodeAssetsApi(ForUserMixin, UserGrantedNodeAssetsApi):
pass

View File

@ -1,82 +1,151 @@
# -*- coding: utf-8 -*-
#
from django.shortcuts import get_object_or_404
import abc
from rest_framework.generics import (
ListAPIView, get_object_or_404
ListAPIView
)
from rest_framework.response import Response
from rest_framework.request import Request
from common.permissions import IsOrgAdminOrAppUser
from orgs.utils import tmp_to_root_org
from assets.api.mixin import SerializeToTreeNodeMixin
from common.utils import get_logger
from ...hands import Node, NodeSerializer
from .mixin import ForAdminMixin, ForUserMixin, UserNodeGrantStatusDispatchMixin
from ...hands import Node, User
from ... import serializers
from .mixin import UserNodeTreeMixin, UserAssetPermissionMixin
from ...utils.user_asset_permission import (
get_indirect_granted_node_children,
get_user_granted_nodes_list_via_mapping_node,
get_top_level_granted_nodes,
rebuild_user_tree_if_need,
)
logger = get_logger(__name__)
__all__ = [
'UserGrantedNodesApi',
'UserGrantedNodesAsTreeApi',
'UserGrantedNodeChildrenApi',
'UserGrantedNodeChildrenAsTreeApi',
'UserGrantedNodesForAdminApi',
'MyGrantedNodesApi',
'MyGrantedNodesAsTreeApi',
'UserGrantedNodeChildrenForAdminApi',
'MyGrantedNodeChildrenApi',
'UserGrantedNodeChildrenAsTreeForAdminApi',
'MyGrantedNodeChildrenAsTreeApi',
'BaseGrantedNodeAsTreeApi',
'UserGrantedNodesMixin',
]
class UserGrantedNodesApi(UserAssetPermissionMixin, ListAPIView):
"""
查询用户授权的所有节点的API
"""
permission_classes = (IsOrgAdminOrAppUser,)
class _GrantedNodeStructApi(ListAPIView, metaclass=abc.ABCMeta):
@property
def user(self):
raise NotImplementedError
def get_nodes(self):
# 不使用 `get_queryset` 单独定义 `get_nodes` 的原因是
# `get_nodes` 返回的不一定是 `queryset`
raise NotImplementedError
class NodeChildrenMixin:
def get_children(self):
raise NotImplementedError
def get_nodes(self):
nodes = self.get_children()
return nodes
class BaseGrantedNodeApi(_GrantedNodeStructApi, metaclass=abc.ABCMeta):
serializer_class = serializers.NodeGrantedSerializer
nodes_only_fields = NodeSerializer.Meta.only_fields
def get_serializer_context(self):
context = super().get_serializer_context()
if self.serializer_class == serializers.NodeGrantedSerializer:
context["tree"] = self.tree
return context
def get_queryset(self):
node_keys = self.util.get_nodes()
queryset = Node.objects.filter(key__in=node_keys)\
.only(*self.nodes_only_fields)
return queryset
@tmp_to_root_org()
def list(self, request, *args, **kwargs):
rebuild_user_tree_if_need(request, self.user)
nodes = self.get_nodes()
serializer = self.get_serializer(nodes, many=True)
return Response(serializer.data)
class UserGrantedNodesAsTreeApi(UserNodeTreeMixin, UserGrantedNodesApi):
class BaseNodeChildrenApi(NodeChildrenMixin, BaseGrantedNodeApi, metaclass=abc.ABCMeta):
pass
class UserGrantedNodeChildrenApi(UserGrantedNodesApi):
node = None
root_keys = None # 如果是第一次访问,则需要把二级节点添加进去,这个 roots_keys
class BaseGrantedNodeAsTreeApi(SerializeToTreeNodeMixin, _GrantedNodeStructApi, metaclass=abc.ABCMeta):
@tmp_to_root_org()
def list(self, request: Request, *args, **kwargs):
rebuild_user_tree_if_need(request, self.user)
nodes = self.get_nodes()
nodes = self.serialize_nodes(nodes, with_asset_amount=True)
return Response(data=nodes)
def get(self, request, *args, **kwargs):
key = self.request.query_params.get("key")
pk = self.request.query_params.get("id")
node = None
if pk is not None:
node = get_object_or_404(Node, id=pk)
elif key is not None:
node = get_object_or_404(Node, key=key)
self.node = node
return super().get(request, *args, **kwargs)
class BaseNodeChildrenAsTreeApi(NodeChildrenMixin, BaseGrantedNodeAsTreeApi, metaclass=abc.ABCMeta):
pass
def get_queryset(self):
if self.node:
children = self.tree.children(self.node.key)
class UserGrantedNodeChildrenMixin(UserNodeGrantStatusDispatchMixin):
user: User
request: Request
def get_children(self):
user = self.user
key = self.request.query_params.get('key')
if not key:
nodes = list(get_top_level_granted_nodes(user))
else:
children = self.tree.children(self.tree.root)
# 默认打开组织节点下的节点
self.root_keys = [child.identifier for child in children]
for key in self.root_keys:
children.extend(self.tree.children(key))
node_keys = [n.identifier for n in children]
queryset = Node.objects.filter(key__in=node_keys)
return queryset
nodes = self.dispatch_get_data(key, user)
return nodes
def get_data_on_node_direct_granted(self, key):
return Node.objects.filter(parent_key=key)
def get_data_on_node_indirect_granted(self, key):
nodes = get_indirect_granted_node_children(self.user, key)
return nodes
def get_data_on_node_not_granted(self, key):
return Node.objects.none()
class UserGrantedNodeChildrenAsTreeApi(UserNodeTreeMixin, UserGrantedNodeChildrenApi):
class UserGrantedNodesMixin:
"""
查询用户授权的所有节点 直接授权节点 + 授权资产关联的节点
"""
user: User
def get_nodes(self):
return get_user_granted_nodes_list_via_mapping_node(self.user)
# ------------------------------------------
# 最终的 api
class UserGrantedNodeChildrenForAdminApi(ForAdminMixin, UserGrantedNodeChildrenMixin, BaseNodeChildrenApi):
pass
class MyGrantedNodeChildrenApi(ForUserMixin, UserGrantedNodeChildrenMixin, BaseNodeChildrenApi):
pass
class UserGrantedNodeChildrenAsTreeForAdminApi(ForAdminMixin, UserGrantedNodeChildrenMixin, BaseNodeChildrenAsTreeApi):
pass
class MyGrantedNodeChildrenAsTreeApi(ForUserMixin, UserGrantedNodeChildrenMixin, BaseNodeChildrenAsTreeApi):
pass
class UserGrantedNodesForAdminApi(ForAdminMixin, UserGrantedNodesMixin, BaseGrantedNodeApi):
pass
class MyGrantedNodesApi(ForUserMixin, UserGrantedNodesMixin, BaseGrantedNodeApi):
pass
class MyGrantedNodesAsTreeApi(ForUserMixin, UserGrantedNodesMixin, BaseGrantedNodeAsTreeApi):
pass
# ------------------------------------------

View File

@ -1,56 +1,118 @@
# -*- coding: utf-8 -*-
#
from common.utils import get_logger
from ...utils import ParserNode
from .mixin import UserAssetTreeMixin
from ...hands import Node
from .user_permission_nodes import UserGrantedNodesAsTreeApi
from .user_permission_nodes import UserGrantedNodeChildrenAsTreeApi
from rest_framework.generics import ListAPIView
from rest_framework.request import Request
from rest_framework.response import Response
from common.permissions import IsValidUser
from common.utils import get_logger, get_object_or_none
from .mixin import UserNodeGrantStatusDispatchMixin, ForUserMixin, ForAdminMixin
from ...utils.user_asset_permission import (
get_user_resources_q_granted_by_permissions,
get_indirect_granted_node_children, UNGROUPED_NODE_KEY, FAVORITE_NODE_KEY,
get_user_direct_granted_assets, get_top_level_granted_nodes,
get_user_granted_nodes_list_via_mapping_node,
get_user_granted_all_assets, rebuild_user_tree_if_need,
get_user_all_assetpermission_ids,
)
from assets.models import Asset, FavoriteAsset
from assets.api import SerializeToTreeNodeMixin
from orgs.utils import tmp_to_root_org
from ...hands import Node
logger = get_logger(__name__)
__all__ = [
'UserGrantedNodesAsTreeApi',
'UserGrantedNodesWithAssetsAsTreeApi',
'UserGrantedNodeChildrenAsTreeApi',
'UserGrantedNodeChildrenWithAssetsAsTreeApi',
]
class MyGrantedNodesWithAssetsAsTreeApi(SerializeToTreeNodeMixin, ListAPIView):
permission_classes = (IsValidUser,)
@tmp_to_root_org()
def list(self, request: Request, *args, **kwargs):
"""
此算法依赖 UserGrantedMappingNode
获取所有授权的节点和资产
Node = UserGrantedMappingNode + 授权节点的子节点
Asset = 授权节点的资产 + 直接授权的资产
"""
user = request.user
rebuild_user_tree_if_need(request, user)
all_nodes = get_user_granted_nodes_list_via_mapping_node(user)
all_assets = get_user_granted_all_assets(user)
data = [
*self.serialize_nodes(all_nodes, with_asset_amount=True),
*self.serialize_assets(all_assets)
]
return Response(data=data)
class UserGrantedNodesWithAssetsAsTreeApi(UserGrantedNodesAsTreeApi):
assets_only_fields = ParserNode.assets_only_fields
class UserGrantedNodeChildrenWithAssetsAsTreeForAdminApi(ForAdminMixin, UserNodeGrantStatusDispatchMixin,
SerializeToTreeNodeMixin, ListAPIView):
"""
带资产的授权树
"""
def get_serializer_queryset(self, queryset):
_queryset = super().get_serializer_queryset(queryset)
_all_assets = self.util.get_assets().only(*self.assets_only_fields)
_all_assets_map = {a.id: a for a in _all_assets}
for node in queryset:
assets_ids = self.tree.assets(node.key)
assets = [_all_assets_map[_id] for _id in assets_ids if _id in _all_assets_map]
_queryset.extend(
UserAssetTreeMixin.parse_assets_to_queryset(assets, node)
)
return _queryset
def get_data_on_node_direct_granted(self, key):
nodes = Node.objects.filter(parent_key=key)
assets = Asset.org_objects.filter(nodes__key=key).distinct()
assets = assets.prefetch_related('platform')
return nodes, assets
def get_data_on_node_indirect_granted(self, key):
user = self.user
asset_perm_ids = get_user_all_assetpermission_ids(user)
nodes = get_indirect_granted_node_children(user, key)
assets = Asset.org_objects.filter(
nodes__key=key,
).filter(
granted_by_permissions__id__in=asset_perm_ids
).distinct()
assets = assets.prefetch_related('platform')
return nodes, assets
def get_data_on_node_not_granted(self, key):
return Node.objects.none(), Asset.objects.none()
def get_data(self, key, user):
assets, nodes = [], []
if not key:
root_nodes = get_top_level_granted_nodes(user)
nodes.extend(root_nodes)
elif key == UNGROUPED_NODE_KEY:
assets = get_user_direct_granted_assets(user)
assets = assets.prefetch_related('platform')
elif key == FAVORITE_NODE_KEY:
assets = FavoriteAsset.get_user_favorite_assets(user)
else:
nodes, assets = self.dispatch_get_data(key, user)
return nodes, assets
def id2key_if_have(self):
id = self.request.query_params.get('id')
if id is not None:
node = get_object_or_none(Node, id=id)
if node:
return node.key
@tmp_to_root_org()
def list(self, request: Request, *args, **kwargs):
key = self.request.query_params.get('key')
if key is None:
key = self.id2key_if_have()
user = self.user
rebuild_user_tree_if_need(request, user)
nodes, assets = self.get_data(key, user)
tree_nodes = self.serialize_nodes(nodes, with_asset_amount=True)
tree_assets = self.serialize_assets(assets, key)
return Response(data=[*tree_nodes, *tree_assets])
class UserGrantedNodeChildrenWithAssetsAsTreeApi(UserGrantedNodeChildrenAsTreeApi):
nodes_only_fields = ParserNode.nodes_only_fields
assets_only_fields = ParserNode.assets_only_fields
def get_serializer_queryset(self, queryset):
_queryset = super().get_serializer_queryset(queryset)
nodes = []
if self.node:
nodes.append(self.node)
elif self.root_keys:
nodes = Node.objects.filter(key__in=self.root_keys)
for node in nodes:
assets = self.util.get_nodes_assets(node).only(
*self.assets_only_fields
)
_queryset.extend(
UserAssetTreeMixin.parse_assets_to_queryset(assets, node)
)
return _queryset
class MyGrantedNodeChildrenWithAssetsAsTreeApi(ForUserMixin, UserGrantedNodeChildrenWithAssetsAsTreeForAdminApi):
pass

View File

@ -59,7 +59,8 @@ class UserGrantedRemoteAppsAsTreeApi(UserGrantedRemoteAppsApi):
tree_root = None
data = []
if not only_remote_app:
tree_root = construct_remote_apps_tree_root()
amount = len(remote_apps)
tree_root = construct_remote_apps_tree_root(amount)
data.append(tree_root)
for remote_app in remote_apps:
node = parse_remote_app_to_tree_node(tree_root, remote_app)

View File

View File

@ -0,0 +1,47 @@
from django.utils.crypto import get_random_string
from perms.utils import rebuild_user_mapping_nodes_if_need_with_lock
from common.thread_pools import SingletonThreadPoolExecutor
from common.utils import get_logger
from perms.models import RebuildUserTreeTask
logger = get_logger(__name__)
class Executor(SingletonThreadPoolExecutor):
pass
executor = Executor()
def run_mapping_node_tasks():
failed_user_ids = []
ident = get_random_string()
logger.debug(f'[{ident}]mapping_node_tasks running')
while True:
task = RebuildUserTreeTask.objects.exclude(
user_id__in=failed_user_ids
).first()
if task is None:
break
user = task.user
try:
rebuild_user_mapping_nodes_if_need_with_lock(user)
except:
logger.exception(f'[{ident}]mapping_node_tasks_exception')
failed_user_ids.append(user.id)
logger.debug(f'[{ident}]mapping_node_tasks finished')
def submit_update_mapping_node_task():
executor.submit(run_mapping_node_tasks)
def submit_update_mapping_node_task_for_user(user):
executor.submit(rebuild_user_mapping_nodes_if_need_with_lock, user)

14
apps/perms/exceptions.py Normal file
View File

@ -0,0 +1,14 @@
from rest_framework import status
from django.utils.translation import gettext_lazy as _
from common.exceptions import JMSException
class AdminIsModifyingPerm(JMSException):
status_code = status.HTTP_409_CONFLICT
default_detail = _('The administrator is modifying permissions. Please wait')
class CanNotRemoveAssetPermNow(JMSException):
status_code = status.HTTP_409_CONFLICT
default_detail = _('The authorization cannot be revoked for the time being')

View File

@ -0,0 +1,53 @@
# Generated by Django 2.2.13 on 2020-09-07 08:40
import assets.models.node
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
('assets', '0057_fill_node_value_assets_amount_and_parent_key'),
('perms', '0012_k8sapppermission'),
]
operations = [
migrations.CreateModel(
name='UserGrantedMappingNode',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('created_by', models.CharField(blank=True, max_length=32, null=True, verbose_name='Created by')),
('updated_by', models.CharField(blank=True, max_length=32, null=True, verbose_name='Updated by')),
('date_created', models.DateTimeField(auto_now_add=True, null=True, verbose_name='Date created')),
('date_updated', models.DateTimeField(auto_now=True, verbose_name='Date updated')),
('key', models.CharField(db_index=True, max_length=64, verbose_name='Key')),
('granted', models.BooleanField(db_index=True, default=False)),
('asset_granted', models.BooleanField(db_index=True, default=False)),
('parent_key', models.CharField(db_index=True, default='', max_length=64, verbose_name='Parent key')),
('assets_amount', models.IntegerField(default=0)),
('node', models.ForeignKey(db_constraint=False, default=None, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='mapping_nodes', to='assets.Node')),
('user', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
],
options={
'abstract': False,
},
bases=(assets.models.node.FamilyMixin, models.Model),
),
migrations.CreateModel(
name='RebuildUserTreeTask',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('created_by', models.CharField(blank=True, max_length=32, null=True, verbose_name='Created by')),
('updated_by', models.CharField(blank=True, max_length=32, null=True, verbose_name='Updated by')),
('date_created', models.DateTimeField(auto_now_add=True, null=True, verbose_name='Date created')),
('date_updated', models.DateTimeField(auto_now=True, verbose_name='Date updated')),
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL, verbose_name='User')),
],
options={
'abstract': False,
},
),
]

View File

@ -0,0 +1,27 @@
# Generated by Django 2.2.13 on 2020-08-21 08:20
from django.db import migrations
from perms.tasks import dispatch_mapping_node_tasks
def start_build_users_perm_tree_task(apps, schema_editor):
User = apps.get_model('users', 'User')
RebuildUserTreeTask = apps.get_model('perms', 'RebuildUserTreeTask')
user_ids = User.objects.all().values_list('id', flat=True).distinct()
RebuildUserTreeTask.objects.bulk_create(
[RebuildUserTreeTask(user_id=i) for i in user_ids]
)
dispatch_mapping_node_tasks.delay()
class Migration(migrations.Migration):
dependencies = [
('perms', '0013_rebuildusertreetask_usergrantedmappingnode'),
]
operations = [
migrations.RunPython(start_build_users_perm_tree_task)
]

View File

@ -0,0 +1,54 @@
# Generated by Django 2.2.13 on 2020-09-29 09:28
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('perms', '0014_build_users_perm_tree'),
]
operations = [
migrations.AlterField(
model_name='assetpermission',
name='user_groups',
field=models.ManyToManyField(blank=True, related_name='assetpermissions', to='users.UserGroup', verbose_name='User group'),
),
migrations.AlterField(
model_name='assetpermission',
name='users',
field=models.ManyToManyField(blank=True, related_name='assetpermissions', to=settings.AUTH_USER_MODEL, verbose_name='User'),
),
migrations.AlterField(
model_name='databaseapppermission',
name='user_groups',
field=models.ManyToManyField(blank=True, related_name='databaseapppermissions', to='users.UserGroup', verbose_name='User group'),
),
migrations.AlterField(
model_name='databaseapppermission',
name='users',
field=models.ManyToManyField(blank=True, related_name='databaseapppermissions', to=settings.AUTH_USER_MODEL, verbose_name='User'),
),
migrations.AlterField(
model_name='k8sapppermission',
name='user_groups',
field=models.ManyToManyField(blank=True, related_name='k8sapppermissions', to='users.UserGroup', verbose_name='User group'),
),
migrations.AlterField(
model_name='k8sapppermission',
name='users',
field=models.ManyToManyField(blank=True, related_name='k8sapppermissions', to=settings.AUTH_USER_MODEL, verbose_name='User'),
),
migrations.AlterField(
model_name='remoteapppermission',
name='user_groups',
field=models.ManyToManyField(blank=True, related_name='remoteapppermissions', to='users.UserGroup', verbose_name='User group'),
),
migrations.AlterField(
model_name='remoteapppermission',
name='users',
field=models.ManyToManyField(blank=True, related_name='remoteapppermissions', to=settings.AUTH_USER_MODEL, verbose_name='User'),
),
]

View File

@ -2,20 +2,22 @@ import uuid
import logging
from functools import reduce
from django.db import models
from django.utils.translation import ugettext_lazy as _
from common.db import models
from common.utils import lazyproperty
from orgs.models import Organization
from orgs.utils import get_current_org
from assets.models import Asset, SystemUser, Node
from assets.models import Asset, SystemUser, Node, FamilyMixin
from .base import BasePermission
__all__ = [
'AssetPermission', 'Action',
'AssetPermission', 'Action', 'UserGrantedMappingNode', 'RebuildUserTreeTask',
]
# 使用场景
logger = logging.getLogger(__name__)
@ -97,6 +99,14 @@ class AssetPermission(BasePermission):
verbose_name = _("Asset permission")
ordering = ('name',)
@lazyproperty
def users_amount(self):
return self.users.count()
@lazyproperty
def user_groups_amount(self):
return self.user_groups.count()
@lazyproperty
def assets_amount(self):
return self.assets.count()
@ -174,3 +184,33 @@ class AssetPermission(BasePermission):
print('Error continue')
continue
class UserGrantedMappingNode(FamilyMixin, models.JMSBaseModel):
node = models.ForeignKey('assets.Node', default=None, on_delete=models.CASCADE,
db_constraint=False, null=True, related_name='mapping_nodes')
key = models.CharField(max_length=64, verbose_name=_("Key"), db_index=True) # '1:1:1:1'
user = models.ForeignKey('users.User', db_constraint=False, on_delete=models.CASCADE)
granted = models.BooleanField(default=False, db_index=True)
asset_granted = models.BooleanField(default=False, db_index=True)
parent_key = models.CharField(max_length=64, default='', verbose_name=_('Parent key'), db_index=True) # '1:1:1:1'
assets_amount = models.IntegerField(default=0)
GRANTED_DIRECT = 1
GRANTED_INDIRECT = 2
GRANTED_NONE = 0
@classmethod
def get_node_granted_status(cls, key, user):
ancestor_keys = Node.get_node_ancestor_keys(key, with_self=True)
has_granted = UserGrantedMappingNode.objects.filter(
key__in=ancestor_keys, user=user
).values_list('granted', flat=True)
if not has_granted:
return cls.GRANTED_NONE
if any(list(has_granted)):
return cls.GRANTED_DIRECT
return cls.GRANTED_INDIRECT
class RebuildUserTreeTask(models.JMSBaseModel):
user = models.ForeignKey('users.User', on_delete=models.CASCADE, verbose_name=_('User'))

View File

@ -13,7 +13,7 @@ from orgs.mixins.models import OrgManager
__all__ = [
'BasePermission',
'BasePermission', 'BasePermissionQuerySet'
]
@ -46,8 +46,8 @@ class BasePermissionManager(OrgManager):
class BasePermission(OrgModelMixin):
id = models.UUIDField(default=uuid.uuid4, primary_key=True)
name = models.CharField(max_length=128, verbose_name=_('Name'))
users = models.ManyToManyField('users.User', blank=True, verbose_name=_("User"))
user_groups = models.ManyToManyField('users.UserGroup', blank=True, verbose_name=_("User group"))
users = models.ManyToManyField('users.User', blank=True, verbose_name=_("User"), related_name='%(class)ss')
user_groups = models.ManyToManyField('users.UserGroup', blank=True, verbose_name=_("User group"), related_name='%(class)ss')
is_active = models.BooleanField(default=True, verbose_name=_('Active'))
date_start = models.DateTimeField(default=timezone.now, db_index=True, verbose_name=_("Date start"))
date_expired = models.DateTimeField(default=date_expired_default, db_index=True, verbose_name=_('Date expired'))

30
apps/perms/pagination.py Normal file
View File

@ -0,0 +1,30 @@
from rest_framework.pagination import LimitOffsetPagination
from rest_framework.request import Request
from common.utils import get_logger
logger = get_logger(__name__)
class GrantedAssetLimitOffsetPagination(LimitOffsetPagination):
def get_count(self, queryset):
exclude_query_params = {
self.limit_query_param,
self.offset_query_param,
'key', 'all', 'show_current_asset',
'cache_policy', 'display', 'draw'
}
for k, v in self._request.query_params.items():
if k not in exclude_query_params and v is not None:
return super().get_count(queryset)
node = getattr(self._view, 'pagination_node', None)
if node:
logger.debug(f'{self._request.get_full_path()} hit node.assets_amount[{node.assets_amount}]')
return node.assets_amount
else:
return super().get_count(queryset)
def paginate_queryset(self, queryset, request: Request, view=None):
self._request = request
self._view = view
return super().paginate_queryset(queryset, request, view=None)

View File

@ -56,9 +56,5 @@ class AssetPermissionSerializer(BulkOrgResourceModelSerializer):
@classmethod
def setup_eager_loading(cls, queryset):
""" Perform necessary eager loading of data. """
queryset = queryset.annotate(
users_amount=Count('users', distinct=True), user_groups_amount=Count('user_groups', distinct=True),
assets_amount=Count('assets', distinct=True), nodes_amount=Count('nodes', distinct=True),
system_users_amount=Count('system_users', distinct=True)
)
queryset = queryset.prefetch_related('users', 'user_groups', 'assets', 'nodes', 'system_users')
return queryset

View File

@ -96,7 +96,7 @@ class AssetPermissionAllAssetSerializer(serializers.Serializer):
class AssetPermissionNodeRelationSerializer(RelationMixin, serializers.ModelSerializer):
node_display = serializers.SerializerMethodField()
node_display = serializers.CharField(source='node.full_value', read_only=True)
class Meta(RelationMixin.Meta):
model = AssetPermission.nodes.through
@ -104,16 +104,6 @@ class AssetPermissionNodeRelationSerializer(RelationMixin, serializers.ModelSeri
'id', 'node', "node_display",
]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tree = Node.tree()
def get_node_display(self, obj):
if hasattr(obj, 'node_key'):
return self.tree.get_node_full_tag(obj.node_key)
else:
return obj.node.full_value
class AssetPermissionSystemUserRelationSerializer(RelationMixin, serializers.ModelSerializer):
systemuser_display = serializers.ReadOnlyField()

View File

@ -82,8 +82,6 @@ class AssetGrantedSerializer(serializers.ModelSerializer):
class NodeGrantedSerializer(serializers.ModelSerializer):
assets_amount = serializers.SerializerMethodField()
class Meta:
model = Node
fields = [
@ -91,15 +89,6 @@ class NodeGrantedSerializer(serializers.ModelSerializer):
]
read_only_fields = fields
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tree = self.context.get("tree")
def get_assets_amount(self, obj):
if not self.tree:
return 0
return self.tree.assets_amount(obj.key)
class ActionsSerializer(serializers.Serializer):
actions = ActionsField(read_only=True)

View File

@ -1,52 +1,96 @@
# -*- coding: utf-8 -*-
#
from django.db.models.signals import m2m_changed, post_save, post_delete
from django.db.models.signals import m2m_changed, pre_delete, pre_save
from django.dispatch import receiver
from perms.tasks import create_rebuild_user_tree_task, \
create_rebuild_user_tree_task_by_related_nodes_or_assets
from users.models import User, UserGroup
from assets.models import Asset
from common.utils import get_logger
from common.decorator import on_transaction_commit
from common.exceptions import M2MReverseNotAllowed
from common.const.signals import POST_ADD, POST_REMOVE, POST_CLEAR
from .models import AssetPermission, RemoteAppPermission
from .utils.asset_permission import AssetPermissionUtil
logger = get_logger(__file__)
@receiver([post_save, post_delete], sender=AssetPermission)
@on_transaction_commit
def on_permission_change(sender, action='', **kwargs):
logger.debug('Asset permission changed, refresh user tree cache')
AssetPermissionUtil.expire_all_user_tree_cache()
@receiver([pre_save], sender=AssetPermission)
def on_asset_perm_deactive(instance: AssetPermission, **kwargs):
try:
old = AssetPermission.objects.only('is_active').get(id=instance.id)
if instance.is_active != old.is_active:
create_rebuild_user_tree_task_by_asset_perm(instance)
except AssetPermission.DoesNotExist:
pass
# Todo: 检查授权规则到期,从而修改授权规则
@receiver([pre_delete], sender=AssetPermission)
def on_asset_permission_delete(instance, **kwargs):
# 授权删除之前,查出所有相关用户
create_rebuild_user_tree_task_by_asset_perm(instance)
def create_rebuild_user_tree_task_by_asset_perm(asset_perm: AssetPermission):
user_ids = set()
user_ids.update(
UserGroup.objects.filter(assetpermissions=asset_perm).distinct().values_list('users__id', flat=True)
)
user_ids.update(
User.objects.filter(assetpermissions=asset_perm).distinct().values_list('id', flat=True)
)
create_rebuild_user_tree_task(user_ids)
def need_rebuild_mapping_node(action):
return action in (POST_REMOVE, POST_ADD, POST_CLEAR)
@receiver(m2m_changed, sender=AssetPermission.nodes.through)
def on_permission_nodes_changed(sender, instance=None, action='', reverse=None, **kwargs):
if action != 'post_add' and reverse:
def on_permission_nodes_changed(instance, action, reverse, pk_set, model, **kwargs):
if reverse:
raise M2MReverseNotAllowed
if need_rebuild_mapping_node(action):
create_rebuild_user_tree_task_by_asset_perm(instance)
if action != POST_ADD:
return
logger.debug("Asset permission nodes change signal received")
nodes = kwargs['model'].objects.filter(pk__in=kwargs['pk_set'])
nodes = model.objects.filter(pk__in=pk_set)
system_users = instance.system_users.all()
# TODO 待优化
for system_user in system_users:
system_user.nodes.add(*tuple(nodes))
system_user.nodes.add(*nodes)
@receiver(m2m_changed, sender=AssetPermission.assets.through)
def on_permission_assets_changed(sender, instance=None, action='', reverse=None, **kwargs):
if action != 'post_add' and reverse:
def on_permission_assets_changed(instance, action, reverse, pk_set, model, **kwargs):
if reverse:
raise M2MReverseNotAllowed
if need_rebuild_mapping_node(action):
create_rebuild_user_tree_task_by_asset_perm(instance)
if action != POST_ADD:
return
logger.debug("Asset permission assets change signal received")
assets = kwargs['model'].objects.filter(pk__in=kwargs['pk_set'])
assets = model.objects.filter(pk__in=pk_set)
# TODO 待优化
system_users = instance.system_users.all()
for system_user in system_users:
system_user.assets.add(*tuple(assets))
@receiver(m2m_changed, sender=AssetPermission.system_users.through)
def on_asset_permission_system_users_changed(sender, instance=None, action='',
reverse=False, **kwargs):
if action != 'post_add' and reverse:
def on_asset_permission_system_users_changed(instance, action, reverse, **kwargs):
if reverse:
raise M2MReverseNotAllowed
if action != POST_ADD:
return
logger.debug("Asset permission system_users change signal received")
system_users = kwargs['model'].objects.filter(pk__in=kwargs['pk_set'])
@ -63,28 +107,42 @@ def on_asset_permission_system_users_changed(sender, instance=None, action='',
@receiver(m2m_changed, sender=AssetPermission.users.through)
def on_asset_permission_users_changed(sender, instance=None, action='',
reverse=False, **kwargs):
if action != 'post_add' and reverse:
def on_asset_permission_users_changed(instance, action, reverse, pk_set, model, **kwargs):
if reverse:
raise M2MReverseNotAllowed
if need_rebuild_mapping_node(action):
create_rebuild_user_tree_task(pk_set)
if action != POST_ADD:
return
logger.debug("Asset permission users change signal received")
users = kwargs['model'].objects.filter(pk__in=kwargs['pk_set'])
users = model.objects.filter(pk__in=pk_set)
system_users = instance.system_users.all()
# TODO 待优化
for system_user in system_users:
if system_user.username_same_with_user:
system_user.users.add(*tuple(users))
@receiver(m2m_changed, sender=AssetPermission.user_groups.through)
def on_asset_permission_user_groups_changed(sender, instance=None, action='',
reverse=False, **kwargs):
if action != 'post_add' and reverse:
def on_asset_permission_user_groups_changed(instance, action, pk_set, model,
reverse, **kwargs):
if reverse:
raise M2MReverseNotAllowed
if need_rebuild_mapping_node(action):
user_ids = User.objects.filter(groups__id__in=pk_set).distinct().values_list('id', flat=True)
create_rebuild_user_tree_task(user_ids)
if action != POST_ADD:
return
logger.debug("Asset permission user groups change signal received")
groups = kwargs['model'].objects.filter(pk__in=kwargs['pk_set'])
groups = model.objects.filter(pk__in=pk_set)
system_users = instance.system_users.all()
# TODO 待优化
for system_user in system_users:
if system_user.username_same_with_user:
system_user.groups.add(*tuple(groups))
@ -93,7 +151,7 @@ def on_asset_permission_user_groups_changed(sender, instance=None, action='',
@receiver(m2m_changed, sender=RemoteAppPermission.system_users.through)
def on_remote_app_permission_system_users_changed(sender, instance=None,
action='', reverse=False, **kwargs):
if action != 'post_add' or reverse:
if action != POST_ADD or reverse:
return
system_users = kwargs['model'].objects.filter(pk__in=kwargs['pk_set'])
logger.debug("Remote app permission system_users change signal received")
@ -110,12 +168,41 @@ def on_remote_app_permission_system_users_changed(sender, instance=None,
@receiver(m2m_changed, sender=RemoteAppPermission.users.through)
def on_remoteapps_permission_users_changed(sender, instance=None, action='',
reverse=False, **kwargs):
on_asset_permission_users_changed(sender, instance=instance, action=action,
reverse=reverse, **kwargs)
if action != POST_ADD and reverse:
return
logger.debug("Asset permission users change signal received")
users = kwargs['model'].objects.filter(pk__in=kwargs['pk_set'])
system_users = instance.system_users.all()
for system_user in system_users:
if system_user.username_same_with_user:
system_user.users.add(*tuple(users))
@receiver(m2m_changed, sender=RemoteAppPermission.user_groups.through)
def on_remoteapps_permission_user_groups_changed(sender, instance=None, action='',
reverse=False, **kwargs):
on_asset_permission_user_groups_changed(sender, instance=instance,
action=action, reverse=reverse, **kwargs)
if action != POST_ADD and reverse:
return
logger.debug("Asset permission user groups change signal received")
groups = kwargs['model'].objects.filter(pk__in=kwargs['pk_set'])
system_users = instance.system_users.all()
for system_user in system_users:
if system_user.username_same_with_user:
system_user.groups.add(*tuple(groups))
@receiver(m2m_changed, sender=Asset.nodes.through)
def on_node_asset_change(action, instance, reverse, pk_set, **kwargs):
if not need_rebuild_mapping_node(action):
return
if reverse:
asset_pk_set = pk_set
node_pk_set = [instance.id]
else:
asset_pk_set = [instance.id]
node_pk_set = pk_set
create_rebuild_user_tree_task_by_related_nodes_or_assets.delay(node_pk_set, asset_pk_set)

View File

@ -1,10 +1,122 @@
# ~*~ coding: utf-8 ~*~
from __future__ import absolute_import, unicode_literals
from datetime import timedelta
from django.db import transaction
from django.db.models import Q
from django.db.transaction import atomic
from celery import shared_task
from common.utils import get_logger, encrypt_password
from common.utils import get_logger
from common.utils.timezone import now, dt_formater, dt_parser
from users.models import User
from assets.models import Node
from perms.models import RebuildUserTreeTask, AssetPermission
from perms.utils.user_asset_permission import rebuild_user_mapping_nodes_if_need_with_lock, lock
logger = get_logger(__file__)
@shared_task(queue='node_tree')
def rebuild_user_mapping_nodes_celery_task(user_id):
user = User.objects.get(id=user_id)
try:
rebuild_user_mapping_nodes_if_need_with_lock(user)
except lock.SomeoneIsDoingThis:
pass
@shared_task(queue='node_tree')
def dispatch_mapping_node_tasks():
user_ids = RebuildUserTreeTask.objects.all().values_list('user_id', flat=True).distinct()
logger.info(f'>>> dispatch_mapping_node_tasks for users {list(user_ids)}')
for id in user_ids:
rebuild_user_mapping_nodes_celery_task.delay(id)
@shared_task(queue='check_asset_perm_expired')
@atomic()
def check_asset_permission_expired():
"""
这里的任务要足够短不要影响周期任务
"""
from settings.models import Setting
setting_name = 'last_asset_perm_expired_check'
end = now()
default_start = end - timedelta(days=36000) # Long long ago in china
defaults = {'value': dt_formater(default_start)}
setting, created = Setting.objects.get_or_create(
name=setting_name, defaults=defaults
)
if created:
start = default_start
else:
start = dt_parser(setting.value)
setting.value = dt_formater(end)
setting.save()
ids = AssetPermission.objects.filter(
date_expired__gte=start, date_expired__lte=end
).distinct().values_list('id', flat=True)
logger.info(f'>>> checking {start} to {end} have {ids} expired')
dispatch_process_expired_asset_permission.delay(list(ids))
@shared_task(queue='node_tree')
def dispatch_process_expired_asset_permission(asset_perm_ids):
user_ids = User.objects.filter(
Q(assetpermissions__id__in=asset_perm_ids) |
Q(groups__assetpermissions__id__in=asset_perm_ids)
).distinct().values_list('id', flat=True)
RebuildUserTreeTask.objects.bulk_create(
[RebuildUserTreeTask(user_id=user_id) for user_id in user_ids]
)
dispatch_mapping_node_tasks.delay()
def create_rebuild_user_tree_task(user_ids):
RebuildUserTreeTask.objects.bulk_create(
[RebuildUserTreeTask(user_id=i) for i in user_ids]
)
transaction.on_commit(dispatch_mapping_node_tasks.delay)
@shared_task(queue='node_tree')
def create_rebuild_user_tree_task_by_related_nodes_or_assets(node_ids, asset_ids):
node_ids = set(node_ids)
node_keys = set()
nodes = Node.objects.filter(id__in=node_ids)
for _node in nodes:
node_keys.update(_node.get_ancestor_keys())
node_ids.update(
Node.objects.filter(key__in=node_keys).values_list('id', flat=True)
)
asset_perm_ids = set()
asset_perm_ids.update(
AssetPermission.objects.filter(
assets__id__in=asset_ids
).values_list('id', flat=True).distinct()
)
asset_perm_ids.update(
AssetPermission.objects.filter(
nodes__id__in=node_ids
).values_list('id', flat=True).distinct()
)
user_ids = set()
user_ids.update(
User.objects.filter(
assetpermissions__id__in=asset_perm_ids
).distinct().values_list('id', flat=True)
)
user_ids.update(
User.objects.filter(
groups__assetpermissions__id__in=asset_perm_ids
).distinct().values_list('id', flat=True)
)
create_rebuild_user_tree_task(user_ids)

View File

@ -2,6 +2,7 @@
from django.urls import path, include
from rest_framework_bulk.routes import BulkRouter
from .. import api
router = BulkRouter()
@ -13,46 +14,70 @@ router.register('asset-permissions-nodes-relations', api.AssetPermissionNodeRela
router.register('asset-permissions-system-users-relations', api.AssetPermissionSystemUserRelationViewSet, 'asset-permissions-system-users-relation')
user_permission_urlpatterns = [
path('<uuid:pk>/assets/', api.UserGrantedAssetsApi.as_view(), name='user-assets'),
path('assets/', api.UserGrantedAssetsApi.as_view(), name='my-assets'),
# 统一说明:
# `<uuid:pk>`: `User.pk`
# 直接授权:在 `AssetPermission` 中关联的对象
# Assets as tree
path('<uuid:pk>/assets/tree/', api.UserGrantedAssetsAsTreeApi.as_view(), name='user-assets-as-tree'),
path('assets/tree/', api.UserGrantedAssetsAsTreeApi.as_view(), name='my-assets-as-tree'),
# ---------------------------------------------------------
# 获取用户所有直接授权的资产
# Nodes
path('<uuid:pk>/nodes/', api.UserGrantedNodesApi.as_view(), name='user-nodes'),
path('nodes/', api.UserGrantedNodesApi.as_view(), name='my-nodes'),
# 以 serializer 格式返回
path('<uuid:pk>/assets/', api.UserDirectGrantedAssetsForAdminApi.as_view(), name='user-assets'),
path('assets/', api.MyDirectGrantedAssetsApi.as_view(), name='my-assets'),
# Node children
path('<uuid:pk>/nodes/children/', api.UserGrantedNodeChildrenApi.as_view(), name='user-nodes-children'),
path('nodes/children/', api.UserGrantedNodesApi.as_view(), name='my-nodes-children'),
# Tree Node 的数据格式返回
path('<uuid:pk>/assets/tree/', api.UserDirectGrantedAssetsAsTreeForAdminApi.as_view(), name='user-assets-as-tree'),
path('assets/tree/', api.MyAllAssetsAsTreeApi.as_view(), name='my-assets-as-tree'),
path('ungroup/assets/tree/', api.MyUngroupAssetsAsTreeApi.as_view(), name='my-ungroup-assets-as-tree'),
# ^--------------------------------------------------------^
# Node as tree
path('<uuid:pk>/nodes/tree/', api.UserGrantedNodesAsTreeApi.as_view(), name='user-nodes-as-tree'),
path('nodes/tree/', api.UserGrantedNodesAsTreeApi.as_view(), name='my-nodes-as-tree'),
# 获取用户所有`直接授权的节点`与`直接授权资产`关联的节点
# 以 serializer 格式返回
path('<uuid:pk>/nodes/', api.UserGrantedNodesForAdminApi.as_view(), name='user-nodes'),
path('nodes/', api.MyGrantedNodesApi.as_view(), name='my-nodes'),
# Node with assets as tree
path('<uuid:pk>/nodes-with-assets/tree/', api.UserGrantedNodesWithAssetsAsTreeApi.as_view(), name='user-nodes-with-assets-as-tree'),
path('nodes-with-assets/tree/', api.UserGrantedNodesWithAssetsAsTreeApi.as_view(), name='my-nodes-with-assets-as-tree'),
# 以 Tree Node 的数据格式返回
path('<uuid:pk>/nodes/tree/', api.MyGrantedNodesAsTreeApi.as_view(), name='user-nodes-as-tree'),
path('nodes/tree/', api.MyGrantedNodesAsTreeApi.as_view(), name='my-nodes-as-tree'),
# ^--------------------------------------------------------^
# Node children as tree
path('<uuid:pk>/nodes/children/tree/', api.UserGrantedNodeChildrenAsTreeApi.as_view(), name='user-nodes-children-as-tree'),
path('nodes/children/tree/', api.UserGrantedNodeChildrenAsTreeApi.as_view(), name='my-nodes-children-as-tree'),
# 一层一层的获取用户授权的节点,
# 以 Serializer 的数据格式返回
path('<uuid:pk>/nodes/children/', api.UserGrantedNodeChildrenForAdminApi.as_view(), name='user-nodes-children'),
path('nodes/children/', api.MyGrantedNodeChildrenApi.as_view(), name='my-nodes-children'),
# Node children with assets as tree
path('<uuid:pk>/nodes/children-with-assets/tree/', api.UserGrantedNodeChildrenWithAssetsAsTreeApi.as_view(), name='user-nodes-children-with-assets-as-tree'),
path('nodes/children-with-assets/tree/', api.UserGrantedNodeChildrenWithAssetsAsTreeApi.as_view(), name='my-nodes-children-with-assets-as-tree'),
# 以 Tree Node 的数据格式返回
path('<uuid:pk>/nodes/children/tree/', api.UserGrantedNodeChildrenAsTreeForAdminApi.as_view(), name='user-nodes-children-as-tree'),
# 部分调用位置
# - 普通用户 -> 我的资产 -> 展开节点 时调用
path('nodes/children/tree/', api.MyGrantedNodeChildrenAsTreeApi.as_view(), name='my-nodes-children-as-tree'),
# ^--------------------------------------------------------^
# Node assets
path('<uuid:pk>/nodes/<uuid:node_id>/assets/', api.UserGrantedNodeAssetsApi.as_view(), name='user-node-assets'),
path('nodes/<uuid:node_id>/assets/', api.UserGrantedNodeAssetsApi.as_view(), name='my-node-assets'),
# 此接口会返回整棵树
# 普通用户 -> 命令执行 -> 左侧树
path('nodes-with-assets/tree/', api.MyGrantedNodesWithAssetsAsTreeApi.as_view(), name='my-nodes-with-assets-as-tree'),
# 主要用于 luna 页面,带资产的节点树
path('<uuid:pk>/nodes/children-with-assets/tree/', api.UserGrantedNodeChildrenWithAssetsAsTreeForAdminApi.as_view(), name='user-nodes-children-with-assets-as-tree'),
path('nodes/children-with-assets/tree/', api.MyGrantedNodeChildrenWithAssetsAsTreeApi.as_view(), name='my-nodes-children-with-assets-as-tree'),
# 查询授权树上某个节点的所有资产
path('<uuid:pk>/nodes/<uuid:node_id>/assets/', api.UserGrantedNodeAssetsForAdminApi.as_view(), name='user-node-assets'),
path('nodes/<uuid:node_id>/assets/', api.MyGrantedNodeAssetsApi.as_view(), name='my-node-assets'),
# 未分组的资产
path('<uuid:pk>/nodes/ungrouped/assets/', api.UserDirectGrantedAssetsForAdminApi.as_view(), name='user-ungrouped-assets'),
path('nodes/ungrouped/assets/', api.MyDirectGrantedAssetsApi.as_view(), name='my-ungrouped-assets'),
# 收藏的资产
path('<uuid:pk>/nodes/favorite/assets/', api.UserFavoriteGrantedAssetsForAdminApi.as_view(), name='user-ungrouped-assets'),
path('nodes/favorite/assets/', api.MyFavoriteGrantedAssetsApi.as_view(), name='my-ungrouped-assets'),
# Asset System users
path('<uuid:pk>/assets/<uuid:asset_id>/system-users/', api.UserGrantedAssetSystemUsersApi.as_view(), name='user-asset-system-users'),
path('assets/<uuid:asset_id>/system-users/', api.UserGrantedAssetSystemUsersApi.as_view(), name='my-asset-system-users'),
path('<uuid:pk>/assets/<uuid:asset_id>/system-users/', api.UserGrantedAssetSystemUsersForAdminApi.as_view(), name='user-asset-system-users'),
path('assets/<uuid:asset_id>/system-users/', api.MyGrantedAssetSystemUsersApi.as_view(), name='my-asset-system-users'),
# Expire user permission cache
# TODO 要废弃 Expire user permission cache
path('<uuid:pk>/asset-permissions/cache/', api.UserAssetPermissionsCacheApi.as_view(),
name='user-asset-permission-cache'),
path('asset-permissions/cache/', api.UserAssetPermissionsCacheApi.as_view(), name='my-asset-permission-cache'),

View File

@ -4,4 +4,5 @@
from .asset_permission import *
from .remote_app_permission import *
from .database_app_permission import *
from .k8s_app_permission import *
from .k8s_app_permission import *
from .user_asset_permission import *

View File

@ -1,488 +1,48 @@
# coding: utf-8
import time
import pickle
from collections import defaultdict
from functools import reduce
from django.core.cache import cache
from django.db.models import Q
from django.conf import settings
from orgs.utils import current_org
from common.utils import get_logger, timeit, lazyproperty
from common.tree import TreeNode
from assets.utils import TreeService
from common.utils import get_logger
from ..models import AssetPermission
from ..hands import Node, Asset, SystemUser, User, FavoriteAsset
from ..hands import Asset, User, UserGroup
from perms.models.base import BasePermissionQuerySet
logger = get_logger(__file__)
__all__ = [
'ParserNode', 'AssetPermissionUtil',
]
def get_asset_system_users_id_with_actions(asset_perm_queryset: BasePermissionQuerySet, asset: Asset):
nodes = asset.get_nodes()
node_keys = set()
for node in nodes:
ancestor_keys = node.get_ancestor_keys(with_self=True)
node_keys.update(ancestor_keys)
def get_user_permissions(user, include_group=True):
if include_group:
groups = user.groups.all()
arg = Q(users=user) | Q(user_groups__in=groups)
else:
arg = Q(users=user)
return AssetPermission.get_queryset_with_prefetch().filter(arg)
def get_user_group_permissions(user_group):
return AssetPermission.get_queryset_with_prefetch().filter(
user_groups=user_group
queryset = asset_perm_queryset.filter(
Q(assets=asset) |
Q(nodes__key__in=node_keys)
)
asset_protocols = asset.protocols_as_dict.keys()
values = queryset.filter(
system_users__protocol__in=asset_protocols
).distinct().values_list('system_users', 'actions')
system_users_actions = defaultdict(int)
for system_user_id, actions in values:
if None in (system_user_id, actions):
continue
system_users_actions[system_user_id] |= actions
return system_users_actions
def get_asset_permissions(asset, include_node=True):
if include_node:
nodes = asset.get_all_nodes(flat=True)
arg = Q(assets=asset) | Q(nodes__in=nodes)
else:
arg = Q(assets=asset)
return AssetPermission.objects.valid().filter(arg)
def get_node_permissions(node):
return AssetPermission.objects.valid().filter(nodes=node)
def get_system_user_permissions(system_user):
return AssetPermission.objects.valid().filter(
system_users=system_user
def get_asset_system_users_id_with_actions_by_user(user: User, asset: Asset):
queryset = AssetPermission.objects.filter(
Q(users=user) | Q(user_groups__users=user)
)
return get_asset_system_users_id_with_actions(queryset, asset)
class AssetPermissionUtilCacheMixin:
user_tree_cache_key = 'USER_PERM_TREE_{}_{}_{}'
user_tree_cache_ttl = settings.ASSETS_PERM_CACHE_TIME
user_tree_cache_enable = settings.ASSETS_PERM_CACHE_ENABLE
user_tree_map = {}
cache_policy = '0'
obj_id = ''
_filter_id = 'None'
@property
def cache_key(self):
return self.get_cache_key()
def get_cache_key(self, org_id=None):
if org_id is None:
org_id = current_org.org_id()
key = self.user_tree_cache_key.format(
org_id, self.obj_id, self._filter_id
)
return key
def expire_user_tree_cache(self):
cache.delete(self.cache_key)
@classmethod
def expire_all_user_tree_cache(cls):
expire_cache_key = "USER_TREE_EXPIRED_AT"
latest_expired = cache.get(expire_cache_key, 0)
now = time.time()
if now - latest_expired < 60:
return
key = cls.user_tree_cache_key.format('*', '1', '1')
key = key.replace('_1', '')
cache.delete_pattern(key)
cache.set(expire_cache_key, now)
@classmethod
def expire_org_tree_cache(cls, org_id=None):
if org_id is None:
org_id = current_org.org_id()
key = cls.user_tree_cache_key.format(org_id, '*', '1')
key = key.replace('_1', '')
cache.delete_pattern(key)
def set_user_tree_to_cache(self, user_tree):
data = pickle.dumps(user_tree)
cache.set(self.cache_key, data, self.user_tree_cache_ttl)
def get_user_tree_from_cache(self):
data = cache.get(self.cache_key)
if not data:
return None
user_tree = pickle.loads(data)
return user_tree
@timeit
def get_user_tree_from_cache_if_need(self):
if not self.user_tree_cache_enable:
return None
if self.cache_policy == '1':
return self.get_user_tree_from_cache()
elif self.cache_policy == '2':
self.expire_user_tree_cache()
return None
else:
return None
def set_user_tree_to_cache_if_need(self, user_tree):
if self.cache_policy == '0':
return
if not self.user_tree_cache_enable:
return None
self.set_user_tree_to_cache(user_tree)
class AssetPermissionUtil(AssetPermissionUtilCacheMixin):
get_permissions_map = {
"User": get_user_permissions,
"UserGroup": get_user_group_permissions,
"Asset": get_asset_permissions,
"Node": get_node_permissions,
"SystemUser": get_system_user_permissions,
}
assets_only = (
'id', 'hostname', 'ip', "platform", "domain_id",
'comment', 'is_active', 'os', 'org_id'
def get_asset_system_users_id_with_actions_by_group(group: UserGroup, asset: Asset):
queryset = AssetPermission.objects.filter(
user_groups=group
)
def __init__(self, obj=None, cache_policy='0'):
self.object = obj
self.cache_policy = cache_policy
self.obj_id = str(obj.id) if obj else None
self._permissions = None
self._filter_id = 'None' # 当通过filter更改 permission是标记
self.change_org_if_need()
self._user_tree = None
self._user_tree_filter_id = 'None'
if not isinstance(obj, User):
self.cache_policy = '0'
@staticmethod
def change_org_if_need():
pass
@lazyproperty
def full_tree(self):
return Node.tree()
@property
def permissions(self):
if self._permissions is not None:
return self._permissions
if self.object is None:
return AssetPermission.objects.none()
object_cls = self.object.__class__.__name__
func = self.get_permissions_map[object_cls]
permissions = func(self.object)
self._permissions = permissions
return permissions
@timeit
def filter_permissions(self, **filters):
self.cache_policy = '0'
self._permissions = self.permissions.filter(**filters)
@lazyproperty
def user_tree(self):
return self.get_user_tree()
@timeit
def get_assets_direct(self):
"""
返回直接授权的资产
并添加到tree.assets中
:return:
{asset.id: {system_user.id: actions, }, }
"""
assets_ids = self.permissions.values_list('assets', flat=True)
return Asset.objects.filter(id__in=assets_ids)
@timeit
def get_nodes_direct(self):
"""
返回直接授权的节点
并将节点添加到tree.nodes中并将节点下的资产添加到tree.assets中
:return:
{node.key: {system_user.id: actions,}, }
"""
nodes_ids = self.permissions.values_list('nodes', flat=True)
return Node.objects.filter(id__in=nodes_ids)
@timeit
def add_direct_nodes_to_user_tree(self, user_tree):
"""
将授权规则的节点放到用户树上, 从full tree中粘贴子树
"""
nodes_direct_keys = self.permissions \
.exclude(nodes__isnull=True) \
.values_list('nodes__key', flat=True) \
.distinct()
nodes_direct_keys = list(nodes_direct_keys)
# 排序,保证从上层节点开始加
nodes_direct_keys.sort(key=lambda x: len(x))
for key in nodes_direct_keys:
# 如果树上已经有这个节点,代表子树已经存在
if user_tree.contains(key):
continue
# 找到这个节点的父节点如果父节点不在树上则挂到ROOT上
parent = self.full_tree.parent(key)
if not user_tree.contains(parent.identifier):
parent = user_tree.root_node()
subtree = self.full_tree.subtree(key)
user_tree.paste(parent.identifier, subtree, deep=True)
for node in user_tree.all_nodes_itr():
assets = list(self.full_tree.assets(node.identifier))
user_tree.set_assets(node.identifier, assets)
@timeit
def add_single_assets_node_to_user_tree(self, user_tree):
"""
将单独授权的资产放到树上如果设置了单独资产到 未分组中则放到未分组中
如果没有则查询资产属于的资产组放到树上
"""
# 添加单独授权资产的节点
nodes_single_assets = defaultdict(set)
queryset = self.permissions.exclude(assets__isnull=True) \
.values_list('assets', 'assets__nodes__key') \
.distinct()
for item in queryset:
nodes_single_assets[item[1]].add(item[0])
nodes_single_assets.pop(None, None)
for key in tuple(nodes_single_assets.keys()):
if user_tree.contains(key):
nodes_single_assets.pop(key)
if not nodes_single_assets:
return
# 如果要设置到ungroup中
if settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE:
node_key = Node.ungrouped_key
node_value = Node.ungrouped_value
user_tree.create_node(
identifier=node_key, tag=node_value,
parent=user_tree.root,
)
assets = set()
for _assets in nodes_single_assets.values():
assets.update(set(_assets))
user_tree.set_assets(node_key, assets)
return
# 获取单独授权资产,并没有在授权的节点上
for key, assets in nodes_single_assets.items():
if not self.full_tree.contains(key):
continue
node = self.full_tree.get_node(key, deep=True)
parent_id = self.full_tree.parent(key).identifier
parent = user_tree.get_node(parent_id)
if not parent:
parent = user_tree.root_node()
user_tree.add_node(node, parent)
user_tree.set_assets(node.identifier, assets)
@timeit
def parse_user_tree_to_full_tree(self, user_tree):
"""
经过前面两个动作用户授权的节点已放到树上但是树不是完整的
这里要将树构造成一个完整的树
"""
# 开始修正user_tree保证父节点都在树上
root_children = user_tree.children('')
for child in root_children:
# print("child: {}".format(child.identifier))
if child.identifier.isdigit():
continue
if child.identifier.startswith('-'):
continue
ancestors = self.full_tree.ancestors(
child.identifier, with_self=False, deep=True,
)
# print("Get ancestors: {}".format(len(ancestors)))
if not ancestors:
continue
user_tree.safe_add_ancestors(child, ancestors)
def add_favorite_node_if_need(self, user_tree):
if not isinstance(self.object, User):
return
node_key = Node.favorite_key
node_value = Node.favorite_value
user_tree.create_node(
identifier=node_key, tag=node_value,
parent=user_tree.root,
)
node = user_tree.get_node(node_key)
assets_id = FavoriteAsset.get_user_favorite_assets_id(self.object)
all_valid_assets = user_tree.all_valid_assets(user_tree.root)
valid_assets_id = set(assets_id) & all_valid_assets
user_tree.set_assets(node_key, valid_assets_id)
# 必须设置这个,否则看不到个数
node.data['all_assets'] = None
def set_user_tree_to_local(self, user_tree):
self._user_tree = user_tree
self._user_tree_filter_id = self._filter_id
def get_user_tree_from_local(self):
if self._user_tree and self._user_tree_filter_id == self._filter_id:
return self._user_tree
return None
@timeit
def get_user_tree(self):
user_tree = self.get_user_tree_from_cache_if_need()
if user_tree:
return user_tree
user_tree = TreeService()
full_tree_root = self.full_tree.root_node()
user_tree.create_node(
tag=full_tree_root.tag,
identifier=full_tree_root.identifier
)
self.add_direct_nodes_to_user_tree(user_tree)
self.add_single_assets_node_to_user_tree(user_tree)
self.parse_user_tree_to_full_tree(user_tree)
self.add_favorite_node_if_need(user_tree)
self.set_user_tree_to_cache_if_need(user_tree)
self.set_user_tree_to_local(user_tree)
# print(user_tree)
return user_tree
# Todo: 是否可以获取多个资产的系统用户
def get_asset_system_users_id_with_actions(self, asset):
nodes = asset.get_nodes()
nodes_keys_related = set()
for node in nodes:
ancestor_keys = node.get_ancestor_keys(with_self=True)
nodes_keys_related.update(set(ancestor_keys))
kwargs = {"assets": asset}
if nodes_keys_related:
kwargs["nodes__key__in"] = nodes_keys_related
queryset = self.permissions
if kwargs == 1:
queryset = queryset.filter(**kwargs)
elif len(kwargs) > 1:
kwargs = [{k: v} for k, v in kwargs.items()]
args = [Q(**kw) for kw in kwargs]
args = reduce(lambda x, y: x | y, args)
queryset = queryset.filter(args)
else:
queryset = queryset.none()
asset_protocols = asset.protocols_as_dict.keys()
values = queryset.filter(system_users__protocol__in=asset_protocols).distinct()\
.values_list('system_users', 'actions')
system_users_actions = defaultdict(int)
for system_user_id, actions in values:
if None in (system_user_id, actions):
continue
for i, action in values:
system_users_actions[i] |= actions
return system_users_actions
def get_permissions_nodes_and_assets(self):
from assets.models import Node
permissions = self.permissions
nodes_keys = permissions.exclude(nodes__isnull=True)\
.values_list('nodes__key', flat=True)
assets_ids = permissions.exclude(assets__isnull=True)\
.values_list('assets', flat=True)
nodes_keys = set(nodes_keys)
assets_ids = set(assets_ids)
nodes_keys = Node.clean_children_keys(nodes_keys)
return nodes_keys, assets_ids
@timeit
def get_assets(self):
nodes_keys, assets_ids = self.get_permissions_nodes_and_assets()
queryset = Node.get_nodes_all_assets(
nodes_keys, extra_assets_ids=assets_ids
)
return queryset.valid()
def get_nodes_assets(self, node, deep=False):
if deep:
assets_ids = self.user_tree.all_assets(node.key)
else:
assets_ids = self.user_tree.assets(node.key)
queryset = Asset.objects.filter(id__in=assets_ids)
return queryset.valid()
def get_nodes(self):
return [n.identifier for n in self.user_tree.all_nodes_itr()]
def get_system_users(self):
system_users_id = self.permissions.values_list('system_users', flat=True).distinct()
return SystemUser.objects.filter(id__in=system_users_id)
class ParserNode:
nodes_only_fields = ("key", "value", "id")
assets_only_fields = ("hostname", "id", "ip", "protocols", "domain", "org_id")
system_users_only_fields = (
"id", "name", "username", "protocol", "priority", "login_mode",
)
@staticmethod
def parse_node_to_tree_node(node):
name = '{} ({})'.format(node.value, node.assets_amount)
data = {
'id': node.key,
'name': name,
'title': name,
'pId': node.parent_key,
'isParent': True,
'open': node.is_org_root(),
'meta': {
'node': {
"id": node.id,
"key": node.key,
"value": node.value,
},
'type': 'node'
}
}
tree_node = TreeNode(**data)
return tree_node
@staticmethod
def parse_asset_to_tree_node(node, asset):
icon_skin = 'file'
platform = asset.platform_base.lower()
if platform == 'windows':
icon_skin = 'windows'
elif platform == 'linux':
icon_skin = 'linux'
parent_id = node.key if node else ''
data = {
'id': str(asset.id),
'name': asset.hostname,
'title': asset.ip,
'pId': parent_id,
'isParent': False,
'open': False,
'iconSkin': icon_skin,
'nocheck': not asset.has_protocol('ssh'),
'meta': {
'type': 'asset',
'asset': {
'id': asset.id,
'hostname': asset.hostname,
'ip': asset.ip,
'protocols': asset.protocols_as_list,
'platform': asset.platform_base,
'domain': asset.domain_id,
'org_name': asset.org_name,
'org_id': asset.org_id
},
}
}
tree_node = TreeNode(**data)
return tree_node
return get_asset_system_users_id_with_actions(queryset, asset)

View File

@ -71,10 +71,10 @@ class DatabaseAppPermissionUtil:
return system_users
def construct_database_apps_tree_root():
def construct_database_apps_tree_root(amount):
tree_root = {
'id': 'ID_DATABASE_APP_ROOT',
'name': _('DatabaseApp'),
'name': '{} ({})'.format(_('DatabaseApp'), amount),
'title': 'DatabaseApp',
'pId': '',
'open': False,

View File

@ -64,10 +64,10 @@ class K8sAppPermissionUtil:
return system_users
def construct_k8s_apps_tree_root():
def construct_k8s_apps_tree_root(amount):
tree_root = {
'id': 'ID_K8S_APP_ROOT',
'name': _('KubernetesApp'),
'name': '{} ({})'.format(_('KubernetesApp'), amount),
'title': 'K8sApp',
'pId': '',
'open': False,

View File

@ -70,10 +70,10 @@ class RemoteAppPermissionUtil:
return system_users
def construct_remote_apps_tree_root():
def construct_remote_apps_tree_root(amount):
tree_root = {
'id': 'ID_REMOTE_APP_ROOT',
'name': _('RemoteApp'),
'name': '{} ({})'.format(_('RemoteApp'), amount),
'title': 'RemoteApp',
'pId': '',
'open': False,

View File

@ -0,0 +1,503 @@
from functools import reduce, wraps
from operator import or_, and_
from uuid import uuid4
import threading
import inspect
from django.conf import settings
from django.db.models import F, Q, Value, BooleanField
from django.utils.translation import gettext as _
from common.http import is_true
from common.utils import get_logger
from common.const.distributed_lock_key import UPDATE_MAPPING_NODE_TASK_LOCK_KEY
from orgs.utils import tmp_to_root_org
from common.utils.timezone import dt_formater, now
from assets.models import Node, Asset, FavoriteAsset
from django.db.transaction import atomic
from orgs import lock
from perms.models import UserGrantedMappingNode, RebuildUserTreeTask, AssetPermission
from users.models import User
logger = get_logger(__name__)
ADD = 'add'
REMOVE = 'remove'
UNGROUPED_NODE_KEY = 'ungrouped'
FAVORITE_NODE_KEY = 'favorite'
TMP_GRANTED_FIELD = '_granted'
TMP_ASSET_GRANTED_FIELD = '_asset_granted'
TMP_GRANTED_ASSETS_AMOUNT_FIELD = '_granted_assets_amount'
# 使用场景
# Asset.objects.filter(get_user_resources_q_granted_by_permissions(user))
def get_user_resources_q_granted_by_permissions(user: User):
"""
获取用户关联的 asset permission 或者 用户组关联的 asset permission 获取规则,
前提 AssetPermission 对象中的 related_name granted_by_permissions
:param user:
:return:
"""
_now = now()
return reduce(and_, (
Q(granted_by_permissions__date_start__lt=_now),
Q(granted_by_permissions__date_expired__gt=_now),
Q(granted_by_permissions__is_active=True),
(
Q(granted_by_permissions__users=user) |
Q(granted_by_permissions__user_groups__users=user)
)
))
# 使用场景
# `Node.objects.annotate(**node_annotate_mapping_node)`
node_annotate_mapping_node = {
TMP_GRANTED_FIELD: F('mapping_nodes__granted'),
TMP_ASSET_GRANTED_FIELD: F('mapping_nodes__asset_granted'),
TMP_GRANTED_ASSETS_AMOUNT_FIELD: F('mapping_nodes__assets_amount')
}
# 使用场景
# `Node.objects.annotate(**node_annotate_set_granted)`
node_annotate_set_granted = {
TMP_GRANTED_FIELD: Value(True, output_field=BooleanField()),
}
def is_direct_granted_by_annotate(node):
return getattr(node, TMP_GRANTED_FIELD, False)
def is_asset_granted(node):
return getattr(node, TMP_ASSET_GRANTED_FIELD, False)
def get_granted_assets_amount(node):
return getattr(node, TMP_GRANTED_ASSETS_AMOUNT_FIELD, 0)
def set_granted(obj):
setattr(obj, TMP_GRANTED_FIELD, True)
def set_asset_granted(obj):
setattr(obj, TMP_ASSET_GRANTED_FIELD, True)
VALUE_TEMPLATE = '{stage}:{rand_str}:thread:{thread_name}:{thread_id}:{now}'
def _generate_value(stage=lock.DOING):
cur_thread = threading.current_thread()
return VALUE_TEMPLATE.format(
stage=stage,
thread_name=cur_thread.name,
thread_id=cur_thread.ident,
now=dt_formater(now()),
rand_str=uuid4()
)
def build_user_mapping_node_lock(func):
@wraps(func)
def wrapper(*args, **kwargs):
call_args = inspect.getcallargs(func, *args, **kwargs)
user = call_args.get('user')
if user is None or not isinstance(user, User):
raise ValueError('You function must have `user` argument')
key = UPDATE_MAPPING_NODE_TASK_LOCK_KEY.format(user_id=user.id)
doing_value = _generate_value()
commiting_value = _generate_value(stage=lock.COMMITING)
try:
locked = lock.acquire(key, doing_value, timeout=600)
if not locked:
logger.error(f'update_mapping_node_task_locked_failed for user: {user.id}')
raise lock.SomeoneIsDoingThis
with atomic(savepoint=False):
func(*args, **kwargs)
ok = lock.change_lock_state_to_commiting(key, doing_value, commiting_value)
if not ok:
logger.error(f'update_mapping_node_task_timeout for user: {user.id}')
raise lock.Timeout
finally:
lock.release(key, commiting_value, doing_value)
return wrapper
@build_user_mapping_node_lock
def rebuild_user_mapping_nodes_if_need_with_lock(user: User):
tasks = RebuildUserTreeTask.objects.filter(user=user)
if tasks:
tasks.delete()
rebuild_user_mapping_nodes(user)
@build_user_mapping_node_lock
def rebuild_user_mapping_nodes_with_lock(user: User):
rebuild_user_mapping_nodes(user)
@tmp_to_root_org()
def compute_tmp_mapping_node_from_perm(user: User):
node_only_fields = ('id', 'key', 'parent_key', 'assets_amount')
# 查询直接授权节点
nodes = Node.objects.filter(
get_user_resources_q_granted_by_permissions(user)
).distinct().only(*node_only_fields)
granted_key_set = {_node.key for _node in nodes}
def _has_ancestor_granted(node):
"""
判断一个节点是否有授权过的祖先节点
"""
ancestor_keys = set(node.get_ancestor_keys())
return ancestor_keys & granted_key_set
key2leaf_nodes_mapper = {}
# 给授权节点设置 _granted 标识,同时去重
for _node in nodes:
if _has_ancestor_granted(_node):
continue
if _node.key not in key2leaf_nodes_mapper:
set_granted(_node)
key2leaf_nodes_mapper[_node.key] = _node
# 查询授权资产关联的节点设置
def process_direct_granted_assets():
# 查询直接授权资产
asset_ids = Asset.objects.filter(
get_user_resources_q_granted_by_permissions(user)
).distinct().values_list('id', flat=True)
# 查询授权资产关联的节点设置
granted_asset_nodes = Node.objects.filter(
assets__id__in=asset_ids
).distinct().only(*node_only_fields)
# 给资产授权关联的节点设置 _asset_granted 标识,同时去重
for _node in granted_asset_nodes:
if _has_ancestor_granted(_node):
continue
if _node.key not in key2leaf_nodes_mapper:
key2leaf_nodes_mapper[_node.key] = _node
set_asset_granted(key2leaf_nodes_mapper[_node.key])
if not settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE:
process_direct_granted_assets()
leaf_nodes = key2leaf_nodes_mapper.values()
# 计算所有祖先节点
ancestor_keys = set()
for _node in leaf_nodes:
ancestor_keys.update(_node.get_ancestor_keys())
# 从祖先节点 key 中去掉同时也是叶子节点的 key
ancestor_keys -= key2leaf_nodes_mapper.keys()
# 查出祖先节点
ancestors = Node.objects.filter(key__in=ancestor_keys).only(*node_only_fields)
return [*leaf_nodes, *ancestors]
def create_mapping_nodes(user, nodes, clear=True):
to_create = []
for node in nodes:
_granted = getattr(node, TMP_GRANTED_FIELD, False)
_asset_granted = getattr(node, TMP_ASSET_GRANTED_FIELD, False)
_granted_assets_amount = getattr(node, TMP_GRANTED_ASSETS_AMOUNT_FIELD, 0)
to_create.append(UserGrantedMappingNode(
user=user,
node=node,
key=node.key,
parent_key=node.parent_key,
granted=_granted,
asset_granted=_asset_granted,
assets_amount=_granted_assets_amount,
))
if clear:
UserGrantedMappingNode.objects.filter(user=user).delete()
UserGrantedMappingNode.objects.bulk_create(to_create)
def set_node_granted_assets_amount(user, node):
"""
不依赖`UserGrantedMappingNode`直接查询授权计算资产数量
"""
_granted = getattr(node, TMP_GRANTED_FIELD, False)
if _granted:
assets_amount = node.assets_amount
else:
if settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE:
assets_amount = count_direct_granted_node_assets(user, node.key)
else:
assets_amount = count_node_all_granted_assets(user, node.key)
setattr(node, TMP_GRANTED_ASSETS_AMOUNT_FIELD, assets_amount)
def rebuild_user_mapping_nodes(user):
logger.info(f'>>> {dt_formater(now())} start rebuild {user} mapping nodes')
tmp_nodes = compute_tmp_mapping_node_from_perm(user)
for _node in tmp_nodes:
set_node_granted_assets_amount(user, _node)
create_mapping_nodes(user, tmp_nodes)
logger.info(f'>>> {dt_formater(now())} end rebuild {user} mapping nodes')
def get_user_granted_nodes_list_via_mapping_node(user):
"""
这里的 granted nodes, 是整棵树需要的node推算出来的也算
:param user:
:return:
"""
# 获取 `UserGrantedMappingNode` 中对应的 `Node`
nodes = Node.objects.filter(
mapping_nodes__user=user,
).annotate(
**node_annotate_mapping_node
).distinct()
key_to_node_mapper = {}
nodes_descendant_q = Q()
for node in nodes:
if not is_direct_granted_by_annotate(node):
# 未授权的节点资产数量设置为 `UserGrantedMappingNode` 中的数量
node.assets_amount = get_granted_assets_amount(node)
else:
# 直接授权的节点
# 增加查询后代节点的过滤条件
nodes_descendant_q |= Q(key__startswith=f'{node.key}:')
key_to_node_mapper[node.key] = node
if nodes_descendant_q:
descendant_nodes = Node.objects.filter(
nodes_descendant_q
).annotate(
**node_annotate_set_granted
)
for node in descendant_nodes:
key_to_node_mapper[node.key] = node
all_nodes = key_to_node_mapper.values()
return all_nodes
def get_user_granted_all_assets(user, via_mapping_node=True):
asset_perm_ids = get_user_all_assetpermission_ids(user)
if via_mapping_node:
granted_node_keys = UserGrantedMappingNode.objects.filter(
user=user, granted=True,
).values_list('key', flat=True).distinct()
else:
granted_node_keys = Node.objects.filter(
granted_by_permissions__id__in=asset_perm_ids
).distinct().values_list('key', flat=True)
granted_node_keys = Node.clean_children_keys(granted_node_keys)
granted_node_q = Q()
for _key in granted_node_keys:
granted_node_q |= Q(nodes__key__startswith=f'{_key}:')
granted_node_q |= Q(nodes__key=_key)
assets__id = get_user_direct_granted_assets(user, asset_perm_ids).values_list('id', flat=True)
q = granted_node_q | Q(id__in=list(assets__id))
return Asset.org_objects.filter(q).distinct()
def get_node_all_granted_assets(user: User, key):
"""
此算法依据 `UserGrantedMappingNode` 的数据查询
1. 查询该节点下的直接授权节点
2. 查询该节点下授权资产关联的节点
"""
assets = Asset.objects.none()
# 查询该节点下的授权节点
granted_mapping_nodes = UserGrantedMappingNode.objects.filter(
user=user, granted=True,
).filter(
Q(key__startswith=f'{key}:') | Q(key=key)
)
# 根据授权节点构建资产查询条件
granted_nodes_qs = []
for _node in granted_mapping_nodes:
granted_nodes_qs.append(Q(nodes__key__startswith=f'{_node.key}:'))
granted_nodes_qs.append(Q(nodes__key=_node.key))
# 查询该节点下的资产授权节点
only_asset_granted_mapping_nodes = UserGrantedMappingNode.objects.filter(
user=user,
asset_granted=True,
granted=False,
).filter(Q(key__startswith=f'{key}:') | Q(key=key))
# 根据资产授权节点构建查询
only_asset_granted_nodes_qs = []
for _node in only_asset_granted_mapping_nodes:
only_asset_granted_nodes_qs.append(Q(nodes__id=_node.node_id))
q = []
if granted_nodes_qs:
q.append(reduce(or_, granted_nodes_qs))
if only_asset_granted_nodes_qs:
only_asset_granted_nodes_q = reduce(or_, only_asset_granted_nodes_qs)
only_asset_granted_nodes_q &= get_user_resources_q_granted_by_permissions(user)
q.append(only_asset_granted_nodes_q)
if q:
assets = Asset.objects.filter(reduce(or_, q)).distinct()
return assets
def get_direct_granted_node_ids(user: User, key):
granted_q = get_user_resources_q_granted_by_permissions(user)
# 先查出该节点下的直接授权节点
granted_nodes = Node.objects.filter(
Q(key__startswith=f'{key}:') | Q(key=key)
).filter(granted_q).distinct().only('id', 'key')
node_ids = set()
# 根据直接授权节点查询他们的子节点
q = Q()
for _node in granted_nodes:
q |= Q(key__startswith=f'{_node.key}:')
node_ids.add(_node.id)
if q:
descendant_ids = Node.objects.filter(q).values_list('id', flat=True).distinct()
node_ids.update(descendant_ids)
return node_ids
def get_node_all_granted_assets_from_perm(user: User, key):
"""
此算法依据 `AssetPermission` 的数据查询
1. 查询该节点下的直接授权节点
2. 查询该节点下授权资产关联的节点
"""
granted_q = get_user_resources_q_granted_by_permissions(user)
# 直接授权资产查询条件
q = (Q(nodes__key__startswith=f'{key}:') | Q(nodes__key=key)) & granted_q
node_ids = get_direct_granted_node_ids(user, key)
q |= Q(nodes__id__in=node_ids)
asset_qs = Asset.objects.filter(q).distinct()
return asset_qs
def get_direct_granted_node_assets_from_perm(user: User, key):
node_ids = get_direct_granted_node_ids(user, key)
asset_qs = Asset.objects.filter(nodes__id__in=node_ids).distinct()
return asset_qs
def count_node_all_granted_assets(user: User, key):
return get_node_all_granted_assets_from_perm(user, key).count()
def count_direct_granted_node_assets(user: User, key):
return get_direct_granted_node_assets_from_perm(user, key).count()
def get_indirect_granted_node_children(user, key=''):
"""
获取用户授权树中未授权节点的子节点
只匹配在 `UserGrantedMappingNode` 中存在的节点
"""
nodes = Node.objects.filter(
mapping_nodes__user=user,
parent_key=key
).annotate(
_granted_assets_amount=F('mapping_nodes__assets_amount'),
_granted=F('mapping_nodes__granted')
).distinct()
# 设置节点授权资产数量
for _node in nodes:
if not is_direct_granted_by_annotate(_node):
_node.assets_amount = get_granted_assets_amount(_node)
return nodes
def get_top_level_granted_nodes(user):
nodes = list(get_indirect_granted_node_children(user, key=''))
if settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE:
ungrouped_node = get_ungrouped_node(user)
nodes.insert(0, ungrouped_node)
favorite_node = get_favorite_node(user)
nodes.insert(0, favorite_node)
return nodes
def get_user_all_assetpermission_ids(user: User):
asset_perm_ids = set()
asset_perm_ids.update(
AssetPermission.objects.valid().filter(users=user).distinct().values_list('id', flat=True)
)
asset_perm_ids.update(
AssetPermission.objects.valid().filter(user_groups__users=user).distinct().values_list('id', flat=True)
)
return asset_perm_ids
def get_user_direct_granted_assets(user, asset_perm_ids=None):
if asset_perm_ids is None:
asset_perm_ids = get_user_all_assetpermission_ids(user)
assets = Asset.org_objects.filter(granted_by_permissions__id__in=asset_perm_ids).distinct()
return assets
def count_user_direct_granted_assets(user):
count = get_user_direct_granted_assets(user).values_list('id').count()
return count
def get_ungrouped_node(user):
assets_amount = count_user_direct_granted_assets(user)
return Node(
id=UNGROUPED_NODE_KEY,
key=UNGROUPED_NODE_KEY,
value=_(UNGROUPED_NODE_KEY),
assets_amount=assets_amount
)
def get_favorite_node(user):
assets_amount = FavoriteAsset.get_user_favorite_assets(user).values_list('id').count()
return Node(
id=FAVORITE_NODE_KEY,
key=FAVORITE_NODE_KEY,
value=_(FAVORITE_NODE_KEY),
assets_amount=assets_amount
)
def rebuild_user_tree_if_need(request, user):
"""
升级授权树策略后用户的数据可能还未初始化为防止用户显示没有数据
先检查 MappingNode 如果没有数据同步创建用户授权树
"""
if is_true(request.query_params.get('rebuild_tree')) or \
not UserGrantedMappingNode.objects.filter(user=user).exists():
try:
rebuild_user_mapping_nodes_with_lock(user)
except lock.SomeoneIsDoingThis:
# 您的数据正在初始化,请稍等
raise lock.SomeoneIsDoingThis(detail=_('Please wait while your data is being initialized'))

View File

@ -278,6 +278,7 @@ class PublicSettingApi(generics.RetrieveAPIView):
"SECURITY_VIEW_AUTH_NEED_MFA": settings.SECURITY_VIEW_AUTH_NEED_MFA,
"SECURITY_MFA_VERIFY_TTL": settings.SECURITY_MFA_VERIFY_TTL,
"SECURITY_COMMAND_EXECUTION": settings.SECURITY_COMMAND_EXECUTION,
"LOGIN_TITLE": settings.XPACK_INTERFACE_LOGIN_TITLE,
"LOGO_URLS": settings.LOGO_URLS,
"TICKETS_ENABLED": settings.TICKETS_ENABLED,
"PASSWORD_RULE": {

View File

@ -145,7 +145,7 @@
@fa-twitter: "\f099";
@fa-facebook: "\f09a";
@fa-github: "\f09b";
@fa-unlock: "\f09c";
@fa-release: "\f09c";
@fa-credit-card: "\f09d";
@fa-rss: "\f09e";
@fa-hdd-o: "\f0a0";
@ -283,7 +283,7 @@
@fa-html5: "\f13b";
@fa-css3: "\f13c";
@fa-anchor: "\f13d";
@fa-unlock-alt: "\f13e";
@fa-release-alt: "\f13e";
@fa-bullseye: "\f140";
@fa-ellipsis-h: "\f141";
@fa-ellipsis-v: "\f142";

View File

@ -537,7 +537,7 @@ q-93 26 -192 26t-192 -26q-16 11 -42.5 27t-83.5 38.5t-85 13.5q-45 -113 -8 -204q-7
t-5 -11.5t9 -14t13 -12l7 -5q22 -10 43.5 -38t31.5 -51l10 -23q13 -38 44 -61.5t67 -30t69.5 -7t55.5 3.5l23 4q0 -38 0.5 -88.5t0.5 -54.5q0 -18 -13 -30t-40 -7q-232 77 -378.5 277.5t-146.5 451.5q0 209 103 385.5t279.5 279.5t385.5 103zM291 305q3 7 -7 12
q-10 3 -13 -2q-3 -7 7 -12q9 -6 13 2zM322 271q7 5 -2 16q-10 9 -16 3q-7 -5 2 -16q10 -10 16 -3zM352 226q9 7 0 19q-8 13 -17 6q-9 -5 0 -18t17 -7zM394 184q8 8 -4 19q-12 12 -20 3q-9 -8 4 -19q12 -12 20 -3zM451 159q3 11 -13 16q-15 4 -19 -7t13 -15q15 -6 19 6z
M514 154q0 13 -17 11q-16 0 -16 -11q0 -13 17 -11q16 0 16 11zM572 164q-2 11 -18 9q-16 -3 -14 -15t18 -8t14 14z" />
<glyph glyph-name="unlock" unicode="&#xf09c;" horiz-adv-x="1664"
<glyph glyph-name="release" unicode="&#xf09c;" horiz-adv-x="1664"
d="M1664 960v-256q0 -26 -19 -45t-45 -19h-64q-26 0 -45 19t-19 45v256q0 106 -75 181t-181 75t-181 -75t-75 -181v-192h96q40 0 68 -28t28 -68v-576q0 -40 -28 -68t-68 -28h-960q-40 0 -68 28t-28 68v576q0 40 28 68t68 28h672v192q0 185 131.5 316.5t316.5 131.5
t316.5 -131.5t131.5 -316.5z" />
<glyph glyph-name="credit_card" unicode="&#xf09d;" horiz-adv-x="1920"

Before

Width:  |  Height:  |  Size: 434 KiB

After

Width:  |  Height:  |  Size: 434 KiB

View File

@ -125,7 +125,7 @@ class SessionReplayViewSet(AsyncApiMixin, viewsets.ViewSet):
if serializer.is_valid():
file = serializer.validated_data['file']
name, err = session.save_to_storage(file)
name, err = session.save_replay_to_storage(file)
if not name:
msg = "Failed save replay `{}`: {}".format(session_id, err)
logger.error(msg)
@ -155,6 +155,7 @@ class SessionReplayViewSet(AsyncApiMixin, viewsets.ViewSet):
return data
def is_need_async(self):
return False
if self.action != 'retrieve':
return False
return True

View File

@ -20,9 +20,12 @@ class BaseStorageViewSetMixin:
def destroy(self, request, *args, **kwargs):
instance = self.get_object()
if not instance.can_delete():
if instance.in_defaults():
data = {'msg': _('Deleting the default storage is not allowed')}
return Response(data=data, status=status.HTTP_400_BAD_REQUEST)
if instance.is_using():
data = {'msg': _('Cannot delete storage that is being used')}
return Response(data=data, status=status.HTTP_400_BAD_REQUEST)
return super().destroy(request, *args, **kwargs)

View File

@ -1,5 +1,7 @@
from importlib import import_module
from django.conf import settings
from django.utils.functional import LazyObject
from .command.serializers import SessionCommandSerializer
from ..const import COMMAND_STORAGE_TYPE_SERVER
@ -17,6 +19,13 @@ def get_command_storage():
return storage
def get_server_replay_storage():
from jms_storage import get_object_storage
config = settings.SERVER_REPLAY_STORAGE
storage = get_object_storage(config)
return storage
def get_terminal_command_storages():
from ..models import CommandStorage
storage_list = {}
@ -40,3 +49,15 @@ def get_multi_command_storage():
return storage
class ServerReplayStorage(LazyObject):
def _setup(self):
self._wrapped = get_server_replay_storage()
class ServerCommandStorage(LazyObject):
def _setup(self):
self._wrapped = get_command_storage()
server_command_storage = ServerCommandStorage()
server_replay_storage = ServerReplayStorage()

View File

@ -0,0 +1,3 @@
"""
Replay backend 已移动到 jms_storage 模块中
"""

View File

@ -253,14 +253,18 @@ class Session(OrgModelMixin):
return False
return True
def save_to_storage(self, f):
def save_replay_to_storage(self, f):
local_path = self.get_local_path()
try:
name = default_storage.save(local_path, f)
return name, None
except OSError as e:
return None, e
if settings.SERVER_REPLAY_STORAGE:
from .tasks import upload_session_replay_to_external_storage
upload_session_replay_to_external_storage.delay(str(self.id))
return name, None
@classmethod
def set_sessions_active(cls, sessions_id):
data = {cls.ACTIVE_CACHE_KEY_PREFIX.format(i): i for i in sessions_id}
@ -401,8 +405,8 @@ class CommandStorage(CommonModelMixin):
storage = jms_storage.get_log_storage(self.config)
return storage.ping()
def can_delete(self):
return not self.in_defaults()
def is_using(self):
return Terminal.objects.filter(command_storage=self.name).exists()
class ReplayStorage(CommonModelMixin):
@ -454,6 +458,5 @@ class ReplayStorage(CommonModelMixin):
src = os.path.join(settings.BASE_DIR, 'common', target)
return storage.is_valid(src, target)
def can_delete(self):
return not self.in_defaults()
def is_using(self):
return Terminal.objects.filter(replay_storage=self.name).exists()

View File

@ -9,11 +9,12 @@ from django.utils import timezone
from django.conf import settings
from django.core.files.storage import default_storage
from ops.celery.decorator import (
register_as_period_task, after_app_ready_start, after_app_shutdown_clean_periodic
)
from .models import Status, Session, Command
from .backends import server_replay_storage
from .utils import find_session_replay_local
CACHE_REFRESH_INTERVAL = 10
@ -68,3 +69,26 @@ def clean_expired_session_period():
# 删除session记录
session.delete()
@shared_task
def upload_session_replay_to_external_storage(session_id):
logger.info(f'Start upload session to external storage: {session_id}')
session = Session.objects.filter(id=session_id).first()
if not session:
logger.error(f'Session db item not found: {session_id}')
return
local_path, foobar = find_session_replay_local(session)
if not local_path:
logger.error(f'Session replay not found, may be upload error: {local_path}')
return
abs_path = default_storage.path(local_path)
remote_path = session.get_rel_replay_path()
ok, err = server_replay_storage.upload(abs_path, remote_path)
if not ok:
logger.error(f'Session replay upload to external error: {err}')
return
try:
default_storage.delete(local_path)
except:
pass
return

View File

@ -2,43 +2,18 @@
#
import os
from django.core.cache import cache
from django.conf import settings
from django.core.files.storage import default_storage
import jms_storage
from assets.models import Asset, SystemUser
from users.models import User
from common.utils import get_logger
from .const import USERS_CACHE_KEY, ASSETS_CACHE_KEY, SYSTEM_USER_CACHE_KEY
from .backends import server_replay_storage
from .models import ReplayStorage
logger = get_logger(__name__)
def get_session_asset_list():
return Asset.objects.values_list('hostname', flat=True)
def get_session_user_list():
return User.objects.exclude(role=User.ROLE.APP).values_list('username', flat=True)
def get_session_system_user_list():
return SystemUser.objects.values_list('username', flat=True)
def get_user_list_from_cache():
return cache.get(USERS_CACHE_KEY)
def get_asset_list_from_cache():
return cache.get(ASSETS_CACHE_KEY)
def get_system_user_list_from_cache():
return cache.get(SYSTEM_USER_CACHE_KEY)
def find_session_replay_local(session):
# 新版本和老版本的文件后缀不同
session_path = session.get_rel_replay_path() # 存在外部存储上的路径
@ -62,6 +37,8 @@ def download_session_replay(session):
for storage in replay_storages
if not storage.in_defaults()
}
if settings.SERVER_REPLAY_STORAGE:
configs['SERVER_REPLAY_STORAGE'] = settings.SERVER_REPLAY_STORAGE
if not configs:
msg = "Not found replay file, and not remote storage set"
return None, msg

View File

@ -63,13 +63,14 @@ class RequestAssetPermTicketViewSet(JMSModelViewSet):
meta = instance.meta
ips = ', '.join(meta.get('ips', []))
confirmed_assets = ', '.join(meta.get('confirmed_assets', []))
confirmed_system_users = ', '.join(meta.get('confirmed_system_users', []))
return textwrap.dedent(f'''\
{_('IP group')}: {ips}
{_('Hostname')}: {meta.get('hostname', '')}
{_('System user')}: {meta.get('system_user', '')}
{_('Confirmed assets')}: {confirmed_assets}
{_('Confirmed system user')}: {meta.get('confirmed_system_user', '')}
{_('Confirmed system users')}: {confirmed_system_users}
''')
@action(detail=True, methods=[POST], permission_classes=[IsAssignee, IsValidUser])
@ -95,15 +96,15 @@ class RequestAssetPermTicketViewSet(JMSModelViewSet):
if len(assets) != len(confirmed_assets):
raise ConfirmedAssetsChanged(detail=_('Confirmed assets changed'))
confirmed_system_user = meta.get('confirmed_system_user')
if not confirmed_system_user:
raise NotHaveConfirmedSystemUser(detail=_('Confirm system-user first'))
confirmed_system_users = meta.get('confirmed_system_users', [])
if not confirmed_system_users:
raise NotHaveConfirmedSystemUser(detail=_('Confirm system-users first'))
system_user = get_object_or_none(SystemUser, id=confirmed_system_user)
if system_user is None:
raise ConfirmedSystemUserChanged(detail=_('Confirmed system-user changed'))
system_users = SystemUser.objects.filter(id__in=confirmed_system_users)
if system_users is None:
raise ConfirmedSystemUserChanged(detail=_('Confirmed system-users changed'))
self._create_asset_permission(instance, assets, system_user)
self._create_asset_permission(instance, assets, system_users)
return Response({'detail': _('Succeed')})
@action(detail=True, methods=[POST], permission_classes=[IsAssignee | IsObjectOwner])
@ -113,7 +114,7 @@ class RequestAssetPermTicketViewSet(JMSModelViewSet):
instance.save()
return Response({'detail': _('Succeed')})
def _create_asset_permission(self, instance: Ticket, assets, system_user):
def _create_asset_permission(self, instance: Ticket, assets, system_users):
meta = instance.meta
request = self.request
actions = meta.get('actions', Action.CONNECT)
@ -135,7 +136,7 @@ class RequestAssetPermTicketViewSet(JMSModelViewSet):
request.user,
self._get_extra_comment(instance))
ap = AssetPermission.objects.create(**ap_kwargs)
ap.system_users.add(system_user)
ap.system_users.add(*system_users)
ap.assets.add(*assets)
ap.users.add(instance.user)

View File

@ -0,0 +1,31 @@
# Generated by BaiJiangjie 2020-09-29 18:31
from django.db import migrations
def migrate_ticket_meta_confirmed_system_user_to_confirmed_system_users(apps, schema_editor):
ticket_model = apps.get_model("tickets", "Ticket")
tickets = ticket_model.origin_objects.all()
for ticket in tickets:
meta = ticket.meta
confirmed_system_user = meta.get('confirmed_system_user')
if confirmed_system_user:
confirmed_system_users = [confirmed_system_user]
else:
confirmed_system_users = []
meta.update({
'confirmed_system_users': confirmed_system_users
})
ticket.meta = meta
ticket.save()
class Migration(migrations.Migration):
dependencies = [
('tickets', '0004_ticket_comment'),
]
operations = [
migrations.RunPython(migrate_ticket_meta_confirmed_system_user_to_confirmed_system_users)
]

View File

@ -33,19 +33,20 @@ class RequestAssetPermTicketSerializer(serializers.ModelSerializer):
source='meta.confirmed_assets',
default=list, required=False,
label=_('Confirmed assets'))
confirmed_system_user = serializers.UUIDField(source='meta.confirmed_system_user',
default='', required=False,
confirmed_system_users = serializers.ListField(child=serializers.UUIDField(),
source='meta.confirmed_system_users',
default=list, required=False,
label=_('Confirmed system user'))
assets_waitlist_url = serializers.SerializerMethodField()
system_user_waitlist_url = serializers.SerializerMethodField()
system_users_waitlist_url = serializers.SerializerMethodField()
class Meta:
model = Ticket
mini_fields = ['id', 'title']
small_fields = [
'status', 'action', 'date_created', 'date_updated', 'system_user_waitlist_url',
'status', 'action', 'date_created', 'date_updated', 'system_users_waitlist_url',
'type', 'type_display', 'action_display', 'ips', 'confirmed_assets',
'date_start', 'date_expired', 'confirmed_system_user', 'hostname',
'date_start', 'date_expired', 'confirmed_system_users', 'hostname',
'assets_waitlist_url', 'system_user', 'org_id', 'actions', 'comment'
]
m2m_fields = [
@ -96,7 +97,7 @@ class RequestAssetPermTicketSerializer(serializers.ModelSerializer):
raise serializers.ValidationError(_('Field `assignees` must be organization admin or superuser'))
return attrs
def get_system_user_waitlist_url(self, instance: Ticket):
def get_system_users_waitlist_url(self, instance: Ticket):
if not self._is_assignee(instance):
return None
return reverse('api-assets:system-user-list')
@ -190,16 +191,14 @@ class RequestAssetPermTicketSerializer(serializers.ModelSerializer):
meta['date_expired'] = dt_formater(date_expired)
# UUID 的转换
confirmed_system_user = meta.get('confirmed_system_user')
if confirmed_system_user:
meta['confirmed_system_user'] = str(confirmed_system_user)
confirmed_system_users = meta.get('confirmed_system_users')
if confirmed_system_users:
meta['confirmed_system_users'] = [str(system_user) for system_user in confirmed_system_users]
confirmed_assets = meta.get('confirmed_assets')
if confirmed_assets:
new_confirmed_assets = []
for asset in confirmed_assets:
new_confirmed_assets.append(str(asset))
meta['confirmed_assets'] = new_confirmed_assets
meta['confirmed_assets'] = [str(asset) for asset in confirmed_assets]
with tmp_to_root_org():
return super().save(**kwargs)
@ -220,7 +219,7 @@ class RequestAssetPermTicketSerializer(serializers.ModelSerializer):
def _pop_confirmed_fields(self):
meta = self.validated_data['meta']
meta.pop('confirmed_assets', None)
meta.pop('confirmed_system_user', None)
meta.pop('confirmed_system_users', None)
def _is_assignee(self, obj: Ticket):
user = self.context['request'].user

View File

@ -393,16 +393,20 @@ class TokenMixin:
@classmethod
def validate_reset_password_token(cls, token):
if not token:
return None
key = cls.CACHE_KEY_USER_RESET_PASSWORD_PREFIX.format(token)
value = cache.get(key)
if not value:
return None
try:
key = cls.CACHE_KEY_USER_RESET_PASSWORD_PREFIX.format(token)
value = cache.get(key)
user_id = value.get('id', '')
email = value.get('email', '')
user = cls.objects.get(id=user_id, email=email)
return user
except (AttributeError, cls.DoesNotExist) as e:
logger.error(e, exc_info=True)
user = None
return user
return None
def set_cache(self, token):
key = self.CACHE_KEY_USER_RESET_PASSWORD_PREFIX.format(token)

View File

@ -9,6 +9,7 @@ from django_cas_ng.signals import cas_user_authenticated
from jms_oidc_rp.signals import openid_create_or_update_user
from perms.tasks import create_rebuild_user_tree_task
from common.utils import get_logger
from .signals import post_user_create
from .models import User
@ -27,14 +28,12 @@ def on_user_create(sender, user=None, **kwargs):
@receiver(m2m_changed, sender=User.groups.through)
def on_user_groups_change(sender, instance=None, action='', **kwargs):
"""
资产节点发生变化时刷新节点
"""
def on_user_groups_change(instance, action, reverse, pk_set, **kwargs):
if action.startswith('post'):
logger.debug("User group member change signal recv: {}".format(instance))
from perms.utils import AssetPermissionUtil
AssetPermissionUtil.expire_all_user_tree_cache()
if reverse:
create_rebuild_user_tree_task(pk_set)
else:
create_rebuild_user_tree_task([instance.id])
@receiver(cas_user_authenticated)

View File

@ -13,7 +13,7 @@ from django.core.cache import cache
from datetime import datetime
from common.tasks import send_mail_async
from common.utils import reverse, get_object_or_none
from common.utils import reverse, get_object_or_none, get_request_ip_or_data, get_request_user_agent
from .models import User
@ -112,6 +112,48 @@ def send_reset_password_mail(user):
send_mail_async.delay(subject, message, recipient_list, html_message=message)
def send_reset_password_success_mail(request, user):
subject = _('Reset password success')
recipient_list = [user.email]
message = _("""
Hi %(name)s:
<br>
<br>
Your JumpServer password has just been successfully updated.
<br>
<br>
If the password update was not initiated by you, your account may have security issues.
It is recommended that you log on to the JumpServer immediately and change your password.
<br>
<br>
If you have any questions, you can contact the administrator.
<br>
<br>
---
<br>
<br>
IP Address: %(ip_address)s
<br>
<br>
Browser: %(browser)s
<br>
""") % {
'name': user.name,
'ip_address': get_request_ip_or_data(request),
'browser': get_request_user_agent(request),
}
if settings.DEBUG:
logger.debug(message)
send_mail_async.delay(subject, message, recipient_list, html_message=message)
def send_password_expiration_reminder_mail(user):
subject = _('Security notice')
recipient_list = [user.email]

View File

@ -16,7 +16,8 @@ from common.utils import get_object_or_none
from common.permissions import PermissionsMixin, IsValidUser
from ...models import User
from ...utils import (
send_reset_password_mail, get_password_check_rules, check_password_rules
send_reset_password_mail, get_password_check_rules, check_password_rules,
send_reset_password_success_mail
)
from ... import forms
@ -126,6 +127,7 @@ class UserResetPasswordView(FormView):
user.reset_password(password)
User.expired_reset_password_token(token)
send_reset_password_success_mail(self.request, user)
return redirect('authentication:reset-password-success')

44
jms
View File

@ -155,20 +155,20 @@ def is_running(s, unlink=True):
def parse_service(s):
all_services = [
'gunicorn', 'celery_ansible', 'celery_default',
'beat', 'flower', 'daphne',
]
web_services = ['gunicorn', 'flower', 'daphne']
celery_services = ["celery_ansible", "celery_default", "celery_node_tree", "check_asset_perm_expired"]
task_services = celery_services + ['beat']
all_services = web_services + task_services
if s == 'all':
return all_services
elif s == "web":
return ['gunicorn', 'flower', 'daphne']
return web_services
elif s == 'ws':
return ['daphne']
elif s == "task":
return ["celery_ansible", "celery_default", "beat"]
return task_services
elif s == "celery":
return ["celery_ansible", "celery_default"]
return celery_services
elif "," in s:
services = set()
for i in s.split(','):
@ -220,6 +220,16 @@ def get_start_celery_default_kwargs():
return get_start_worker_kwargs('celery', 4)
def get_start_celery_node_tree_kwargs():
print("\n- Start Celery as Distributed Task Queue: NodeTree")
return get_start_worker_kwargs('node_tree', 2)
def get_start_celery_check_asset_perm_expired_kwargs():
print("\n- Start Celery as Distributed Task Queue: CheckAseetPermissionExpired")
return get_start_worker_kwargs('check_asset_perm_expired', 1)
def get_start_worker_kwargs(queue, num):
# Todo: Must set this environment, otherwise not no ansible result return
os.environ.setdefault('PYTHONOPTIMIZE', '1')
@ -261,19 +271,11 @@ def get_start_flower_kwargs():
def get_start_beat_kwargs():
print("\n- Start Beat as Periodic Task Scheduler")
os.environ.setdefault('PYTHONOPTIMIZE', '1')
if os.getuid() == 0:
os.environ.setdefault('C_FORCE_ROOT', '1')
scheduler = "django_celery_beat.schedulers:DatabaseScheduler"
utils_dir = os.path.join(BASE_DIR, 'utils')
cmd = [
'celery', 'beat',
'-A', 'ops',
'-l', 'INFO',
'--scheduler', scheduler,
'--max-interval', '60'
sys.executable, 'start_celery_beat.py',
]
return {"cmd": cmd, 'cwd': APPS_DIR}
return {"cmd": cmd, 'cwd': utils_dir}
processes = {}
@ -289,7 +291,7 @@ def watch_services():
for s, p in processes.items():
print("{} Check service status: {} -> ".format(now, s), end='')
try:
p.wait(timeout=1)
p.wait(timeout=1) # 不wait子进程可能无法回收
except subprocess.TimeoutExpired:
pass
ok = is_running(s)
@ -363,6 +365,8 @@ def start_service(s):
"gunicorn": get_start_gunicorn_kwargs,
"celery_ansible": get_start_celery_ansible_kwargs,
"celery_default": get_start_celery_default_kwargs,
"celery_node_tree": get_start_celery_node_tree_kwargs,
"check_asset_perm_expired": get_start_celery_check_asset_perm_expired_kwargs,
"beat": get_start_beat_kwargs,
"flower": get_start_flower_kwargs,
"daphne": get_start_daphne_kwargs,
@ -388,7 +392,7 @@ def start_service(s):
def start_services_and_watch(s):
logging.info(time.ctime())
logging.info('Jumpserver version {}, more see https://www.jumpserver.org'.format(
logging.info('JumpServer version {}, more see https://www.jumpserver.org'.format(
__version__)
)

View File

@ -0,0 +1,34 @@
#!/usr/bin/env python
#
import os
import sys
import subprocess
import redis_lock
from redis import Redis
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
APPS_DIR = os.path.join(BASE_DIR, 'apps')
sys.path.insert(0, BASE_DIR)
from apps.jumpserver.const import CONFIG
os.environ.setdefault('PYTHONOPTIMIZE', '1')
if os.getuid() == 0:
os.environ.setdefault('C_FORCE_ROOT', '1')
redis = Redis(host=CONFIG.REDIS_HOST, port=CONFIG.REDIS_PORT, password=CONFIG.REDIS_PASSWORD)
scheduler = "django_celery_beat.schedulers:DatabaseScheduler"
cmd = [
'celery', 'beat',
'-A', 'ops',
'-l', 'INFO',
'--scheduler', scheduler,
'--max-interval', '60'
]
with redis_lock.Lock(redis, name="beat-distribute-start-lock", expire=60, auto_renewal=True):
print("Get beat lock start to run it")
code = subprocess.call(cmd, cwd=APPS_DIR)
sys.exit(code)