pref: 修改 asset permission

pull/9048/head
ibuler 2022-11-11 15:04:31 +08:00
parent 644f3f1783
commit f6e403fd8b
32 changed files with 835 additions and 646 deletions

3
.isort.cfg Normal file
View File

@ -0,0 +1,3 @@
[settings]
line_length=120
known_first_party=common,users,assets,perms,authentication,jumpserver,notification,ops,orgs,rbac,settings,terminal,tickets

View File

@ -1,89 +1,91 @@
# -*- coding: utf-8 -*-
#
import django_filters
from rest_framework.decorators import action
from rest_framework.response import Response
from common.utils import get_logger
from common.drf.filters import BaseFilterSet
from common.mixins.api import SuggestionMixin
from orgs.mixins.api import OrgBulkModelViewSet
from orgs.mixins import generics
from assets import serializers
from assets.filters import IpInFilterBackend, LabelFilterBackend, NodeFilterBackend
from assets.models import Asset, Gateway
from assets.tasks import (
push_accounts_to_assets,
verify_accounts_connectivity,
test_assets_connectivity_manual,
update_assets_hardware_info_manual,
verify_accounts_connectivity,
)
from assets.filters import NodeFilterBackend, LabelFilterBackend, IpInFilterBackend
from common.drf.filters import BaseFilterSet
from common.mixins.api import SuggestionMixin
from common.utils import get_logger
from orgs.mixins import generics
from orgs.mixins.api import OrgBulkModelViewSet
from ..mixin import NodeFilterMixin
logger = get_logger(__file__)
__all__ = [
'AssetViewSet', 'AssetTaskCreateApi', 'AssetsTaskCreateApi',
"AssetViewSet",
"AssetTaskCreateApi",
"AssetsTaskCreateApi",
]
class AssetFilterSet(BaseFilterSet):
type = django_filters.CharFilter(field_name='platform__type', lookup_expr='exact')
category = django_filters.CharFilter(field_name='platform__category', lookup_expr='exact')
hostname = django_filters.CharFilter(field_name='name', lookup_expr='exact')
type = django_filters.CharFilter(field_name="platform__type", lookup_expr="exact")
category = django_filters.CharFilter(
field_name="platform__category", lookup_expr="exact"
)
hostname = django_filters.CharFilter(field_name="name", lookup_expr="exact")
class Meta:
model = Asset
fields = ['name', 'address', 'is_active', 'type', 'category', 'hostname']
fields = ["name", "address", "is_active", "type", "category", "hostname"]
class AssetViewSet(SuggestionMixin, NodeFilterMixin, OrgBulkModelViewSet):
"""
API endpoint that allows Asset to be viewed or edited.
"""
model = Asset
filterset_class = AssetFilterSet
search_fields = ("name", "address")
ordering_fields = ("name", "address")
ordering = ('name',)
ordering = ("name",)
serializer_classes = (
('default', serializers.AssetSerializer),
('suggestion', serializers.MiniAssetSerializer),
('platform', serializers.PlatformSerializer),
('gateways', serializers.GatewayWithAuthSerializer)
("default", serializers.AssetSerializer),
("suggestion", serializers.MiniAssetSerializer),
("platform", serializers.PlatformSerializer),
("gateways", serializers.GatewayWithAuthSerializer),
)
rbac_perms = (
('match', 'assets.match_asset'),
('platform', 'assets.view_platform'),
('gateways', 'assets.view_gateway')
("match", "assets.match_asset"),
("platform", "assets.view_platform"),
("gateways", "assets.view_gateway"),
)
extra_filter_backends = [
LabelFilterBackend,
IpInFilterBackend,
NodeFilterBackend
]
extra_filter_backends = [LabelFilterBackend, IpInFilterBackend, NodeFilterBackend]
@action(methods=['GET'], detail=True, url_path='platform')
@action(methods=["GET"], detail=True, url_path="platform")
def platform(self, *args, **kwargs):
asset = self.get_object()
serializer = self.get_serializer(asset.platform)
return Response(serializer.data)
@action(methods=['GET'], detail=True, url_path='gateways')
@action(methods=["GET"], detail=True, url_path="gateways")
def gateways(self, *args, **kwargs):
asset = self.get_object()
if not asset.domain:
gateways = Gateway.objects.none()
else:
gateways = asset.domain.gateways.filter(protocol='ssh')
gateways = asset.domain.gateways.filter(protocol="ssh")
return self.get_paginated_response_from_queryset(gateways)
class AssetsTaskMixin:
def perform_assets_task(self, serializer):
data = serializer.validated_data
assets = data.get('assets', [])
assets = data.get("assets", [])
asset_ids = [asset.id for asset in assets]
if data['action'] == "refresh":
if data["action"] == "refresh":
task = update_assets_hardware_info_manual.delay(asset_ids)
else:
task = test_assets_connectivity_manual.delay(asset_ids)
@ -94,9 +96,9 @@ class AssetsTaskMixin:
self.set_task_to_serializer_data(serializer, task)
def set_task_to_serializer_data(self, serializer, task):
data = getattr(serializer, '_data', {})
data = getattr(serializer, "_data", {})
data["task"] = task.id
setattr(serializer, '_data', data)
setattr(serializer, "_data", data)
class AssetTaskCreateApi(AssetsTaskMixin, generics.CreateAPIView):
@ -104,18 +106,18 @@ class AssetTaskCreateApi(AssetsTaskMixin, generics.CreateAPIView):
serializer_class = serializers.AssetTaskSerializer
def create(self, request, *args, **kwargs):
pk = self.kwargs.get('pk')
request.data['asset'] = pk
request.data['assets'] = [pk]
pk = self.kwargs.get("pk")
request.data["asset"] = pk
request.data["assets"] = [pk]
return super().create(request, *args, **kwargs)
def check_permissions(self, request):
action = request.data.get('action')
action = request.data.get("action")
action_perm_require = {
'refresh': 'assets.refresh_assethardwareinfo',
'push_account': 'assets.push_assetsystemuser',
'test': 'assets.test_assetconnectivity',
'test_account': 'assets.test_assetconnectivity'
"refresh": "assets.refresh_assethardwareinfo",
"push_account": "assets.push_assetsystemuser",
"test": "assets.test_assetconnectivity",
"test_account": "assets.test_assetconnectivity",
}
perm_required = action_perm_require.get(action)
has = self.request.user.has_perm(perm_required)
@ -126,19 +128,19 @@ class AssetTaskCreateApi(AssetsTaskMixin, generics.CreateAPIView):
@staticmethod
def perform_asset_task(serializer):
data = serializer.validated_data
if data['action'] not in ['push_system_user', 'test_system_user']:
if data["action"] not in ["push_system_user", "test_system_user"]:
return
asset = data['asset']
accounts = data.get('accounts')
asset = data["asset"]
accounts = data.get("accounts")
if not accounts:
accounts = asset.accounts.all()
asset_ids = [asset.id]
account_ids = accounts.values_list('id', flat=True)
if action == 'push_account':
account_ids = accounts.values_list("id", flat=True)
if action == "push_account":
task = push_accounts_to_assets.delay(account_ids, asset_ids)
elif action == 'test_account':
elif action == "test_account":
task = verify_accounts_connectivity.delay(account_ids, asset_ids)
else:
task = None
@ -156,9 +158,9 @@ class AssetsTaskCreateApi(AssetsTaskMixin, generics.CreateAPIView):
serializer_class = serializers.AssetsTaskSerializer
def check_permissions(self, request):
action = request.data.get('action')
action = request.data.get("action")
action_perm_require = {
'refresh': 'assets.refresh_assethardwareinfo',
"refresh": "assets.refresh_assethardwareinfo",
}
perm_required = action_perm_require.get(action)
has = self.request.user.has_perm(perm_required)

View File

@ -1,13 +1,14 @@
# -*- coding: utf-8 -*-
#
from django.db.models import Q
from django_filters import rest_framework as drf_filters
from rest_framework import filters
from rest_framework.compat import coreapi, coreschema
from django_filters import rest_framework as drf_filters
from assets.utils import get_node_from_request, is_query_node_all_assets
from common.drf.filters import BaseFilterSet
from assets.utils import is_query_node_all_assets, get_node_from_request
from .models import Label, Node, Account
from .models import Account, Label, Node
class AssetByNodeFilterBackend(filters.BaseFilterBackend):

View File

