mirror of https://github.com/jumpserver/jumpserver
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
93 lines
3.1 KiB
93 lines
3.1 KiB
# -*- coding: utf-8 -*- |
|
# |
|
from rest_framework import generics |
|
from rest_framework.serializers import Serializer |
|
|
|
from common.permissions import IsValidUser |
|
from common.utils import get_logger |
|
from .. import serializers |
|
from ..models import Preference |
|
|
|
logger = get_logger(__file__) |
|
|
|
|
|
class PreferenceApi(generics.RetrieveUpdateAPIView): |
|
permission_classes = (IsValidUser,) |
|
queryset = Preference.objects.all() |
|
serializer_class_mapper = { |
|
'lina': serializers.LinaSerializer, |
|
'luna': serializers.LunaSerializer, |
|
'koko': serializers.KokoSerializer, |
|
} |
|
|
|
def check_permissions(self, request): |
|
if self.category not in self.serializer_class_mapper: |
|
return self.permission_denied(request, 'category is invalid') |
|
return super().check_permissions(request) |
|
|
|
@property |
|
def user(self): |
|
return self.request.user |
|
|
|
@property |
|
def category(self): |
|
return self.request.query_params.get('category') |
|
|
|
def get_serializer_class(self): |
|
cls = self.serializer_class_mapper.get(self.category) |
|
return cls |
|
|
|
def get_field_defaults(self, serializer): |
|
field_defaults = {} |
|
fields = serializer.get_fields() |
|
for name, field in fields.items(): |
|
if isinstance(field, Serializer): |
|
field_defaults[name] = self.get_field_defaults(field) |
|
continue |
|
field_defaults[name] = getattr(field, 'default', None) |
|
return field_defaults |
|
|
|
def get_encrypted_fields(self, serializer): |
|
encrypted_fields = [] |
|
fields = serializer.get_fields() |
|
for name, field in fields.items(): |
|
if isinstance(field, Serializer): |
|
encrypted_fields += self.get_encrypted_fields(field) |
|
continue |
|
if not field.write_only: |
|
continue |
|
encrypted_fields.append(name) |
|
return encrypted_fields |
|
|
|
def get_object(self): |
|
serializer = self.get_serializer_class()() |
|
field_defaults = self.get_field_defaults(serializer) |
|
|
|
qs = self.queryset.filter(user=self.user, category=self.category) |
|
if not qs.exists(): |
|
return field_defaults |
|
|
|
data = dict(qs.values_list('name', 'value')) |
|
for k, v in data.items(): |
|
for d in field_defaults.values(): |
|
if k in d: |
|
d[k] = v |
|
break |
|
return field_defaults |
|
|
|
def perform_update(self, serializer): |
|
user = self.user |
|
category = self.category |
|
model = self.queryset.model |
|
encrypted_fields = self.get_encrypted_fields(serializer) |
|
data = serializer.validated_data |
|
for d in data.values(): |
|
for name, value in d.items(): |
|
kwargs = {'name': name, 'user': user} |
|
defaults = {'category': category} |
|
if name in encrypted_fields: |
|
value = model.encrypt(value) |
|
defaults['encrypted'] = True |
|
defaults['value'] = value |
|
defaults.update(kwargs) |
|
model.objects.update_or_create(defaults, **kwargs)
|
|
|