mirror of https://github.com/jumpserver/jumpserver
perf: 修改过滤的 q
parent
7c850a8a1e
commit
4e5ab5a605
|
@ -1,7 +1,7 @@
|
|||
from common.api import JMSBulkModelViewSet
|
||||
from ..models import LoginACL
|
||||
from .. import serializers
|
||||
from ..filters import LoginAclFilter
|
||||
from ..models import LoginACL
|
||||
|
||||
__all__ = ['LoginACLViewSet']
|
||||
|
||||
|
@ -11,4 +11,3 @@ class LoginACLViewSet(JMSBulkModelViewSet):
|
|||
filterset_class = LoginAclFilter
|
||||
search_fields = ('name',)
|
||||
serializer_class = serializers.LoginACLSerializer
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from rest_framework.generics import CreateAPIView
|
||||
from rest_framework.response import Response
|
||||
|
||||
from common.db.fields import JSONManyToManyField
|
||||
from common.utils import reverse, lazyproperty
|
||||
from orgs.utils import tmp_to_org
|
||||
from .. import serializers
|
||||
|
@ -36,9 +35,9 @@ class LoginAssetCheckAPI(CreateAPIView):
|
|||
|
||||
# 用户满足的 acls
|
||||
queryset = LoginAssetACL.objects.all()
|
||||
q = JSONManyToManyField.get_filter_q(LoginAssetACL, 'users', user)
|
||||
q = LoginAssetACL.users.get_filter_q(LoginAssetACL, 'users', user)
|
||||
queryset = queryset.filter(q)
|
||||
q = JSONManyToManyField.get_filter_q(LoginAssetACL, 'assets', asset)
|
||||
q = LoginAssetACL.assets.get_filter_q(LoginAssetACL, 'assets', asset)
|
||||
queryset = queryset.filter(q)
|
||||
account_username = self.serializer.validated_data.get('account_username')
|
||||
queryset = queryset.filter(accounts__contains=account_username)
|
||||
|
|
|
@ -3,19 +3,20 @@
|
|||
|
||||
import ipaddress
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
|
||||
from django.apps import apps
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.core.validators import MinValueValidator, MaxValueValidator
|
||||
from django.db import models
|
||||
from django.db.models import Q
|
||||
from django.db.models import Q, Manager
|
||||
from django.utils.encoding import force_text
|
||||
from django.utils.translation import ugettext_lazy as _
|
||||
from rest_framework.utils.encoders import JSONEncoder
|
||||
|
||||
from common.local import add_encrypted_field_set
|
||||
from common.utils import signer, crypto
|
||||
from common.utils import signer, crypto, contains_ip
|
||||
from .validators import PortRangeValidator
|
||||
|
||||
__all__ = [
|
||||
|
@ -321,58 +322,82 @@ class RelatedManager:
|
|||
continue
|
||||
return q
|
||||
|
||||
def _get_filter_attrs_q(self, value, to_model):
|
||||
filters = Q()
|
||||
# 特殊情况有这几种,
|
||||
# 1. 像 资产中的 type 和 category,集成自 Platform。所以不能直接查询
|
||||
# 2. 像 资产中的 nodes,不是简单的 m2m,是树 的关系
|
||||
# 3. 像 用户中的 orgs 也不是简单的 m2m,也是计算出来的
|
||||
# get_filter_{}_attr_q 处理复杂的
|
||||
custom_attr_filter = getattr(to_model, "get_json_filter_attr_q", None)
|
||||
for attr in value["attrs"]:
|
||||
if not isinstance(attr, dict):
|
||||
continue
|
||||
|
||||
name = attr.get('name')
|
||||
val = attr.get('value')
|
||||
match = attr.get('match', 'exact')
|
||||
if name is None or val is None:
|
||||
continue
|
||||
|
||||
print("Has custom filter: {}".format(custom_attr_filter))
|
||||
if custom_attr_filter:
|
||||
custom_filter_q = custom_attr_filter(name, val, match)
|
||||
print("Custom filter: {}".format(custom_filter_q))
|
||||
if custom_filter_q:
|
||||
filters &= custom_filter_q
|
||||
continue
|
||||
|
||||
if match == 'ip_in':
|
||||
q = self.get_ip_in_q(name, val)
|
||||
elif match in ("exact", "contains", "startswith", "endswith", "regex", "gte", "lte", "gt", "lt"):
|
||||
lookup = "{}__{}".format(name, match)
|
||||
q = Q(**{lookup: val})
|
||||
elif match == "not":
|
||||
q = ~Q(**{name: val})
|
||||
elif match == "m2m":
|
||||
if not isinstance(val, list):
|
||||
val = [val]
|
||||
q = Q(**{"{}__in".format(name): val})
|
||||
elif match == "in" and isinstance(val, list):
|
||||
if '*' not in val:
|
||||
lookup = "{}__in".format(name)
|
||||
q = Q(**{lookup: val})
|
||||
else:
|
||||
q = Q()
|
||||
else:
|
||||
if val == '*':
|
||||
q = Q()
|
||||
else:
|
||||
q = Q(**{name: val})
|
||||
|
||||
filters &= q
|
||||
return filters
|
||||
|
||||
def _get_queryset(self):
|
||||
model = apps.get_model(self.field.to)
|
||||
to_model = apps.get_model(self.field.to)
|
||||
value = self.value
|
||||
if hasattr(to_model, "get_queryset"):
|
||||
queryset = to_model.get_queryset()
|
||||
else:
|
||||
queryset = to_model.objects.all()
|
||||
|
||||
if not value or not isinstance(value, dict):
|
||||
return model.objects.none()
|
||||
return queryset.none()
|
||||
|
||||
if value["type"] == "all":
|
||||
return model.objects.all()
|
||||
return queryset
|
||||
elif value["type"] == "ids" and isinstance(value.get("ids"), list):
|
||||
return model.objects.filter(id__in=value["ids"])
|
||||
return queryset.filter(id__in=value["ids"])
|
||||
elif value["type"] == "attrs" and isinstance(value.get("attrs"), list):
|
||||
filters = Q()
|
||||
excludes = Q()
|
||||
for attr in value["attrs"]:
|
||||
if not isinstance(attr, dict):
|
||||
continue
|
||||
|
||||
name = attr.get('name')
|
||||
val = attr.get('value')
|
||||
match = attr.get('match', 'exact')
|
||||
rel = attr.get('rel', 'and')
|
||||
if name is None or val is None:
|
||||
continue
|
||||
|
||||
if match == 'ip_in':
|
||||
q = self.get_ip_in_q(name, val)
|
||||
elif match in ("exact", "contains", "startswith", "endswith", "regex"):
|
||||
lookup = "{}__{}".format(name, match)
|
||||
q = Q(**{lookup: val})
|
||||
elif match == "not":
|
||||
q = ~Q(**{name: val})
|
||||
elif match == "in" and isinstance(val, list):
|
||||
if '*' not in val:
|
||||
lookup = "{}__in".format(name)
|
||||
q = Q(**{lookup: val})
|
||||
else:
|
||||
q = Q()
|
||||
else:
|
||||
if val == '*':
|
||||
q = Q()
|
||||
else:
|
||||
q = Q(**{name: val})
|
||||
|
||||
if rel == 'or':
|
||||
filters |= q
|
||||
elif rel == 'not':
|
||||
excludes |= q
|
||||
else:
|
||||
filters &= q
|
||||
return model.objects.filter(filters).exclude(excludes)
|
||||
q = self._get_filter_attrs_q(value, to_model)
|
||||
return queryset.filter(q)
|
||||
else:
|
||||
return model.objects.none()
|
||||
return queryset.none()
|
||||
|
||||
def get_attr_q(self):
|
||||
q = self._get_filter_attrs_q(self.value)
|
||||
return q
|
||||
|
||||
def all(self):
|
||||
return self._get_queryset()
|
||||
|
@ -415,40 +440,68 @@ class JSONManyToManyDescriptor:
|
|||
value = value.value
|
||||
manager.set(value)
|
||||
|
||||
def test_is(self):
|
||||
print("Self.field is", self.field)
|
||||
print("Self.field to", self.field.to)
|
||||
print("Self.field model", self.field.model)
|
||||
print("Self.field column", self.field.column)
|
||||
print("Self.field to", self.field.__dict__)
|
||||
def is_match(self, obj, attr_rules):
|
||||
# m2m 的情况
|
||||
# 自定义的情况:比如 nodes, category
|
||||
res = True
|
||||
to_model = apps.get_model(self.field.to)
|
||||
src_model = self.field.model
|
||||
field_name = self.field.name
|
||||
custom_attr_filter = getattr(src_model, "get_filter_{}_attr_q".format(field_name), None)
|
||||
|
||||
@staticmethod
|
||||
def attr_to_regex(attr):
|
||||
"""将属性规则转换为正则表达式"""
|
||||
name, value, match = attr['name'], attr['value'], attr['match']
|
||||
if match == 'contains':
|
||||
return r'.*{}.*'.format(escape_regex(value))
|
||||
elif match == 'startswith':
|
||||
return r'^{}.*'.format(escape_regex(value))
|
||||
elif match == 'endswith':
|
||||
return r'.*{}$'.format(escape_regex(value))
|
||||
elif match == 'regex':
|
||||
return value
|
||||
elif match == 'not':
|
||||
return r'^(?!^{}$)'.format(escape_regex(value))
|
||||
elif match == 'in':
|
||||
values = '|'.join(map(escape_regex, value))
|
||||
return r'^(?:{})$'.format(values)
|
||||
else:
|
||||
return r'^{}$'.format(escape_regex(value))
|
||||
|
||||
def is_match(self, attr_dict, attr_rules):
|
||||
custom_q = Q()
|
||||
for rule in attr_rules:
|
||||
value = attr_dict.get(rule['name'], '')
|
||||
regex = self.attr_to_regex(rule)
|
||||
if not re.match(regex, value):
|
||||
return False
|
||||
return True
|
||||
value = getattr(obj, rule['name'], '')
|
||||
rule_value = rule.get('value', '')
|
||||
rule_match = rule.get('match', 'exact')
|
||||
|
||||
if custom_attr_filter:
|
||||
q = custom_attr_filter(rule['name'], rule_value, rule_match)
|
||||
if q:
|
||||
custom_q &= q
|
||||
continue
|
||||
|
||||
if rule_match == 'in':
|
||||
res &= value in rule_value
|
||||
elif rule_match == 'exact':
|
||||
res &= value == rule_value
|
||||
elif rule_match == 'contains':
|
||||
res &= rule_value in value
|
||||
elif rule_match == 'startswith':
|
||||
res &= str(value).startswith(str(rule_value))
|
||||
elif rule_match == 'endswith':
|
||||
res &= str(value).endswith(str(rule_value))
|
||||
elif rule_match == 'regex':
|
||||
res &= re.match(rule_value, value)
|
||||
elif rule_match == 'not':
|
||||
res &= value != rule_value
|
||||
elif rule['match'] == 'gte':
|
||||
res &= value >= rule_value
|
||||
elif rule['match'] == 'lte':
|
||||
res &= value <= rule_value
|
||||
elif rule['match'] == 'gt':
|
||||
res &= value > rule_value
|
||||
elif rule['match'] == 'lt':
|
||||
res &= value < rule_value
|
||||
elif rule['match'] == 'ip_in':
|
||||
if isinstance(rule_value, str):
|
||||
rule_value = [rule_value]
|
||||
res &= contains_ip(value, rule_value)
|
||||
elif rule['match'] == 'm2m':
|
||||
if isinstance(value, Manager):
|
||||
value = value.values_list('id', flat=True)
|
||||
value = set(map(str, value))
|
||||
rule_value = set(map(str, rule_value))
|
||||
res &= rule_value.issubset(value)
|
||||
else:
|
||||
logging.error("unknown match: {}".format(rule['match']))
|
||||
res &= False
|
||||
|
||||
if not res:
|
||||
return res
|
||||
if custom_q:
|
||||
res &= to_model.objects.filter(custom_q).filter(id=obj.id).exists()
|
||||
return res
|
||||
|
||||
def get_filter_q(self, instance):
|
||||
model_cls = self.field.model
|
||||
|
@ -457,18 +510,12 @@ class JSONManyToManyDescriptor:
|
|||
queryset_id_attrs = model_cls.objects \
|
||||
.filter(**{'{}__type'.format(field_name): 'attrs'}) \
|
||||
.values_list('id', '{}__attrs'.format(field_name))
|
||||
instance_attr = {k: v for k, v in instance.__dict__.items() if not k.startswith('_')}
|
||||
ids = [str(_id) for _id, attr_rules in queryset_id_attrs if self.is_match(instance_attr, attr_rules)]
|
||||
ids = [str(_id) for _id, attr_rules in queryset_id_attrs if self.is_match(instance, attr_rules)]
|
||||
if ids:
|
||||
q |= Q(id__in=ids)
|
||||
return q
|
||||
|
||||
|
||||
def escape_regex(s):
|
||||
"""转义字符串中的正则表达式特殊字符"""
|
||||
return re.sub('[.*+?^${}()|[\\]]', r'\\\g<0>', s)
|
||||
|
||||
|
||||
class JSONManyToManyField(models.JSONField):
|
||||
def __init__(self, to, *args, **kwargs):
|
||||
self.to = to
|
||||
|
@ -490,7 +537,7 @@ class JSONManyToManyField(models.JSONField):
|
|||
e = ValueError(_(
|
||||
"Invalid JSON data for JSONManyToManyField, should be like "
|
||||
"{'type': 'all'} or {'type': 'ids', 'ids': []} "
|
||||
"or {'type': 'attrs', 'attrs': [{'name': 'ip', 'match': 'exact', 'value': 'value', 'rel': 'and|or|not'}}"
|
||||
"or {'type': 'attrs', 'attrs': [{'name': 'ip', 'match': 'exact', 'value': '1.1.1.1'}}"
|
||||
))
|
||||
if not isinstance(val, dict):
|
||||
raise e
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import ipaddress
|
||||
import socket
|
||||
from ipaddress import ip_network, ip_address
|
||||
|
||||
|
@ -75,6 +76,23 @@ def contains_ip(ip, ip_group):
|
|||
return False
|
||||
|
||||
|
||||
def is_ip(self, ip, rule_value):
|
||||
if rule_value == '*':
|
||||
return True
|
||||
elif '/' in rule_value:
|
||||
network = ipaddress.ip_network(rule_value)
|
||||
return ip in network.hosts()
|
||||
elif '-' in rule_value:
|
||||
start_ip, end_ip = rule_value.split('-')
|
||||
start_ip = ipaddress.ip_address(start_ip)
|
||||
end_ip = ipaddress.ip_address(end_ip)
|
||||
return start_ip <= ip <= end_ip
|
||||
elif len(rule_value.split('.')) == 4:
|
||||
return ip == rule_value
|
||||
else:
|
||||
return ip.startswith(rule_value)
|
||||
|
||||
|
||||
def get_ip_city(ip):
|
||||
if not ip or not isinstance(ip, str):
|
||||
return _("Invalid address")
|
||||
|
|
|
@ -668,7 +668,33 @@ class MFAMixin:
|
|||
return backend
|
||||
|
||||
|
||||
class User(AuthMixin, TokenMixin, RoleMixin, MFAMixin, AbstractUser):
|
||||
class JSONFilterMixin:
|
||||
"""
|
||||
users = JSONManyToManyField('users.User', blank=True, null=True)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_json_filter_attr_q(name, value, match):
|
||||
from rbac.models import RoleBinding
|
||||
from orgs.utils import current_org
|
||||
|
||||
if name == 'system_roles':
|
||||
user_id = RoleBinding.objects \
|
||||
.filter(role__in=value, scope='system') \
|
||||
.values_list('user_id', flat=True)
|
||||
return models.Q(id__in=user_id)
|
||||
elif name == 'org_roles':
|
||||
kwargs = dict(role__in=value, scope='org')
|
||||
if not current_org.is_root():
|
||||
kwargs['org_id'] = current_org.id
|
||||
|
||||
user_id = RoleBinding.objects.filter(**kwargs) \
|
||||
.values_list('user_id', flat=True)
|
||||
return models.Q(id__in=user_id)
|
||||
return None
|
||||
|
||||
|
||||
class User(AuthMixin, TokenMixin, RoleMixin, MFAMixin, JSONFilterMixin, AbstractUser):
|
||||
class Source(models.TextChoices):
|
||||
local = 'local', _('Local')
|
||||
ldap = 'ldap', 'LDAP/AD'
|
||||
|
|
Loading…
Reference in New Issue