perf: 修改过滤的 q

pull/10327/head
ibuler 2023-05-18 13:14:32 +08:00
parent 7c850a8a1e
commit 4e5ab5a605
5 changed files with 181 additions and 92 deletions

View File

@ -1,7 +1,7 @@
from common.api import JMSBulkModelViewSet from common.api import JMSBulkModelViewSet
from ..models import LoginACL
from .. import serializers from .. import serializers
from ..filters import LoginAclFilter from ..filters import LoginAclFilter
from ..models import LoginACL
__all__ = ['LoginACLViewSet'] __all__ = ['LoginACLViewSet']
@ -11,4 +11,3 @@ class LoginACLViewSet(JMSBulkModelViewSet):
filterset_class = LoginAclFilter filterset_class = LoginAclFilter
search_fields = ('name',) search_fields = ('name',)
serializer_class = serializers.LoginACLSerializer serializer_class = serializers.LoginACLSerializer

View File

@ -1,7 +1,6 @@
from rest_framework.generics import CreateAPIView from rest_framework.generics import CreateAPIView
from rest_framework.response import Response from rest_framework.response import Response
from common.db.fields import JSONManyToManyField
from common.utils import reverse, lazyproperty from common.utils import reverse, lazyproperty
from orgs.utils import tmp_to_org from orgs.utils import tmp_to_org
from .. import serializers from .. import serializers
@ -36,9 +35,9 @@ class LoginAssetCheckAPI(CreateAPIView):
# 用户满足的 acls # 用户满足的 acls
queryset = LoginAssetACL.objects.all() queryset = LoginAssetACL.objects.all()
q = JSONManyToManyField.get_filter_q(LoginAssetACL, 'users', user) q = LoginAssetACL.users.get_filter_q(LoginAssetACL, 'users', user)
queryset = queryset.filter(q) queryset = queryset.filter(q)
q = JSONManyToManyField.get_filter_q(LoginAssetACL, 'assets', asset) q = LoginAssetACL.assets.get_filter_q(LoginAssetACL, 'assets', asset)
queryset = queryset.filter(q) queryset = queryset.filter(q)
account_username = self.serializer.validated_data.get('account_username') account_username = self.serializer.validated_data.get('account_username')
queryset = queryset.filter(accounts__contains=account_username) queryset = queryset.filter(accounts__contains=account_username)

View File