@ -3,7 +3,8 @@ from django.utils.translation import gettext_lazy as _
from simple_history.models import HistoricalRecords
from common.utils import lazyproperty
from .base import BaseAccount, AbsConnectivity
from .base import AbsConnectivity, BaseAccount
__all__ = ['Account', 'AccountTemplate']
@ -40,9 +41,10 @@ class AccountHistoricalRecords(HistoricalRecords):
class Account(AbsConnectivity, BaseAccount):
class InnerAccount(models.TextChoices):
INPUT = '@INPUT', '@INPUT'
USER = '@USER', '@USER'
class AliasAccount(models.TextChoices):
ALL = '@ALL', _('All')
INPUT = '@INPUT', _('Manual input')
USER = '@USER', _('Dynamic user')
asset = models.ForeignKey(
'assets.Asset', related_name='accounts',
@ -76,14 +78,14 @@ class Account(AbsConnectivity, BaseAccount):
return '{}'.format(self.username)
@classmethod
def get_input_account(cls):
def get_manual_account(cls):
""" @INPUT 手动登录的账号(any) """
return cls(name=cls.InnerAccount.INPUT.value, username='')
return cls(name=cls.AliasAccount.INPUT.label, username=cls.AliasAccount.INPUT.value, secret=None)
@classmethod
def get_user_account(cls, username):
""" @USER 动态用户的账号(self) """
return cls(name=cls.InnerAccount.USER.value, username=username)
return cls(name=cls.AliasAccount.USER.label, username=cls.AliasAccount.USER.value)
class AccountTemplate(BaseAccount):

View File

@ -1,61 +1,75 @@
from rest_framework import serializers
from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
from common.drf.fields import LabeledChoiceField
from common.drf.serializers import WritableNestedModelSerializer
from ..models import Platform, PlatformProtocol, PlatformAutomation
from ..const import Category, AllTypes
from ..models import Platform, PlatformProtocol, PlatformAutomation
__all__ = ['PlatformSerializer', 'PlatformOpsMethodSerializer']
__all__ = ["PlatformSerializer", "PlatformOpsMethodSerializer"]
class ProtocolSettingSerializer(serializers.Serializer):
SECURITY_CHOICES = [
('any', 'Any'),
('rdp', 'RDP'),
('tls', 'TLS'),
('nla', 'NLA'),
("any", "Any"),
("rdp", "RDP"),
("tls", "TLS"),
("nla", "NLA"),
]
# RDP
console = serializers.BooleanField(required=False)
security = serializers.ChoiceField(choices=SECURITY_CHOICES, default='any')
security = serializers.ChoiceField(choices=SECURITY_CHOICES, default="any")
# SFTP
sftp_enabled = serializers.BooleanField(default=True, label=_("SFTP enabled"))
sftp_home = serializers.CharField(default='/tmp', label=_("SFTP home"))
sftp_home = serializers.CharField(default="/tmp", label=_("SFTP home"))
# HTTP
auto_fill = serializers.BooleanField(default=False, label=_("Auto fill"))
username_selector = serializers.CharField(default='', allow_blank=True, label=_("Username selector"))
password_selector = serializers.CharField(default='', allow_blank=True, label=_("Password selector"))
submit_selector = serializers.CharField(default='', allow_blank=True, label=_("Submit selector"))
username_selector = serializers.CharField(
default="", allow_blank=True, label=_("Username selector")
)
password_selector = serializers.CharField(
default="", allow_blank=True, label=_("Password selector")
)
submit_selector = serializers.CharField(
default="", allow_blank=True, label=_("Submit selector")
)
class PlatformAutomationSerializer(serializers.ModelSerializer):
class Meta:
model = PlatformAutomation
fields = [
'id', 'ansible_enabled', 'ansible_config',
'ping_enabled', 'ping_method',
'gather_facts_enabled', 'gather_facts_method',
'push_account_enabled', 'push_account_method',
'change_secret_enabled', 'change_secret_method',
'verify_account_enabled', 'verify_account_method',
'gather_accounts_enabled', 'gather_accounts_method',
"id",
"ansible_enabled",
"ansible_config",
"ping_enabled",
"ping_method",
"gather_facts_enabled",
"gather_facts_method",
"push_account_enabled",
"push_account_method",
"change_secret_enabled",
"change_secret_method",
"verify_account_enabled",
"verify_account_method",
"gather_accounts_enabled",
"gather_accounts_method",
]
extra_kwargs = {
'ping_enabled': {'label': '启用资产探测'},
'ping_method': {'label': '探测方式'},
'gather_facts_enabled': {'label': '启用收集信息'},
'gather_facts_method': {'label': '收集信息方式'},
'verify_account_enabled': {'label': '启用校验账号'},
'verify_account_method': {'label': '校验账号方式'},
'push_account_enabled': {'label': '启用推送账号'},
'push_account_method': {'label': '推送账号方式'},
'change_secret_enabled': {'label': '启用账号改密'},
'change_secret_method': {'label': '账号创建改密方式'},
'gather_accounts_enabled': {'label': '启用账号收集'},
'gather_accounts_method': {'label': '收集账号方式'},
"ping_enabled": {"label": "启用资产探测"},
"ping_method": {"label": "探测方式"},
"gather_facts_enabled": {"label": "启用收集信息"},
"gather_facts_method": {"label": "收集信息方式"},
"verify_account_enabled": {"label": "启用校验账号"},
"verify_account_method": {"label": "校验账号方式"},
"push_account_enabled": {"label": "启用推送账号"},
"push_account_method": {"label": "推送账号方式"},
"change_secret_enabled": {"label": "启用账号改密"},
"change_secret_method": {"label": "账号创建改密方式"},
"gather_accounts_enabled": {"label": "启用账号收集"},
"gather_accounts_method": {"label": "收集账号方式"},
}
@ -66,42 +80,62 @@ class PlatformProtocolsSerializer(serializers.ModelSerializer):
class Meta:
model = PlatformProtocol
fields = [
'id', 'name', 'port', 'primary', 'default',
'required', 'secret_types', 'setting',
"id",
"name",
"port",
"primary",
"default",
"required",
"secret_types",
"setting",
]
class PlatformSerializer(WritableNestedModelSerializer):
charset = LabeledChoiceField(
choices=Platform.CharsetChoices.choices, label=_("Charset")
)
type = LabeledChoiceField(choices=AllTypes.choices(), label=_("Type"))
category = LabeledChoiceField(choices=Category.choices, label=_("Category"))
protocols = PlatformProtocolsSerializer(label=_('Protocols'), many=True, required=False)
automation = PlatformAutomationSerializer(label=_('Automation'), required=False)
protocols = PlatformProtocolsSerializer(
label=_("Protocols"), many=True, required=False
)
automation = PlatformAutomationSerializer(label=_("Automation"), required=False)
su_method = LabeledChoiceField(
choices=[('sudo', 'sudo su -'), ('su', 'su - ')],
label='切换方式', required=False, default='sudo'
choices=[("sudo", "sudo su -"), ("su", "su - ")],
label="切换方式",
required=False,
default="sudo",
)
class Meta:
model = Platform
fields_mini = ['id', 'name', 'internal']
fields_mini = ["id", "name", "internal"]
fields_small = fields_mini + [
'category', 'type', 'charset',
"category",
"type",
"charset",
]
fields = fields_small + [
'protocols_enabled', 'protocols', 'domain_enabled',
'su_enabled', 'su_method', 'automation', 'comment',
"protocols_enabled",
"protocols",
"domain_enabled",
"su_enabled",
"su_method",
"automation",
"comment",
]
extra_kwargs = {
'su_enabled': {'label': '启用切换账号'},
'protocols_enabled': {'label': '启用协议'},
'domain_enabled': {'label': "启用网域"},
'domain_default': {'label': "默认网域"},
"su_enabled": {"label": "启用切换账号"},
"protocols_enabled": {"label": "启用协议"},
"domain_enabled": {"label": "启用网域"},
"domain_default": {"label": "默认网域"},
}
class PlatformOpsMethodSerializer(serializers.Serializer):
id = serializers.CharField(read_only=True)
name = serializers.CharField(max_length=50, label=_('Name'))
category = serializers.CharField(max_length=50, label=_('Category'))
name = serializers.CharField(max_length=50, label=_("Name"))
category = serializers.CharField(max_length=50, label=_("Category"))
type = serializers.ListSerializer(child=serializers.CharField())
method = serializers.CharField()

View File

@ -16,7 +16,7 @@ from rest_framework.request import Request
from common.drf.api import JMSModelViewSet
from common.http import is_true
from orgs.mixins.api import RootOrgViewMixin
from perms.models import Action
from perms.models import ActionChoices
from terminal.models import EndpointRule
from ..serializers import (
ConnectionTokenSerializer, ConnectionTokenSecretSerializer,
@ -70,8 +70,8 @@ class RDPFileClientProtocolURLMixin:
# 设置磁盘挂载
drives_redirect = is_true(self.request.query_params.get('drives_redirect'))
if drives_redirect:
actions = Action.choices_to_value(token.actions)
if actions & Action.UPDOWNLOAD == Action.UPDOWNLOAD:
actions = ActionChoices.choices_to_value(token.actions)
if actions & Action.TRANSFER == Action.TRANSFER:
rdp_options['drivestoredirect:s'] = '*'
# 设置全屏

View File

@ -7,7 +7,7 @@ from common.utils import pretty_string
from common.utils.random import random_string
from assets.models import Asset, Gateway, Domain, CommandFilterRule, Account
from users.models import User
from perms.serializers.permission import ActionsField
from perms.serializers.permission import ActionChoicesField
__all__ = [
@ -158,14 +158,13 @@ class ConnectionTokenSecretSerializer(OrgResourceModelSerializerMixin):
gateway = ConnectionTokenGatewaySerializer(read_only=True)
domain = ConnectionTokenDomainSerializer(read_only=True)
cmd_filter_rules = ConnectionTokenCmdFilterRuleSerializer(many=True)
actions = ActionsField()
actions = ActionChoicesField()
expire_at = serializers.IntegerField()
class Meta:
model = ConnectionToken
fields = [
'id', 'secret',
'user', 'asset', 'account_username', 'account', 'protocol',
'domain', 'gateway', 'cmd_filter_rules',
'actions', 'expire_at',
'id', 'secret', 'user', 'asset', 'account_username',
'account', 'protocol', 'domain', 'gateway',
'cmd_filter_rules', 'actions', 'expire_at',
]

View File

@ -1,27 +1,28 @@
from urllib.parse import urlencode
from django.conf import settings
from django.db.utils import IntegrityError
from django.http.request import HttpRequest
from django.http.response import HttpResponseRedirect
from django.utils.translation import ugettext_lazy as _
from urllib.parse import urlencode
from django.views import View
from django.conf import settings
from django.http.request import HttpRequest
from django.db.utils import IntegrityError
from rest_framework.permissions import IsAuthenticated, AllowAny
from rest_framework.exceptions import APIException
from rest_framework.permissions import AllowAny, IsAuthenticated
from authentication import errors
from authentication.const import ConfirmType
from authentication.mixins import AuthMixin
from authentication.notifications import OAuthBindMessage
from common.mixins.views import PermissionsMixin, UserConfirmRequiredExceptionMixin
from common.permissions import UserConfirmation
from common.sdk.im.dingtalk import URL, DingTalk
from common.utils import FlashMessageUtil, get_logger
from common.utils.common import get_request_ip
from common.utils.django import get_object_or_none, reverse
from common.utils.random import random_string
from users.models import User
from users.views import UserVerifyPasswordView
from common.utils import get_logger, FlashMessageUtil
from common.utils.random import random_string
from common.utils.django import reverse, get_object_or_none
from common.sdk.im.dingtalk import URL
from common.mixins.views import UserConfirmRequiredExceptionMixin, PermissionsMixin
from common.permissions import UserConfirmation
from authentication import errors
from authentication.mixins import AuthMixin
from authentication.const import ConfirmType
from common.sdk.im.dingtalk import DingTalk
from common.utils.common import get_request_ip
from authentication.notifications import OAuthBindMessage
from .mixins import METAMixin
logger = get_logger(__file__)

View File

@ -1,26 +1,27 @@
from urllib.parse import urlencode
from django.conf import settings
from django.db.utils import IntegrityError
from django.http.request import HttpRequest
from django.http.response import HttpResponseRedirect
from django.utils.translation import ugettext_lazy as _
from urllib.parse import urlencode
from django.views import View
from django.conf import settings
from django.http.request import HttpRequest
from django.db.utils import IntegrityError
from rest_framework.permissions import IsAuthenticated, AllowAny
from rest_framework.exceptions import APIException
from rest_framework.permissions import AllowAny, IsAuthenticated
from users.models import User
from users.views import UserVerifyPasswordView
from common.utils import get_logger, FlashMessageUtil
from common.utils.random import random_string
from common.utils.django import reverse, get_object_or_none
from common.mixins.views import UserConfirmRequiredExceptionMixin, PermissionsMixin
from common.permissions import UserConfirmation
from common.sdk.im.feishu import FeiShu, URL
from common.utils.common import get_request_ip
from authentication import errors
from authentication.const import ConfirmType
from authentication.mixins import AuthMixin
from authentication.notifications import OAuthBindMessage
from common.mixins.views import PermissionsMixin, UserConfirmRequiredExceptionMixin
from common.permissions import UserConfirmation
from common.sdk.im.feishu import URL, FeiShu
from common.utils import FlashMessageUtil, get_logger
from common.utils.common import get_request_ip
from common.utils.django import get_object_or_none, reverse
from common.utils.random import random_string
from users.models import User
from users.views import UserVerifyPasswordView
logger = get_logger(__file__)

View File

@ -1,10 +1,12 @@
# -*- coding: utf-8 -*-
#
import json
from django.db import models
from django.utils.translation import ugettext_lazy as _
from django.utils.encoding import force_text
from django.core.validators import MinValueValidator, MaxValueValidator
from common.utils import signer, crypto
@ -13,7 +15,7 @@ __all__ = [
'JsonCharField', 'JsonTextField', 'JsonListCharField', 'JsonListTextField',
'JsonDictCharField', 'JsonDictTextField', 'EncryptCharField',
'EncryptTextField', 'EncryptMixin', 'EncryptJsonDictTextField',
'EncryptJsonDictCharField', 'PortField'
'EncryptJsonDictCharField', 'PortField', 'BitChoices',
]
@ -190,3 +192,37 @@ class PortField(models.IntegerField):
})
super().__init__(*args, **kwargs)
class BitChoices(models.IntegerChoices):
@classmethod
def branches(cls):
return [i for i in cls]
@classmethod
def tree(cls):
root = [_('All'), cls.branches()]
return cls.render_node(root)
@classmethod
def render_node(cls, node):
if isinstance(node, BitChoices):
return {
'id': node.name,
'label': node.label,
}
else:
name, children = node
return {
'id': name,
'label': name,
'children': [cls.render_node(child) for child in children]
}
@classmethod
def all(cls):
value = 0
for c in cls:
value |= c.value
return value

View File

@ -1,17 +1,20 @@
# -*- coding: utf-8 -*-
#
import six
from rest_framework.fields import ChoiceField
from rest_framework import serializers
from django.utils.translation import gettext_lazy as _
from django.core.exceptions import ObjectDoesNotExist
from django.db.models import IntegerChoices
from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
from rest_framework.fields import ChoiceField
from common.utils import decrypt_password
__all__ = [
'ReadableHiddenField', 'EncryptedField', 'LabeledChoiceField',
'ObjectRelatedField',
"ReadableHiddenField",
"EncryptedField",
"LabeledChoiceField",
"ObjectRelatedField",
"BitChoicesField",
]
@ -20,14 +23,15 @@ __all__ = [
class ReadableHiddenField(serializers.HiddenField):
""" 可读的 HiddenField """
"""可读的 HiddenField"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.write_only = False
def to_representation(self, value):
if hasattr(value, 'id'):
return getattr(value, 'id')
if hasattr(value, "id"):
return getattr(value, "id")
return value
@ -35,7 +39,7 @@ class EncryptedField(serializers.CharField):
def __init__(self, write_only=None, **kwargs):
if write_only is None:
write_only = True
kwargs['write_only'] = write_only
kwargs["write_only"] = write_only
super().__init__(**kwargs)
def to_internal_value(self, value):
@ -54,26 +58,26 @@ class LabeledChoiceField(ChoiceField):
if value is None:
return value
return {
'value': value,
'label': self.choice_mapper.get(six.text_type(value), value),
"value": value,
"label": self.choice_mapper.get(six.text_type(value), value),
}
def to_internal_value(self, data):
if isinstance(data, dict):
return data.get('value')
return data.get("value")
return super(LabeledChoiceField, self).to_internal_value(data)
class ObjectRelatedField(serializers.RelatedField):
default_error_messages = {
'required': _('This field is required.'),
'does_not_exist': _('Invalid pk "{pk_value}" - object does not exist.'),
'incorrect_type': _('Incorrect type. Expected pk value, received {data_type}.'),
"required": _("This field is required."),
"does_not_exist": _('Invalid pk "{pk_value}" - object does not exist.'),
"incorrect_type": _("Incorrect type. Expected pk value, received {data_type}."),
}
def __init__(self, **kwargs):
self.attrs = kwargs.pop('attrs', None) or ('id', 'name')
self.many = kwargs.get('many', False)
self.attrs = kwargs.pop("attrs", None) or ("id", "name")
self.many = kwargs.get("many", False)
super().__init__(**kwargs)
def to_representation(self, value):
@ -86,13 +90,53 @@ class ObjectRelatedField(serializers.RelatedField):
if not isinstance(data, dict):
pk = data
else:
pk = data.get('id') or data.get('pk') or data.get(self.attrs[0])
pk = data.get("id") or data.get("pk") or data.get(self.attrs[0])
queryset = self.get_queryset()
try:
if isinstance(data, bool):
raise TypeError
return queryset.get(pk=pk)
except ObjectDoesNotExist:
self.fail('does_not_exist', pk_value=pk)
self.fail("does_not_exist", pk_value=pk)
except (TypeError, ValueError):
self.fail('incorrect_type', data_type=type(pk).__name__)
self.fail("incorrect_type", data_type=type(pk).__name__)
class BitChoicesField(serializers.MultipleChoiceField):
"""
位字段
"""
def __init__(self, choice_cls, **kwargs):
assert issubclass(choice_cls, IntegerChoices)
choices = [(c.name, c.label) for c in choice_cls]
self._choice_cls = choice_cls
super().__init__(choices=choices, **kwargs)
def to_representation(self, value):
return [
{"value": c.name, "label": c.label}
for c in self._choice_cls
if c.value & value == c.value
]
def to_internal_value(self, data):
if not isinstance(data, list):
raise serializers.ValidationError(_("Invalid data type, should be list"))
value = 0
if not data:
return value
if isinstance(data[0], dict):
data = [d["value"] for d in data]
# 所有的
if "all" in data:
for c in self._choice_cls:
value |= c.value
return value
name_value_map = {c.name: c.value for c in self._choice_cls}
for name in data:
if name not in name_value_map:
raise serializers.ValidationError(_("Invalid choice: {}").format(name))
value |= name_value_map[name]
return value

View File

@ -2,17 +2,15 @@
#
from __future__ import unicode_literals
from collections import OrderedDict
import datetime
from itertools import chain
from collections import OrderedDict
from django.core.exceptions import PermissionDenied
from django.http import Http404
from django.utils.encoding import force_text
from rest_framework.fields import empty
from rest_framework.metadata import SimpleMetadata
from rest_framework import exceptions, serializers
from rest_framework.fields import empty
from rest_framework.metadata import SimpleMetadata
from rest_framework.request import clone_request
@ -21,9 +19,14 @@ class SimpleMetadataWithFilters(SimpleMetadata):
methods = {"PUT", "POST", "GET", "PATCH"}
attrs = [
'read_only', 'label', 'help_text',
'min_length', 'max_length',
'min_value', 'max_value', "write_only",
"read_only",
"label",
"help_text",
"min_length",
"max_length",
"min_value",
"max_value",
"write_only",
]
def determine_actions(self, request, view):
@ -32,18 +35,18 @@ class SimpleMetadataWithFilters(SimpleMetadata):
the fields that are accepted for 'PUT' and 'POST' methods.
"""
actions = {}
view.raw_action = getattr(view, 'action', None)
view.raw_action = getattr(view, "action", None)
for method in self.methods & set(view.allowed_methods):
if hasattr(view, 'action_map'):
if hasattr(view, "action_map"):
view.action = view.action_map.get(method.lower(), view.action)
view.request = clone_request(request, method)
try:
# Test global permissions
if hasattr(view, 'check_permissions'):
if hasattr(view, "check_permissions"):
view.check_permissions(view.request)
# Test object permissions
if method == 'PUT' and hasattr(view, 'get_object'):
if method == "PUT" and hasattr(view, "get_object"):
view.get_object()
except (exceptions.APIException, PermissionDenied, Http404):
pass
@ -62,64 +65,63 @@ class SimpleMetadataWithFilters(SimpleMetadata):
of metadata about it.
"""
field_info = OrderedDict()
field_info['type'] = self.label_lookup[field]
field_info['required'] = getattr(field, 'required', False)
field_info["type"] = self.label_lookup[field]
field_info["required"] = getattr(field, "required", False)
default = getattr(field, 'default', None)
# Default value
default = getattr(field, "default", None)
if default is not None and default != empty:
if isinstance(default, (str, int, bool, float, datetime.datetime, list)):
field_info['default'] = default
field_info["default"] = default
for attr in self.attrs:
value = getattr(field, attr, None)
if value is not None and value != '':
if value is not None and value != "":
field_info[attr] = force_text(value, strings_only=True)
if getattr(field, 'child', None):
field_info['child'] = self.get_field_info(field.child)
elif getattr(field, 'fields', None):
field_info['children'] = self.get_serializer_info(field)
if getattr(field, "child", None):
field_info["child"] = self.get_field_info(field.child)
elif getattr(field, "fields", None):
field_info["children"] = self.get_serializer_info(field)
is_related_field = isinstance(field, (serializers.RelatedField, serializers.ManyRelatedField))
if not is_related_field and hasattr(field, 'choices'):
field_info['choices'] = [
is_choice_field = isinstance(field, (serializers.ChoiceField,))
if is_choice_field and hasattr(field, "choices"):
field_info["choices"] = [
{
'value': choice_value,
'label': force_text(choice_name, strings_only=True)
"value": choice_value,
"label": force_text(choice_label, strings_only=True),
}
for choice_value, choice_name in dict(field.choices).items()
for choice_value, choice_label in dict(field.choices).items()
]
class_name = field.__class__.__name__
if class_name == 'LabeledChoiceField':
field_info['type'] = 'labeled_choice'
elif class_name == 'ObjectRelatedField':
field_info['type'] = 'object_related_field'
elif class_name == 'ManyRelatedField':
if class_name == "LabeledChoiceField":
field_info["type"] = "labeled_choice"
elif class_name == "ObjectRelatedField":
field_info["type"] = "object_related_field"
elif class_name == "ManyRelatedField":
child_relation_class_name = field.child_relation.__class__.__name__
if child_relation_class_name == 'ObjectRelatedField':
field_info['type'] = 'm2m_related_field'
# if field.label == '系统平台':
# print("Field: ", class_name, field, field_info)
if child_relation_class_name == "ObjectRelatedField":
field_info["type"] = "m2m_related_field"
return field_info
def get_filters_fields(self, request, view):
@staticmethod
def get_filters_fields(request, view):
fields = []
if hasattr(view, 'get_filter_fields'):
if hasattr(view, "get_filter_fields"):
fields = view.get_filter_fields(request)
elif hasattr(view, 'filter_fields'):
elif hasattr(view, "filter_fields"):
fields = view.filter_fields
elif hasattr(view, 'filterset_fields'):
elif hasattr(view, "filterset_fields"):
fields = view.filterset_fields
elif hasattr(view, 'get_filterset_fields'):
elif hasattr(view, "get_filterset_fields"):
fields = view.get_filterset_fields(request)
elif hasattr(view, 'filterset_class'):
fields = list(view.filterset_class.Meta.fields) + \
list(view.filterset_class.declared_filters.keys())
elif hasattr(view, "filterset_class"):
fields = list(view.filterset_class.Meta.fields) + list(
view.filterset_class.declared_filters.keys()
)
if hasattr(view, 'custom_filter_fields'):
if hasattr(view, "custom_filter_fields"):
# 不能写 fields += view.custom_filter_fields
# 会改变 view 的 filter_fields
fields = list(fields) + list(view.custom_filter_fields)
@ -130,14 +132,16 @@ class SimpleMetadataWithFilters(SimpleMetadata):
def get_ordering_fields(self, request, view):
fields = []
if hasattr(view, 'get_ordering_fields'):
if hasattr(view, "get_ordering_fields"):
fields = view.get_ordering_fields(request)
elif hasattr(view, 'ordering_fields'):
elif hasattr(view, "ordering_fields"):
fields = view.ordering_fields
return fields
def determine_metadata(self, request, view):
metadata = super(SimpleMetadataWithFilters, self).determine_metadata(request, view)
metadata = super(SimpleMetadataWithFilters, self).determine_metadata(
request, view
)
filterset_fields = self.get_filters_fields(request, view)
order_fields = self.get_ordering_fields(request, view)

View File

@ -0,0 +1,3 @@
def bit(x):
return 2 ** (x - 1)

View File

@ -1,10 +1,9 @@
# -*- coding: utf-8 -*-
#
from perms.filters import AssetPermissionFilter
from perms.models import AssetPermission
from orgs.mixins.api import OrgBulkModelViewSet
from perms import serializers
from perms.filters import AssetPermissionFilter
from perms.models import AssetPermission
__all__ = ['AssetPermissionViewSet']
@ -18,4 +17,4 @@ class AssetPermissionViewSet(OrgBulkModelViewSet):
filterset_class = AssetPermissionFilter
search_fields = ('name',)
ordering_fields = ('name',)
ordering = ('name', )
ordering = ('name',)

View File

@ -6,7 +6,6 @@ from common.utils import get_logger, lazyproperty
from assets.serializers import AccountSerializer
from perms.hands import User, Asset, Account
from perms import serializers
from perms.models import Action
from perms.utils import PermAccountUtil
from .mixin import RoleAdminMixin, RoleUserMixin
@ -80,7 +79,7 @@ class UserGrantedAssetSpecialAccountsApi(ListAPIView):
def get_queryset(self):
# 构造默认包含的账号,如: @INPUT @USER
accounts = [
Account.get_input_account(),
Account.get_manual_account(),
Account.get_user_account(self.user.username)
]
for account in accounts:

View File

@ -3,11 +3,9 @@
from rest_framework.request import Request
from common.http import is_true
from common.mixins.api import RoleAdminMixin
from common.mixins.api import RoleUserMixin
from orgs.utils import tmp_to_root_org
from users.models import User
from common.mixins.api import RoleAdminMixin, RoleUserMixin
from perms.utils.user_permission import UserGrantedTreeRefreshController
from users.models import User
class RebuildTreeMixin:

View File

@ -1,2 +1,71 @@
# -*- coding: utf-8 -*-
#
from django.db import models
from django.utils.translation import ugettext_lazy as _
from common.utils.integer import bit
from common.db.fields import BitChoices
__all__ = ['SpecialAccount', 'ActionChoices']
class ActionChoices(BitChoices):
connect = bit(0), _('Connect')
upload = bit(1), _('Upload')
download = bit(2), _('Download')
copy = bit(3), _('Copy')
paste = bit(4), _('Paste')
@classmethod
def branches(cls):
return (
(_('Transfer'), [cls.upload, cls.download]),
(_('Clipboard'), [cls.copy, cls.paste]),
)
# class Action(BitOperationChoice):
# CONNECT = 0b1
# UPLOAD = 0b1 << 1
# DOWNLOAD = 0b1 << 2
# COPY = 0b1 << 3
# PASTE = 0b1 << 4
# ALL = 0 << 8
# TRANSFER = UPLOAD | DOWNLOAD
# CLIPBOARD = COPY | PASTE
#
# DB_CHOICES = (
# (ALL, _('All')),
# (CONNECT, _('Connect')),
# (UPLOAD, _('Upload file')),
# (DOWNLOAD, _('Download file')),
# (TRANSFER, _("Upload download")),
# (COPY, _('Clipboard copy')),
# (PASTE, _('Clipboard paste')),
# (CLIPBOARD, _('Clipboard copy paste'))
# )
#
# NAME_MAP = {
# ALL: "all",
# CONNECT: "connect",
# UPLOAD: "upload",
# DOWNLOAD: "download",
# TRANSFER: "transfer",
# COPY: 'copy',
# PASTE: 'paste',
# CLIPBOARD: 'clipboard'
# }
#
# NAME_MAP_REVERSE = {v: k for k, v in NAME_MAP.items()}
# CHOICES = []
# for i, j in DB_CHOICES:
# CHOICES.append((NAME_MAP[i], j))
#
# @classmethod
# def choices(cls):
# pass
#
class SpecialAccount(models.TextChoices):
ALL = '@ALL', 'All'

View File

@ -5,7 +5,5 @@ class UserGrantedTreeRebuildLock(DistributedLock):
name_template = 'perms.user.asset.node.tree.rebuid.<user_id:{user_id}>'
def __init__(self, user_id):
name = self.name_template.format(
user_id=user_id
)
name = self.name_template.format(user_id=user_id)
super().__init__(name=name, release_on_transaction_commit=True)

View File

@ -3,13 +3,12 @@
from django.db import migrations, models
from django.db.models import F
from perms.models import Action
def migrate_asset_permission(apps, schema_editor):
# 已有的资产权限默认拥有剪切板复制粘贴动作
AssetPermission = apps.get_model('perms', 'AssetPermission')
AssetPermission.objects.all().update(actions=F('actions').bitor(Action.CLIPBOARD_COPY_PASTE))
asset_permission_model = apps.get_model('perms', 'AssetPermission')
asset_permission_model.objects.all().update(actions=F('actions').bitor(24))
class Migration(migrations.Migration):

View File

@ -1,5 +1,5 @@
# coding: utf-8
#
from .permed_node import *
from .asset_permission import *
from .const import *

View File

@ -1,23 +1,19 @@
import uuid
import logging
import uuid
from django.db import models
from django.db.models import Q
from django.utils import timezone
from django.utils.translation import ugettext_lazy as _
from django.db import models
from django.db.models import F, Q, TextChoices
from common.utils import lazyproperty, date_expired_default
from common.db.models import BaseCreateUpdateModel, UnionQuerySet
from assets.models import Asset, Node, FamilyMixin, Account
from orgs.mixins.models import OrgModelMixin
from assets.models import Asset, Account
from common.db.models import UnionQuerySet
from common.utils import date_expired_default
from orgs.mixins.models import OrgManager
from .const import Action, SpecialAccount
from orgs.mixins.models import OrgModelMixin
from perms.const import ActionChoices, SpecialAccount
__all__ = [
'AssetPermission', 'PermNode',
'UserAssetGrantedTreeNodeRelation',
'Action'
]
__all__ = ['AssetPermission', 'ActionChoices']
# 使用场景
logger = logging.getLogger(__name__)
@ -67,9 +63,7 @@ class AssetPermission(OrgModelMixin):
)
# 特殊的账号: @ALL, @INPUT @USER 默认包含,将来在全局设置中进行控制.
accounts = models.JSONField(default=list, verbose_name=_("Accounts"))
actions = models.IntegerField(
choices=Action.DB_CHOICES, default=Action.ALL, verbose_name=_("Actions")
)
actions = models.IntegerField(default=ActionChoices.connect, verbose_name=_("Actions"))
is_active = models.BooleanField(default=True, verbose_name=_('Active'))
date_start = models.DateTimeField(
default=timezone.now, db_index=True, verbose_name=_("Date start")
@ -133,145 +127,9 @@ class AssetPermission(OrgModelMixin):
"""
asset_ids = self.get_all_assets(flat=True)
q = Q(asset_id__in=asset_ids)
if not self.is_perm_all_accounts:
if SpecialAccount.ALL in self.accounts:
q &= Q(username__in=self.accounts)
accounts = Account.objects.filter(q).order_by('asset__name', 'name', 'username')
if not flat:
return accounts
return accounts.values_list('id', flat=True)
@property
def is_perm_all_accounts(self):
return SpecialAccount.ALL in self.accounts
@lazyproperty
def users_amount(self):
return self.users.count()
@lazyproperty
def user_groups_amount(self):
return self.user_groups.count()
@lazyproperty
def assets_amount(self):
return self.assets.count()
@lazyproperty
def nodes_amount(self):
return self.nodes.count()
def users_display(self):
names = [user.username for user in self.users.all()]
return names
def user_groups_display(self):
names = [group.name for group in self.user_groups.all()]
return names
def assets_display(self):
names = [asset.name for asset in self.assets.all()]
return names
def nodes_display(self):
names = [node.full_value for node in self.nodes.all()]
return names
class UserAssetGrantedTreeNodeRelation(OrgModelMixin, FamilyMixin, BaseCreateUpdateModel):
class NodeFrom(TextChoices):
granted = 'granted', 'Direct node granted'
child = 'child', 'Have children node'
asset = 'asset', 'Direct asset granted'
user = models.ForeignKey('users.User', db_constraint=False, on_delete=models.CASCADE)
node = models.ForeignKey('assets.Node', default=None, on_delete=models.CASCADE,
db_constraint=False, null=False, related_name='granted_node_rels')
node_key = models.CharField(max_length=64, verbose_name=_("Key"), db_index=True)
node_parent_key = models.CharField(max_length=64, default='', verbose_name=_('Parent key'),
db_index=True)
node_from = models.CharField(choices=NodeFrom.choices, max_length=16, db_index=True)
node_assets_amount = models.IntegerField(default=0)
@property
def key(self):
return self.node_key
@property
def parent_key(self):
return self.node_parent_key
@classmethod
def get_node_granted_status(cls, user, key):
ancestor_keys = set(cls.get_node_ancestor_keys(key, with_self=True))
ancestor_rel_nodes = cls.objects.filter(user=user, node_key__in=ancestor_keys)
for rel_node in ancestor_rel_nodes:
if rel_node.key == key:
return rel_node.node_from, rel_node
if rel_node.node_from == cls.NodeFrom.granted:
return cls.NodeFrom.granted, None
return '', None
class PermNode(Node):
class Meta:
proxy = True
ordering = []
# 特殊节点
UNGROUPED_NODE_KEY = 'ungrouped'
UNGROUPED_NODE_VALUE = _('Ungrouped')
FAVORITE_NODE_KEY = 'favorite'
FAVORITE_NODE_VALUE = _('Favorite')
node_from = ''
granted_assets_amount = 0
annotate_granted_node_rel_fields = {
'granted_assets_amount': F('granted_node_rels__node_assets_amount'),
'node_from': F('granted_node_rels__node_from')
}
def use_granted_assets_amount(self):
self.assets_amount = self.granted_assets_amount
@classmethod
def get_ungrouped_node(cls, assets_amount):
return cls(
id=cls.UNGROUPED_NODE_KEY,
key=cls.UNGROUPED_NODE_KEY,
value=cls.UNGROUPED_NODE_VALUE,
assets_amount=assets_amount
)
@classmethod
def get_favorite_node(cls, assets_amount):
node = cls(
id=cls.FAVORITE_NODE_KEY,
key=cls.FAVORITE_NODE_KEY,
value=cls.FAVORITE_NODE_VALUE,
)
node.assets_amount = assets_amount
return node
def get_granted_status(self, user):
status, rel_node = UserAssetGrantedTreeNodeRelation.get_node_granted_status(user, self.key)
self.node_from = status
if rel_node:
self.granted_assets_amount = rel_node.node_assets_amount
return status
def save(self):
# 这是个只读 Model
raise NotImplementedError
class PermedAsset(Asset):
class Meta:
proxy = True
verbose_name = _('Permed asset')
permissions = [
('view_myassets', _('Can view my assets')),
('view_userassets', _('Can view user assets')),
('view_usergroupassets', _('Can view usergroup assets')),
]

View File

@ -1,48 +0,0 @@
from django.db import models
from django.utils.translation import ugettext_lazy as _
from common.db.models import BitOperationChoice
__all__ = ['Action', 'SpecialAccount']
class Action(BitOperationChoice):
ALL = 0xff
CONNECT = 0b1
UPLOAD = 0b1 << 1
DOWNLOAD = 0b1 << 2
CLIPBOARD_COPY = 0b1 << 3
CLIPBOARD_PASTE = 0b1 << 4
UPDOWNLOAD = UPLOAD | DOWNLOAD
CLIPBOARD_COPY_PASTE = CLIPBOARD_COPY | CLIPBOARD_PASTE
DB_CHOICES = (
(ALL, _('All')),
(CONNECT, _('Connect')),
(UPLOAD, _('Upload file')),
(DOWNLOAD, _('Download file')),
(UPDOWNLOAD, _("Upload download")),
(CLIPBOARD_COPY, _('Clipboard copy')),
(CLIPBOARD_PASTE, _('Clipboard paste')),
(CLIPBOARD_COPY_PASTE, _('Clipboard copy paste'))
)
NAME_MAP = {
ALL: "all",
CONNECT: "connect",
UPLOAD: "upload_file",
DOWNLOAD: "download_file",
UPDOWNLOAD: "updownload",
CLIPBOARD_COPY: 'clipboard_copy',
CLIPBOARD_PASTE: 'clipboard_paste',
CLIPBOARD_COPY_PASTE: 'clipboard_copy_paste'
}
NAME_MAP_REVERSE = {v: k for k, v in NAME_MAP.items()}
CHOICES = []
for i, j in DB_CHOICES:
CHOICES.append((NAME_MAP[i], j))
class SpecialAccount(models.TextChoices):
ALL = '@ALL', 'All'

View File

@ -0,0 +1,119 @@
from django.utils.translation import ugettext_lazy as _
from django.db import models
from django.db.models import F, TextChoices
from common.utils import lazyproperty
from common.db.models import BaseCreateUpdateModel
from assets.models import Asset, Node, FamilyMixin, Account
from orgs.mixins.models import OrgModelMixin
class UserAssetGrantedTreeNodeRelation(OrgModelMixin, FamilyMixin, BaseCreateUpdateModel):
class NodeFrom(TextChoices):
granted = 'granted', 'Direct node granted'
child = 'child', 'Have children node'
asset = 'asset', 'Direct asset granted'
user = models.ForeignKey('users.User', db_constraint=False, on_delete=models.CASCADE)
node = models.ForeignKey('assets.Node', default=None, on_delete=models.CASCADE,
db_constraint=False, null=False, related_name='granted_node_rels')
node_key = models.CharField(max_length=64, verbose_name=_("Key"), db_index=True)
node_parent_key = models.CharField(max_length=64, default='', verbose_name=_('Parent key'),
db_index=True)
node_from = models.CharField(choices=NodeFrom.choices, max_length=16, db_index=True)
node_assets_amount = models.IntegerField(default=0)
@property
def key(self):
return self.node_key
@property
def parent_key(self):
return self.node_parent_key
@classmethod
def get_node_granted_status(cls, user, key):
ancestor_keys = set(cls.get_node_ancestor_keys(key, with_self=True))
ancestor_rel_nodes = cls.objects.filter(user=user, node_key__in=ancestor_keys)
for rel_node in ancestor_rel_nodes:
if rel_node.key == key:
return rel_node.node_from, rel_node
if rel_node.node_from == cls.NodeFrom.granted:
return cls.NodeFrom.granted, None
return '', None
class PermNode(Node):
class Meta:
proxy = True
ordering = []
# 特殊节点
UNGROUPED_NODE_KEY = 'ungrouped'
UNGROUPED_NODE_VALUE = _('Ungrouped')
FAVORITE_NODE_KEY = 'favorite'
FAVORITE_NODE_VALUE = _('Favorite')
node_from = ''
granted_assets_amount = 0
annotate_granted_node_rel_fields = {
'granted_assets_amount': F('granted_node_rels__node_assets_amount'),
'node_from': F('granted_node_rels__node_from')
}
def use_granted_assets_amount(self):
self.assets_amount = self.granted_assets_amount
@classmethod
def get_ungrouped_node(cls, assets_amount):
return cls(
id=cls.UNGROUPED_NODE_KEY,
key=cls.UNGROUPED_NODE_KEY,
value=cls.UNGROUPED_NODE_VALUE,
assets_amount=assets_amount
)
@classmethod
def get_favorite_node(cls, assets_amount):
node = cls(
id=cls.FAVORITE_NODE_KEY,
key=cls.FAVORITE_NODE_KEY,
value=cls.FAVORITE_NODE_VALUE,
)
node.assets_amount = assets_amount
return node
def get_granted_status(self, user):
status, rel_node = UserAssetGrantedTreeNodeRelation.get_node_granted_status(user, self.key)
self.node_from = status
if rel_node:
self.granted_assets_amount = rel_node.node_assets_amount
return status
def save(self):
# 这是个只读 Model
raise NotImplementedError
class PermedAsset(Asset):
class Meta:
proxy = True
verbose_name = _('Permed asset')
permissions = [
('view_myassets', _('Can view my assets')),
('view_userassets', _('Can view user assets')),
('view_usergroupassets', _('Can view usergroup assets')),
]
class PermedAccount(Account):
@lazyproperty
def actions(self):
return 0
class Meta:
proxy = True
verbose_name = _('Permed account')

View File

@ -1,75 +1,64 @@
# -*- coding: utf-8 -*-
#
from rest_framework import serializers
from rest_framework.fields import empty
from django.utils.translation import ugettext_lazy as _
from django.db.models import Q
from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers
from common.drf.fields import ObjectRelatedField
from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from assets.models import Asset, Node
from common.drf.fields import BitChoicesField, ObjectRelatedField
from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from perms.models import ActionChoices, AssetPermission
from users.models import User, UserGroup
from perms.models import AssetPermission, Action
__all__ = ['AssetPermissionSerializer', 'ActionsField']
__all__ = ["AssetPermissionSerializer", "ActionChoicesField"]
class ActionsField(serializers.MultipleChoiceField):
class ActionChoicesField(BitChoicesField):
def __init__(self, **kwargs):
kwargs['choices'] = Action.CHOICES
super().__init__(**kwargs)
def run_validation(self, data=empty):
data = super(ActionsField, self).run_validation(data)
if isinstance(data, list):
data = Action.choices_to_value(value=data)
return data
def to_representation(self, value):
return Action.value_to_choices(value)
def to_internal_value(self, data):
if not self.allow_empty and not data:
self.fail('empty')
if not data:
return data
return Action.choices_to_value(data)
class ActionsDisplayField(ActionsField):
def to_representation(self, value):
values = super().to_representation(value)
choices = dict(Action.CHOICES)
return [choices.get(i) for i in values]
super().__init__(ActionChoices, **kwargs)
class AssetPermissionSerializer(BulkOrgResourceModelSerializer):
users = ObjectRelatedField(queryset=User.objects, many=True, required=False)
user_groups = ObjectRelatedField(queryset=UserGroup.objects, many=True, required=False)
user_groups = ObjectRelatedField(
queryset=UserGroup.objects, many=True, required=False
)
assets = ObjectRelatedField(queryset=Asset.objects, many=True, required=False)
nodes = ObjectRelatedField(queryset=Node.objects, many=True, required=False)
actions = ActionsField(required=False, allow_null=True, label=_("Actions"))
actions = ActionChoicesField(required=False, allow_null=True, label=_("Actions"))
is_valid = serializers.BooleanField(read_only=True, label=_("Is valid"))
is_expired = serializers.BooleanField(read_only=True, label=_('Is expired'))
is_expired = serializers.BooleanField(read_only=True, label=_("Is expired"))
accounts = serializers.ListField(label=_("Accounts"), required=False)
class Meta:
model = AssetPermission
fields_mini = ['id', 'name']
fields_mini = ["id", "name"]
fields_small = fields_mini + [
'accounts', 'is_active', 'is_expired', 'is_valid',
'actions', 'created_by', 'date_created', 'date_expired',
'date_start', 'comment', 'from_ticket'
"accounts",
"is_active",
"is_expired",
"is_valid",
"actions",
"created_by",
"date_created",
"date_expired",
"date_start",
"comment",
"from_ticket",
]
fields_m2m = [
'users', 'user_groups', 'assets', 'nodes',
"users",
"user_groups",
"assets",
"nodes",
]
fields = fields_small + fields_m2m
read_only_fields = ['created_by', 'date_created', 'from_ticket']
read_only_fields = ["created_by", "date_created", "from_ticket"]
extra_kwargs = {
'actions': {'label': _('Actions')},
'is_expired': {'label': _('Is expired')},
'is_valid': {'label': _('Is valid')},
"actions": {"label": _("Actions")},
"is_expired": {"label": _("Is expired")},
"is_valid": {"label": _("Is valid")},
}
def __init__(self, *args, **kwargs):
@ -77,7 +66,7 @@ class AssetPermissionSerializer(BulkOrgResourceModelSerializer):
self.set_actions_field()
def set_actions_field(self):
actions = self.fields.get('actions')
actions = self.fields.get("actions")
if not actions:
return
choices = actions._choices
@ -86,9 +75,12 @@ class AssetPermissionSerializer(BulkOrgResourceModelSerializer):
@classmethod
def setup_eager_loading(cls, queryset):
""" Perform necessary eager loading of data. """
"""Perform necessary eager loading of data."""
queryset = queryset.prefetch_related(
'users', 'user_groups', 'assets', 'nodes',
"users",
"user_groups",
"assets",
"nodes",
)
return queryset
@ -96,35 +88,34 @@ class AssetPermissionSerializer(BulkOrgResourceModelSerializer):
def perform_display_create(instance, **kwargs):
# 用户
users_to_set = User.objects.filter(
Q(name__in=kwargs.get('users_display')) |
Q(username__in=kwargs.get('users_display'))
Q(name__in=kwargs.get("users_display"))
| Q(username__in=kwargs.get("users_display"))
).distinct()
instance.users.add(*users_to_set)
# 用户组
user_groups_to_set = UserGroup.objects.filter(
name__in=kwargs.get('user_groups_display')
name__in=kwargs.get("user_groups_display")
).distinct()
instance.user_groups.add(*user_groups_to_set)
# 资产
assets_to_set = Asset.objects.filter(
Q(address__in=kwargs.get('assets_display')) |
Q(name__in=kwargs.get('assets_display'))
Q(address__in=kwargs.get("assets_display"))
| Q(name__in=kwargs.get("assets_display"))
).distinct()
instance.assets.add(*assets_to_set)
# 节点
nodes_to_set = Node.objects.filter(
full_value__in=kwargs.get('nodes_display')
full_value__in=kwargs.get("nodes_display")
).distinct()
instance.nodes.add(*nodes_to_set)
def create(self, validated_data):
display = {
'users_display': validated_data.pop('users_display', ''),
'user_groups_display': validated_data.pop('user_groups_display', ''),
'assets_display': validated_data.pop('assets_display', ''),
'nodes_display': validated_data.pop('nodes_display', '')
"users_display": validated_data.pop("users_display", ""),
"user_groups_display": validated_data.pop("user_groups_display", ""),
"assets_display": validated_data.pop("assets_display", ""),
"nodes_display": validated_data.pop("nodes_display", ""),
}
instance = super().create(validated_data)
self.perform_display_create(instance, **display)
return instance

View File

@ -7,7 +7,7 @@ from django.utils.translation import ugettext_lazy as _
from common.drf.fields import ObjectRelatedField, LabeledChoiceField
from assets.models import Node, Asset, Platform, Account
from assets.const import Category, AllTypes
from perms.serializers.permission import ActionsField
from perms.serializers.permission import ActionChoicesField
__all__ = [
'NodeGrantedSerializer', 'AssetGrantedSerializer',
@ -45,7 +45,7 @@ class NodeGrantedSerializer(serializers.ModelSerializer):
class ActionsSerializer(serializers.Serializer):
actions = ActionsField(read_only=True)
actions = ActionChoicesField(read_only=True)
class AccountsGrantedSerializer(serializers.ModelSerializer):
@ -53,7 +53,7 @@ class AccountsGrantedSerializer(serializers.ModelSerializer):
# Todo: 添加前端登录逻辑中需要的一些字段,比如:是否需要手动输入密码
# need_manual = serializers.BooleanField(label=_('Need manual input'))
actions = ActionsField(read_only=True)
actions = ActionChoicesField(read_only=True)
class Meta:
model = Account

View File

@ -1,5 +1,6 @@
import time
from collections import defaultdict
from assets.models import Account
from .permission import AssetPermissionUtil
@ -8,54 +9,78 @@ __all__ = ['PermAccountUtil']
class PermAccountUtil(AssetPermissionUtil):
""" 资产授权账号相关的工具 """
@staticmethod
def get_permed_accounts_from_perms(perms, user, asset):
alias_action_bit_mapper = defaultdict(int)
alias_expired_mapper = defaultdict(list)
def get_perm_accounts_for_user(self, user, with_actions=False):
""" 获取授权给用户的所有账号 """
perms = self.get_permissions_for_user(user)
accounts = self.get_perm_accounts_for_permissions(perms, with_actions=with_actions)
for perm in perms:
for alias in perm.accounts:
alias_action_bit_mapper[alias] |= perm.actions
alias_expired_mapper[alias].append(perm.date_expired)
asset_accounts = asset.accounts.all()
username_account_mapper = {account.username: account for account in asset_accounts}
cleaned_accounts_action_bit = defaultdict(int)
cleaned_accounts_expired = defaultdict(list)
# @ALL 账号先处理,后面的每个最多映射一个账号
all_action_bit = alias_action_bit_mapper.pop('@ALL', None)
if all_action_bit:
for account in asset_accounts:
cleaned_accounts_action_bit[account] |= all_action_bit
cleaned_accounts_expired[account].extend(alias_expired_mapper['@ALL'])
for alias, action_bit in alias_action_bit_mapper.items():
if alias == '@USER':
if user.username in username_account_mapper:
account = username_account_mapper[user.username]
else:
account = Account.get_user_account(user.username)
elif alias == '@INPUT':
account = Account.get_manual_account()
elif alias in username_account_mapper:
account = username_account_mapper[alias]
else:
account = None
if account:
cleaned_accounts_action_bit[account] |= action_bit
cleaned_accounts_expired[account].extend(alias_expired_mapper[alias])
accounts = []
for account, action_bit in cleaned_accounts_action_bit.items():
account.actions = action_bit
account.date_expired = max(cleaned_accounts_expired[account])
accounts.append(account)
return accounts
def get_perm_accounts_for_user_asset(self, user, asset, with_actions=False, with_perms=False):
def get_permed_accounts_for_user(self, user, asset):
""" 获取授权给用户某个资产的账号 """
perms = self.get_permissions_for_user_asset(user, asset)
accounts = self.get_perm_accounts_for_permissions(perms, with_actions=with_actions)
if with_perms:
return perms, accounts
return accounts
def get_perm_accounts_for_user_group_asset(self, user_group, asset, with_actions=False):
""" 获取授权给用户组某个资产的账号 """
perms = self.get_permissions_for_user_group_asset(user_group, asset)
accounts = self.get_perm_accounts_for_permissions(perms, with_actions=with_actions)
return accounts
permed_accounts = self.get_permed_accounts_from_perms(perms, user, asset)
return permed_accounts
@staticmethod
def get_perm_accounts_for_permissions(permissions, with_actions=False):
def get_accounts_for_permission(perm, with_actions=False):
""" 获取授权规则包含的账号 """
aid_actions_map = defaultdict(int)
for perm in permissions:
account_ids = perm.get_all_accounts(flat=True)
actions = perm.actions
for aid in account_ids:
aid_actions_map[str(aid)] |= actions
# 这里不行,速度太慢, 别情有很多查询
account_ids = perm.get_all_accounts(flat=True)
actions = perm.actions
for aid in account_ids:
aid_actions_map[str(aid)] |= actions
account_ids = list(aid_actions_map.keys())
accounts = Account.objects.filter(id__in=account_ids).order_by(
'asset__name', 'name', 'username'
)
if with_actions:
for account in accounts:
account.actions = aid_actions_map.get(str(account.id))
accounts = Account.objects.filter(id__in=account_ids)
return accounts
def validate_permission(self, user, asset, account_username):
""" 校验用户有某个资产下某个账号名的权限 """
perms, accounts = self.get_perm_accounts_for_user_asset(
user, asset, with_actions=True, with_perms=True
)
perm = perms.first()
actions = []
for account in accounts:
if account.username == account_username:
actions = account.actions
expire_at = perm.date_expired.timestamp() if perm else time.time()
return actions, expire_at
permed_accounts = self.get_permed_accounts_for_user(user, asset)
accounts_mapper = {account.username: account for account in permed_accounts}
account = accounts_mapper.get(account_username)
if not account:
return False, None
else:
return account.actions, account.date_expired

View File

@ -1,12 +1,6 @@
import time
from collections import defaultdict
from django.db.models import Q
from common.utils import get_logger
from perms.models import AssetPermission, Action
from perms.hands import Asset, User, UserGroup, Node
from perms.utils.user_permission import get_user_all_asset_perm_ids
from perms.models import AssetPermission
logger = get_logger(__file__)

View File

@ -19,13 +19,12 @@ from orgs.utils import (
from assets.models import (
Asset, FavoriteAsset, AssetQuerySet, NodeQuerySet
)
from orgs.models import Organization
from perms.models import (
AssetPermission, PermNode, UserAssetGrantedTreeNodeRelation,
)
from users.models import User
from orgs.models import Organization
from perms.locks import UserGrantedTreeRebuildLock
from perms.models import (
AssetPermission, PermNode, UserAssetGrantedTreeNodeRelation
)
NodeFrom = UserAssetGrantedTreeNodeRelation.NodeFrom
NODE_ONLY_FIELDS = ('id', 'key', 'parent_key', 'org_id')

View File

@ -7,7 +7,6 @@ from collections import defaultdict
from django.utils import timezone as dj_timezone
from django.db import migrations
from perms.models import Action
from tickets.const import TicketType
pt = re.compile(r'(\w+)\((\w+)\)')

View File

@ -1,7 +1,6 @@
from django.db import models
from django.utils.translation import gettext_lazy as _
from perms.models import Action
from .general import Ticket
__all__ = ['ApplyAssetTicket']
@ -15,15 +14,13 @@ class ApplyAssetTicket(Ticket):
# 申请信息
apply_assets = models.ManyToManyField('assets.Asset', verbose_name=_('Apply assets'))
apply_accounts = models.JSONField(default=list, verbose_name=_('Apply accounts'))
apply_actions = models.IntegerField(
choices=Action.DB_CHOICES, default=Action.ALL, verbose_name=_('Actions')
)
apply_actions = models.IntegerField(default=1, verbose_name=_('Actions'))
apply_date_start = models.DateTimeField(verbose_name=_('Date start'), null=True)
apply_date_expired = models.DateTimeField(verbose_name=_('Date expired'), null=True)
@property
def apply_actions_display(self):
return Action.value_to_choices_display(self.apply_actions)
return 'Todo'
def get_apply_actions_display(self):
return ', '.join(self.apply_actions_display)

View File

@ -1,7 +1,7 @@
from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers
from perms.serializers.permission import ActionsField
from perms.serializers.permission import ActionChoicesField
from perms.models import AssetPermission
from orgs.utils import tmp_to_org
from assets.models import Asset, Node
@ -16,7 +16,7 @@ asset_or_node_help_text = _("Select at least one asset or node")
class ApplyAssetSerializer(BaseApplyAssetApplicationSerializer, TicketApplySerializer):
apply_actions = ActionsField(required=True, allow_empty=False)
apply_actions = ActionChoicesField(required=True, allow_empty=False)
permission_model = AssetPermission
class Meta:

View File

@ -15,8 +15,10 @@ from ..models import User
from ..const import PasswordStrategy
__all__ = [
'UserSerializer', 'MiniUserSerializer',
'InviteSerializer', 'ServiceAccountSerializer',
"UserSerializer",
"MiniUserSerializer",
"InviteSerializer",
"ServiceAccountSerializer",
]
logger = get_logger(__file__)
@ -25,15 +27,17 @@ logger = get_logger(__file__)
class RolesSerializerMixin(serializers.Serializer):
system_roles = serializers.ManyRelatedField(
child_relation=serializers.PrimaryKeyRelatedField(queryset=Role.system_roles),
label=_('System roles'),
label=_("System roles"),
)
org_roles = serializers.ManyRelatedField(
required=False,
child_relation=serializers.PrimaryKeyRelatedField(queryset=Role.org_roles),
label=_('Org roles'),
label=_("Org roles"),
)
system_roles_display = serializers.SerializerMethodField(label=_('System roles display'))
org_roles_display = serializers.SerializerMethodField(label=_('Org roles display'))
system_roles_display = serializers.SerializerMethodField(
label=_("System roles display")
)
org_roles_display = serializers.SerializerMethodField(label=_("Org roles display"))
@staticmethod
def get_system_roles_display(user):
@ -44,20 +48,20 @@ class RolesSerializerMixin(serializers.Serializer):
return user.org_roles.display
def pop_roles_if_need(self, fields):
request = self.context.get('request')
view = self.context.get('view')
request = self.context.get("request")
view = self.context.get("view")
if not all([request, view, hasattr(view, 'action')]):
if not all([request, view, hasattr(view, "action")]):
return fields
if request.user.is_anonymous:
return fields
action = view.action or 'list'
if action in ('partial_bulk_update', 'bulk_update', 'partial_update', 'update'):
action = 'create'
action = view.action or "list"
if action in ("partial_bulk_update", "bulk_update", "partial_update", "update"):
action = "create"
model_cls_field_mapper = {
SystemRoleBinding: ['system_roles', 'system_roles_display'],
OrgRoleBinding: ['org_roles', 'system_roles_display']
SystemRoleBinding: ["system_roles", "system_roles_display"],
OrgRoleBinding: ["org_roles", "system_roles_display"],
}
for model_cls, fields_names in model_cls_field_mapper.items():
@ -75,97 +79,148 @@ class RolesSerializerMixin(serializers.Serializer):
return fields
class UserSerializer(RolesSerializerMixin, CommonBulkSerializerMixin, serializers.ModelSerializer):
class UserSerializer(
RolesSerializerMixin, CommonBulkSerializerMixin, serializers.ModelSerializer
):
password_strategy = serializers.ChoiceField(
choices=PasswordStrategy.choices, default=PasswordStrategy.email, required=False,
write_only=True, label=_('Password strategy')
choices=PasswordStrategy.choices,
default=PasswordStrategy.email,
required=False,
write_only=True,
label=_("Password strategy"),
)
mfa_enabled = serializers.BooleanField(read_only=True, label=_("MFA enabled"))
mfa_force_enabled = serializers.BooleanField(
read_only=True, label=_("MFA force enabled")
)
mfa_enabled = serializers.BooleanField(read_only=True, label=_('MFA enabled'))
mfa_force_enabled = serializers.BooleanField(read_only=True, label=_('MFA force enabled'))
mfa_level_display = serializers.ReadOnlyField(
source='get_mfa_level_display', label=_('MFA level display')
source="get_mfa_level_display", label=_("MFA level display")
)
login_blocked = serializers.BooleanField(read_only=True, label=_('Login blocked'))
is_expired = serializers.BooleanField(read_only=True, label=_('Is expired'))
login_blocked = serializers.BooleanField(read_only=True, label=_("Login blocked"))
is_expired = serializers.BooleanField(read_only=True, label=_("Is expired"))
can_public_key_auth = serializers.ReadOnlyField(
source='can_use_ssh_key_login', label=_('Can public key authentication')
source="can_use_ssh_key_login", label=_("Can public key authentication")
)
password = EncryptedField(
label=_('Password'), required=False, allow_blank=True, allow_null=True, max_length=1024
label=_("Password"),
required=False,
allow_blank=True,
allow_null=True,
max_length=1024,
)
# Todo: 这里看看该怎么搞
# can_update = serializers.SerializerMethodField(label=_('Can update'))
# can_delete = serializers.SerializerMethodField(label=_('Can delete'))
custom_m2m_fields = {
'system_roles': [BuiltinRole.system_user],
'org_roles': [BuiltinRole.org_user]
"system_roles": [BuiltinRole.system_user],
"org_roles": [BuiltinRole.org_user],
}
class Meta:
model = User
# mini 是指能识别对象的最小单元
fields_mini = ['id', 'name', 'username']
fields_mini = ["id", "name", "username"]
# 只能写的字段, 这个虽然无法在框架上生效,但是更多对我们是提醒
fields_write_only = [
'password', 'public_key',
"password",
"public_key",
]
# small 指的是 不需要计算的直接能从一张表中获取到的数据
fields_small = fields_mini + fields_write_only + [
'email', 'wechat', 'phone', 'mfa_level', 'source', 'source_display',
'can_public_key_auth', 'need_update_password',
'mfa_enabled', 'is_service_account', 'is_valid', 'is_expired', 'is_active', # 布尔字段
'date_expired', 'date_joined', 'last_login', # 日期字段
'created_by', 'comment', # 通用字段
'is_wecom_bound', 'is_dingtalk_bound', 'is_feishu_bound', 'is_otp_secret_key_bound',
'wecom_id', 'dingtalk_id', 'feishu_id'
]
fields_small = (
fields_mini
+ fields_write_only
+ [
"email",
"wechat",
"phone",
"mfa_level",
"source",
"source_display",
"can_public_key_auth",
"need_update_password",
"mfa_enabled",
"is_service_account",
"is_valid",
"is_expired",
"is_active", # 布尔字段
"date_expired",
"date_joined",
"last_login", # 日期字段
"created_by",
"comment", # 通用字段
"is_wecom_bound",
"is_dingtalk_bound",
"is_feishu_bound",
"is_otp_secret_key_bound",
"wecom_id",
"dingtalk_id",
"feishu_id",
]
)
# 包含不太常用的字段,可以没有
fields_verbose = fields_small + [
'mfa_level_display', 'mfa_force_enabled', 'is_first_login',
'date_password_last_updated', 'avatar_url',
"mfa_level_display",
"mfa_force_enabled",
"is_first_login",
"date_password_last_updated",
"avatar_url",
]
# 外键的字段
fields_fk = []
# 多对多字段
fields_m2m = [
'groups', 'groups_display', 'system_roles', 'org_roles',
'system_roles_display', 'org_roles_display'
"groups",
"groups_display",
"system_roles",
"org_roles",
"system_roles_display",
"org_roles_display",
]
# 在serializer 上定义的字段
fields_custom = ['login_blocked', 'password_strategy']
fields_custom = ["login_blocked", "password_strategy"]
fields = fields_verbose + fields_fk + fields_m2m + fields_custom
read_only_fields = [
'date_joined', 'last_login', 'created_by', 'is_first_login',
'wecom_id', 'dingtalk_id', 'feishu_id'
"date_joined",
"last_login",
"created_by",
"is_first_login",
"wecom_id",
"dingtalk_id",
"feishu_id",
]
disallow_self_update_fields = ['is_active']
disallow_self_update_fields = ["is_active"]
extra_kwargs = {
'password': {'write_only': True, 'required': False, 'allow_null': True, 'allow_blank': True},
'public_key': {'write_only': True},
'is_first_login': {'label': _('Is first login'), 'read_only': True},
'is_active': {'label': _('Is active')},
'is_valid': {'label': _('Is valid')},
'is_service_account': {'label': _('Is service account')},
'is_expired': {'label': _('Is expired')},
'avatar_url': {'label': _('Avatar url')},
'created_by': {'read_only': True, 'allow_blank': True},
'groups_display': {'label': _('Groups name')},
'source_display': {'label': _('Source name')},
'org_role_display': {'label': _('Organization role name')},
'role_display': {'label': _('Super role name')},
'total_role_display': {'label': _('Total role name')},
'role': {'default': "User"},
'is_wecom_bound': {'label': _('Is wecom bound')},
'is_dingtalk_bound': {'label': _('Is dingtalk bound')},
'is_feishu_bound': {'label': _('Is feishu bound')},
'is_otp_secret_key_bound': {'label': _('Is OTP bound')},
'phone': {'validators': [PhoneValidator()]},
'system_role_display': {'label': _('System role name')},
"password": {
"write_only": True,
"required": False,
"allow_null": True,
"allow_blank": True,
},
"public_key": {"write_only": True},
"is_first_login": {"label": _("Is first login"), "read_only": True},
"is_active": {"label": _("Is active")},
"is_valid": {"label": _("Is valid")},
"is_service_account": {"label": _("Is service account")},
"is_expired": {"label": _("Is expired")},
"avatar_url": {"label": _("Avatar url")},
"created_by": {"read_only": True, "allow_blank": True},
"groups_display": {"label": _("Groups name")},
"source_display": {"label": _("Source name")},
"org_role_display": {"label": _("Organization role name")},
"role_display": {"label": _("Super role name")},
"total_role_display": {"label": _("Total role name")},
"role": {"default": "User"},
"is_wecom_bound": {"label": _("Is wecom bound")},
"is_dingtalk_bound": {"label": _("Is dingtalk bound")},
"is_feishu_bound": {"label": _("Is feishu bound")},
"is_otp_secret_key_bound": {"label": _("Is OTP bound")},
"phone": {"validators": [PhoneValidator()]},
"system_role_display": {"label": _("System role name")},
}
def validate_password(self, password):
password_strategy = self.initial_data.get('password_strategy')
password_strategy = self.initial_data.get("password_strategy")
if self.instance is None and password_strategy != PasswordStrategy.custom:
# 创建用户,使用邮件设置密码
return
@ -176,32 +231,34 @@ class UserSerializer(RolesSerializerMixin, CommonBulkSerializerMixin, serializer
@staticmethod
def change_password_to_raw(attrs):
password = attrs.pop('password', None)
password = attrs.pop("password", None)
if password:
attrs['password_raw'] = password
attrs["password_raw"] = password
return attrs
@staticmethod
def clean_auth_fields(attrs):
for field in ('password', 'public_key'):
for field in ("password", "public_key"):
value = attrs.get(field)
if not value:
attrs.pop(field, None)
return attrs
def check_disallow_self_update_fields(self, attrs):
request = self.context.get('request')
request = self.context.get("request")
if not request or not request.user.is_authenticated:
return attrs
if not self.instance:
return attrs
if request.user.id != self.instance.id:
return attrs
disallow_fields = set(list(attrs.keys())) & set(self.Meta.disallow_self_update_fields)
disallow_fields = set(list(attrs.keys())) & set(
self.Meta.disallow_self_update_fields
)
if not disallow_fields:
return attrs
# 用户自己不能更新自己的一些字段
logger.debug('Disallow update self fields: %s', disallow_fields)
logger.debug("Disallow update self fields: %s", disallow_fields)
for field in disallow_fields:
attrs.pop(field, None)
return attrs
@ -210,7 +267,7 @@ class UserSerializer(RolesSerializerMixin, CommonBulkSerializerMixin, serializer
attrs = self.check_disallow_self_update_fields(attrs)
attrs = self.change_password_to_raw(attrs)
attrs = self.clean_auth_fields(attrs)
attrs.pop('password_strategy', None)
attrs.pop("password_strategy", None)
return attrs
def save_and_set_custom_m2m_fields(self, validated_data, save_handler, created):
@ -219,8 +276,7 @@ class UserSerializer(RolesSerializerMixin, CommonBulkSerializerMixin, serializer
roles = validated_data.pop(f, None)
if created and not roles:
roles = [
Role.objects.filter(id=role.id).first()
for role in default_roles
Role.objects.filter(id=role.id).first() for role in default_roles
]
m2m_values[f] = roles
@ -234,22 +290,26 @@ class UserSerializer(RolesSerializerMixin, CommonBulkSerializerMixin, serializer
def update(self, instance, validated_data):
save_handler = partial(super().update, instance)
instance = self.save_and_set_custom_m2m_fields(validated_data, save_handler, created=False)
instance = self.save_and_set_custom_m2m_fields(
validated_data, save_handler, created=False
)
return instance
def create(self, validated_data):
save_handler = super().create
instance = self.save_and_set_custom_m2m_fields(validated_data, save_handler, created=True)
instance = self.save_and_set_custom_m2m_fields(
validated_data, save_handler, created=True
)
return instance
class UserRetrieveSerializer(UserSerializer):
login_confirm_settings = serializers.PrimaryKeyRelatedField(
read_only=True, source='login_confirm_setting.reviewers', many=True
read_only=True, source="login_confirm_setting.reviewers", many=True
)
class Meta(UserSerializer.Meta):
fields = UserSerializer.Meta.fields + ['login_confirm_settings']
fields = UserSerializer.Meta.fields + ["login_confirm_settings"]
class MiniUserSerializer(serializers.ModelSerializer):
@ -260,8 +320,10 @@ class MiniUserSerializer(serializers.ModelSerializer):
class InviteSerializer(RolesSerializerMixin, serializers.Serializer):
users = serializers.PrimaryKeyRelatedField(
queryset=User.get_nature_users(), many=True, label=_('Select users'),
help_text=_('For security, only list several users')
queryset=User.get_nature_users(),
many=True,
label=_("Select users"),
help_text=_("For security, only list several users"),
)
system_roles = None
system_roles_display = None
@ -271,22 +333,23 @@ class InviteSerializer(RolesSerializerMixin, serializers.Serializer):
class ServiceAccountSerializer(serializers.ModelSerializer):
class Meta:
model = User
fields = ['id', 'name', 'access_key', 'comment']
read_only_fields = ['access_key']
fields = ["id", "name", "access_key", "comment"]
read_only_fields = ["access_key"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
from authentication.serializers import AccessKeySerializer
self.fields['access_key'] = AccessKeySerializer(read_only=True)
self.fields["access_key"] = AccessKeySerializer(read_only=True)
def get_username(self):
return self.initial_data.get('name')
return self.initial_data.get("name")
def get_email(self):
name = self.initial_data.get('name')
name = self.initial_data.get("name")
name_max_length = 128 - len(User.service_account_email_suffix)
name = pretty_string(name, max_length=name_max_length, ellipsis_str='-')
return '{}{}'.format(name, User.service_account_email_suffix)
name = pretty_string(name, max_length=name_max_length, ellipsis_str="-")
return "{}{}".format(name, User.service_account_email_suffix)
def validate_name(self, name):
email = self.get_email()
@ -296,12 +359,12 @@ class ServiceAccountSerializer(serializers.ModelSerializer):
else:
users = User.objects.all()
if users.filter(email=email) or users.filter(username=username):
raise serializers.ValidationError(_('name not unique'), code='unique')
raise serializers.ValidationError(_("name not unique"), code="unique")
return name
def create(self, validated_data):
name = validated_data['name']
name = validated_data["name"]
email = self.get_email()
comment = validated_data.get('comment', '')
comment = validated_data.get("comment", "")
user, ak = User.create_service_account(name, email, comment)
return user