diff --git a/apps/accounts/const/account.py b/apps/accounts/const/account.py index 109044934..b86e9400b 100644 --- a/apps/accounts/const/account.py +++ b/apps/accounts/const/account.py @@ -18,3 +18,9 @@ class AliasAccount(TextChoices): class Source(TextChoices): LOCAL = 'local', _('Local') COLLECTED = 'collected', _('Collected') + + +class BulkCreateStrategy(TextChoices): + SKIP = 'skip', _('Skip') + UPDATE = 'update', _('Update') + ERROR = 'error', _('Failed') diff --git a/apps/accounts/serializers/account/account.py b/apps/accounts/serializers/account/account.py index 739ef7e37..a2ab1e22e 100644 --- a/apps/accounts/serializers/account/account.py +++ b/apps/accounts/serializers/account/account.py @@ -1,10 +1,11 @@ from django.utils.translation import ugettext_lazy as _ from rest_framework import serializers from rest_framework.validators import ( - UniqueTogetherValidator, ValidationError + UniqueTogetherValidator ) -from accounts.const import SecretType, Source +from accounts import validator +from accounts.const import SecretType, Source, BulkCreateStrategy from accounts.models import Account, AccountTemplate from accounts.tasks import push_accounts_to_assets_task from assets.const import Category, AllTypes @@ -17,15 +18,6 @@ from .base import BaseAccountSerializer logger = get_logger(__name__) -class SkipUniqueValidator(UniqueTogetherValidator): - def __call__(self, attrs, serializer): - try: - super().__call__(attrs, serializer) - except ValidationError as e: - logger.debug(f'{attrs.get("asset")}: {e.detail[0]}') - raise ValidationError({}) - - class AccountSerializerCreateValidateMixin: from_id: str template: bool @@ -113,12 +105,16 @@ class AccountSerializer(AccountSerializerCreateMixin, BaseAccountSerializer): required=False, queryset=Account.objects, allow_null=True, allow_empty=True, label=_('Su from'), attrs=('id', 'name', 'username') ) + strategy = LabeledChoiceField( + choices=BulkCreateStrategy.choices, default=BulkCreateStrategy.SKIP, + write_only=True, label=_('Account policy') + ) class Meta(BaseAccountSerializer.Meta): model = Account fields = BaseAccountSerializer.Meta.fields + [ 'su_from', 'asset', 'template', 'version', - 'push_now', 'source', 'connectivity', + 'push_now', 'source', 'connectivity', 'strategy' ] extra_kwargs = { **BaseAccountSerializer.Meta.extra_kwargs, @@ -138,17 +134,27 @@ class AccountSerializer(AccountSerializerCreateMixin, BaseAccountSerializer): return queryset def get_validators(self): - validators = [] - data = self.context['request'].data - action = self.context['view'].action + ignore = False + validators = [validator.AccountSecretTypeValidator(fields=('secret_type',))] + view = self.context.get('view') + request = self.context.get('request') + if request and view: + data = request.data + action = view.action + ignore = action == 'create' and isinstance(data, list) + _validators = super().get_validators() - ignore = action == 'create' and isinstance(data, list) and len(data) > 1 for v in _validators: if ignore and isinstance(v, UniqueTogetherValidator): - v = SkipUniqueValidator(v.queryset, v.fields) + v = validator.AccountUniqueTogetherValidator(v.queryset, v.fields) validators.append(v) return validators + def validate(self, attrs): + attrs = super().validate(attrs) + attrs.pop('strategy', None) + return attrs + class AccountSecretSerializer(SecretReadableMixin, AccountSerializer): class Meta(AccountSerializer.Meta): diff --git a/apps/accounts/validator.py b/apps/accounts/validator.py new file mode 100644 index 000000000..ccf1e8c11 --- /dev/null +++ b/apps/accounts/validator.py @@ -0,0 +1,98 @@ +from functools import reduce + +from django.utils.translation import ugettext_lazy as _ +from rest_framework.validators import ( + UniqueTogetherValidator, ValidationError +) + +from accounts.const import BulkCreateStrategy +from accounts.models import Account +from assets.const import Protocol + +__all__ = ['AccountUniqueTogetherValidator', 'AccountSecretTypeValidator'] + + +class ValidatorStrategyMixin: + + @staticmethod + def get_strategy(attrs): + return attrs.get('strategy', BulkCreateStrategy.SKIP) + + def __call__(self, attrs, serializer): + message = None + try: + super().__call__(attrs, serializer) + except ValidationError as e: + message = e.detail[0] + strategy = self.get_strategy(attrs) + if not message: + return + if strategy == BulkCreateStrategy.ERROR: + raise ValidationError(message, code='error') + elif strategy in [BulkCreateStrategy.SKIP, BulkCreateStrategy.UPDATE]: + raise ValidationError({}) + else: + return + + +class SecretTypeValidator: + requires_context = True + protocol_settings = Protocol.settings() + message = _('{field_name} not a legal option') + + def __init__(self, fields): + self.fields = fields + + def __call__(self, attrs, serializer): + secret_types = set() + asset = attrs['asset'] + secret_type = attrs['secret_type'] + platform_protocols_dict = { + name: self.protocol_settings.get(name, {}).get('secret_types', []) + for name in asset.platform.protocols.values_list('name', flat=True) + } + + for name in asset.protocols.values_list('name', flat=True): + if name in platform_protocols_dict: + secret_types |= set(platform_protocols_dict[name]) + if secret_type not in secret_types: + message = self.message.format(field_name=secret_type) + raise ValidationError(message, code='error') + + +class UpdateAccountMixin: + fields: tuple + get_strategy: callable + + def update(self, attrs): + unique_together = Account._meta.unique_together + unique_together_fields = reduce(lambda x, y: set(x) | set(y), unique_together) + query = {field_name: attrs[field_name] for field_name in unique_together_fields} + account = Account.objects.filter(**query).first() + if not account: + query = {field_name: attrs[field_name] for field_name in self.fields} + account = Account.objects.filter(**query).first() + + for k, v in attrs.items(): + setattr(account, k, v) + account.save() + + def __call__(self, attrs, serializer): + try: + super().__call__(attrs, serializer) + except ValidationError as e: + strategy = self.get_strategy(attrs) + if strategy == BulkCreateStrategy.UPDATE: + self.update(attrs) + message = e.detail[0] + raise ValidationError(message, code='unique') + + +class AccountUniqueTogetherValidator( + ValidatorStrategyMixin, UpdateAccountMixin, UniqueTogetherValidator +): + pass + + +class AccountSecretTypeValidator(ValidatorStrategyMixin, SecretTypeValidator): + pass