jumpserver/apps/accounts/validator.py

99 lines
3.1 KiB
Python
Raw Normal View History

from functools import reduce
from django.utils.translation import ugettext_lazy as _
from rest_framework.validators import (
UniqueTogetherValidator, ValidationError
)
from accounts.const import BulkCreateStrategy
from accounts.models import Account
from assets.const import Protocol
__all__ = ['AccountUniqueTogetherValidator', 'AccountSecretTypeValidator']
class ValidatorStrategyMixin:
@staticmethod
def get_strategy(attrs):
return attrs.get('strategy', BulkCreateStrategy.SKIP)
def __call__(self, attrs, serializer):
message = None
try:
super().__call__(attrs, serializer)
except ValidationError as e:
message = e.detail[0]
strategy = self.get_strategy(attrs)
if not message:
return
if strategy == BulkCreateStrategy.ERROR:
raise ValidationError(message, code='error')
elif strategy in [BulkCreateStrategy.SKIP, BulkCreateStrategy.UPDATE]:
raise ValidationError({})
else:
return
class SecretTypeValidator:
requires_context = True
protocol_settings = Protocol.settings()
message = _('{field_name} not a legal option')
def __init__(self, fields):
self.fields = fields
def __call__(self, attrs, serializer):
secret_types = set()
asset = attrs['asset']
secret_type = attrs['secret_type']
platform_protocols_dict = {
name: self.protocol_settings.get(name, {}).get('secret_types', [])
for name in asset.platform.protocols.values_list('name', flat=True)
}
for name in asset.protocols.values_list('name', flat=True):
if name in platform_protocols_dict:
secret_types |= set(platform_protocols_dict[name])
if secret_type not in secret_types:
message = self.message.format(field_name=secret_type)
raise ValidationError(message, code='error')
class UpdateAccountMixin:
fields: tuple
get_strategy: callable
def update(self, attrs):
unique_together = Account._meta.unique_together
unique_together_fields = reduce(lambda x, y: set(x) | set(y), unique_together)
query = {field_name: attrs[field_name] for field_name in unique_together_fields}
account = Account.objects.filter(**query).first()
if not account:
query = {field_name: attrs[field_name] for field_name in self.fields}
account = Account.objects.filter(**query).first()
for k, v in attrs.items():
setattr(account, k, v)
account.save()
def __call__(self, attrs, serializer):
try:
super().__call__(attrs, serializer)
except ValidationError as e:
strategy = self.get_strategy(attrs)
if strategy == BulkCreateStrategy.UPDATE:
self.update(attrs)
message = e.detail[0]
raise ValidationError(message, code='unique')
class AccountUniqueTogetherValidator(
ValidatorStrategyMixin, UpdateAccountMixin, UniqueTogetherValidator
):
pass
class AccountSecretTypeValidator(ValidatorStrategyMixin, SecretTypeValidator):
pass