@ -3,19 +3,20 @@
import ipaddress import ipaddress
import json import json
import logging
import re import re
from django.apps import apps from django.apps import apps
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.core.validators import MinValueValidator, MaxValueValidator from django.core.validators import MinValueValidator, MaxValueValidator
from django.db import models from django.db import models
from django.db.models import Q from django.db.models import Q, Manager
from django.utils.encoding import force_text from django.utils.encoding import force_text
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework.utils.encoders import JSONEncoder from rest_framework.utils.encoders import JSONEncoder
from common.local import add_encrypted_field_set from common.local import add_encrypted_field_set
from common.utils import signer, crypto from common.utils import signer, crypto, contains_ip
from .validators import PortRangeValidator from .validators import PortRangeValidator
__all__ = [ __all__ = [
@ -321,58 +322,82 @@ class RelatedManager:
continue continue
return q return q
def _get_filter_attrs_q(self, value, to_model):
filters = Q()
# 特殊情况有这几种,
# 1. 像 资产中的 type 和 category集成自 Platform。所以不能直接查询
# 2. 像 资产中的 nodes不是简单的 m2m是树 的关系
# 3. 像 用户中的 orgs 也不是简单的 m2m也是计算出来的
# get_filter_{}_attr_q 处理复杂的
custom_attr_filter = getattr(to_model, "get_json_filter_attr_q", None)
for attr in value["attrs"]:
if not isinstance(attr, dict):
continue
name = attr.get('name')
val = attr.get('value')
match = attr.get('match', 'exact')
if name is None or val is None:
continue
print("Has custom filter: {}".format(custom_attr_filter))
if custom_attr_filter:
custom_filter_q = custom_attr_filter(name, val, match)
print("Custom filter: {}".format(custom_filter_q))
if custom_filter_q:
filters &= custom_filter_q
continue
if match == 'ip_in':
q = self.get_ip_in_q(name, val)
elif match in ("exact", "contains", "startswith", "endswith", "regex", "gte", "lte", "gt", "lt"):
lookup = "{}__{}".format(name, match)
q = Q(**{lookup: val})
elif match == "not":
q = ~Q(**{name: val})
elif match == "m2m":
if not isinstance(val, list):
val = [val]
q = Q(**{"{}__in".format(name): val})
elif match == "in" and isinstance(val, list):
if '*' not in val:
lookup = "{}__in".format(name)
q = Q(**{lookup: val})
else:
q = Q()
else:
if val == '*':
q = Q()
else:
q = Q(**{name: val})
filters &= q
return filters
def _get_queryset(self): def _get_queryset(self):
model = apps.get_model(self.field.to) to_model = apps.get_model(self.field.to)
value = self.value value = self.value
if hasattr(to_model, "get_queryset"):
queryset = to_model.get_queryset()
else:
queryset = to_model.objects.all()
if not value or not isinstance(value, dict): if not value or not isinstance(value, dict):
return model.objects.none() return queryset.none()
if value["type"] == "all": if value["type"] == "all":
return model.objects.all() return queryset
elif value["type"] == "ids" and isinstance(value.get("ids"), list): elif value["type"] == "ids" and isinstance(value.get("ids"), list):
return model.objects.filter(id__in=value["ids"]) return queryset.filter(id__in=value["ids"])
elif value["type"] == "attrs" and isinstance(value.get("attrs"), list): elif value["type"] == "attrs" and isinstance(value.get("attrs"), list):
filters = Q() q = self._get_filter_attrs_q(value, to_model)
excludes = Q() return queryset.filter(q)
for attr in value["attrs"]:
if not isinstance(attr, dict):
continue
name = attr.get('name')
val = attr.get('value')
match = attr.get('match', 'exact')
rel = attr.get('rel', 'and')
if name is None or val is None:
continue
if match == 'ip_in':
q = self.get_ip_in_q(name, val)
elif match in ("exact", "contains", "startswith", "endswith", "regex"):
lookup = "{}__{}".format(name, match)
q = Q(**{lookup: val})
elif match == "not":
q = ~Q(**{name: val})
elif match == "in" and isinstance(val, list):
if '*' not in val:
lookup = "{}__in".format(name)
q = Q(**{lookup: val})
else:
q = Q()
else:
if val == '*':
q = Q()
else:
q = Q(**{name: val})
if rel == 'or':
filters |= q
elif rel == 'not':
excludes |= q
else:
filters &= q
return model.objects.filter(filters).exclude(excludes)
else: else:
return model.objects.none() return queryset.none()
def get_attr_q(self):
q = self._get_filter_attrs_q(self.value)
return q
def all(self): def all(self):
return self._get_queryset() return self._get_queryset()
@ -415,40 +440,68 @@ class JSONManyToManyDescriptor:
value = value.value value = value.value
manager.set(value) manager.set(value)
def test_is(self): def is_match(self, obj, attr_rules):
print("Self.field is", self.field) # m2m 的情况
print("Self.field to", self.field.to) # 自定义的情况:比如 nodes, category
print("Self.field model", self.field.model) res = True
print("Self.field column", self.field.column) to_model = apps.get_model(self.field.to)
print("Self.field to", self.field.__dict__) src_model = self.field.model
field_name = self.field.name
custom_attr_filter = getattr(src_model, "get_filter_{}_attr_q".format(field_name), None)
@staticmethod custom_q = Q()
def attr_to_regex(attr):
"""将属性规则转换为正则表达式"""
name, value, match = attr['name'], attr['value'], attr['match']
if match == 'contains':
return r'.*{}.*'.format(escape_regex(value))
elif match == 'startswith':
return r'^{}.*'.format(escape_regex(value))
elif match == 'endswith':
return r'.*{}$'.format(escape_regex(value))
elif match == 'regex':
return value
elif match == 'not':
return r'^(?!^{}$)'.format(escape_regex(value))
elif match == 'in':
values = '|'.join(map(escape_regex, value))
return r'^(?:{})$'.format(values)
else:
return r'^{}$'.format(escape_regex(value))
def is_match(self, attr_dict, attr_rules):
for rule in attr_rules: for rule in attr_rules:
value = attr_dict.get(rule['name'], '') value = getattr(obj, rule['name'], '')
regex = self.attr_to_regex(rule) rule_value = rule.get('value', '')
if not re.match(regex, value): rule_match = rule.get('match', 'exact')
return False
return True if custom_attr_filter:
q = custom_attr_filter(rule['name'], rule_value, rule_match)
if q:
custom_q &= q
continue
if rule_match == 'in':
res &= value in rule_value
elif rule_match == 'exact':
res &= value == rule_value
elif rule_match == 'contains':
res &= rule_value in value
elif rule_match == 'startswith':
res &= str(value).startswith(str(rule_value))
elif rule_match == 'endswith':
res &= str(value).endswith(str(rule_value))
elif rule_match == 'regex':
res &= re.match(rule_value, value)
elif rule_match == 'not':
res &= value != rule_value
elif rule['match'] == 'gte':
res &= value >= rule_value
elif rule['match'] == 'lte':
res &= value <= rule_value
elif rule['match'] == 'gt':
res &= value > rule_value
elif rule['match'] == 'lt':
res &= value < rule_value
elif rule['match'] == 'ip_in':
if isinstance(rule_value, str):
rule_value = [rule_value]
res &= contains_ip(value, rule_value)
elif rule['match'] == 'm2m':
if isinstance(value, Manager):
value = value.values_list('id', flat=True)
value = set(map(str, value))
rule_value = set(map(str, rule_value))
res &= rule_value.issubset(value)
else:
logging.error("unknown match: {}".format(rule['match']))
res &= False
if not res:
return res
if custom_q:
res &= to_model.objects.filter(custom_q).filter(id=obj.id).exists()
return res
def get_filter_q(self, instance): def get_filter_q(self, instance):
model_cls = self.field.model model_cls = self.field.model
@ -457,18 +510,12 @@ class JSONManyToManyDescriptor:
queryset_id_attrs = model_cls.objects \ queryset_id_attrs = model_cls.objects \
.filter(**{'{}__type'.format(field_name): 'attrs'}) \ .filter(**{'{}__type'.format(field_name): 'attrs'}) \
.values_list('id', '{}__attrs'.format(field_name)) .values_list('id', '{}__attrs'.format(field_name))
instance_attr = {k: v for k, v in instance.__dict__.items() if not k.startswith('_')} ids = [str(_id) for _id, attr_rules in queryset_id_attrs if self.is_match(instance, attr_rules)]
ids = [str(_id) for _id, attr_rules in queryset_id_attrs if self.is_match(instance_attr, attr_rules)]
if ids: if ids:
q |= Q(id__in=ids) q |= Q(id__in=ids)
return q return q
def escape_regex(s):
"""转义字符串中的正则表达式特殊字符"""
return re.sub('[.*+?^${}()|[\\]]', r'\\\g<0>', s)
class JSONManyToManyField(models.JSONField): class JSONManyToManyField(models.JSONField):
def __init__(self, to, *args, **kwargs): def __init__(self, to, *args, **kwargs):
self.to = to self.to = to
@ -490,7 +537,7 @@ class JSONManyToManyField(models.JSONField):
e = ValueError(_( e = ValueError(_(
"Invalid JSON data for JSONManyToManyField, should be like " "Invalid JSON data for JSONManyToManyField, should be like "
"{'type': 'all'} or {'type': 'ids', 'ids': []} " "{'type': 'all'} or {'type': 'ids', 'ids': []} "
"or {'type': 'attrs', 'attrs': [{'name': 'ip', 'match': 'exact', 'value': 'value', 'rel': 'and|or|not'}}" "or {'type': 'attrs', 'attrs': [{'name': 'ip', 'match': 'exact', 'value': '1.1.1.1'}}"
)) ))
if not isinstance(val, dict): if not isinstance(val, dict):
raise e raise e

View File

@ -1,3 +1,4 @@
import ipaddress
import socket import socket
from ipaddress import ip_network, ip_address from ipaddress import ip_network, ip_address
@ -75,6 +76,23 @@ def contains_ip(ip, ip_group):
return False return False
def is_ip(self, ip, rule_value):
if rule_value == '*':
return True
elif '/' in rule_value:
network = ipaddress.ip_network(rule_value)
return ip in network.hosts()
elif '-' in rule_value:
start_ip, end_ip = rule_value.split('-')
start_ip = ipaddress.ip_address(start_ip)
end_ip = ipaddress.ip_address(end_ip)
return start_ip <= ip <= end_ip
elif len(rule_value.split('.')) == 4:
return ip == rule_value
else:
return ip.startswith(rule_value)
def get_ip_city(ip): def get_ip_city(ip):
if not ip or not isinstance(ip, str): if not ip or not isinstance(ip, str):
return _("Invalid address") return _("Invalid address")

View File

@ -668,7 +668,33 @@ class MFAMixin:
return backend return backend
class User(AuthMixin, TokenMixin, RoleMixin, MFAMixin, AbstractUser): class JSONFilterMixin:
"""
users = JSONManyToManyField('users.User', blank=True, null=True)
"""
@staticmethod
def get_json_filter_attr_q(name, value, match):
from rbac.models import RoleBinding
from orgs.utils import current_org
if name == 'system_roles':
user_id = RoleBinding.objects \
.filter(role__in=value, scope='system') \
.values_list('user_id', flat=True)
return models.Q(id__in=user_id)
elif name == 'org_roles':
kwargs = dict(role__in=value, scope='org')
if not current_org.is_root():
kwargs['org_id'] = current_org.id
user_id = RoleBinding.objects.filter(**kwargs) \
.values_list('user_id', flat=True)
return models.Q(id__in=user_id)
return None
class User(AuthMixin, TokenMixin, RoleMixin, MFAMixin, JSONFilterMixin, AbstractUser):
class Source(models.TextChoices): class Source(models.TextChoices):
local = 'local', _('Local') local = 'local', _('Local')
ldap = 'ldap', 'LDAP/AD' ldap = 'ldap', 'LDAP/AD'